diff --git a/BUILD b/BUILD index 385e328..65b049c 100644 --- a/BUILD +++ b/BUILD @@ -611,11 +611,107 @@ cc_library( alwayslink = 1, ) +gentbl_cc_library( + name = "DiscRalPassIncGen", + strip_include_prefix = "include", + tbl_outs = [ + ( + [ + "-gen-pass-decls", + "-name=RAL", + ], + "include/mlir-hlo/Dialect/mhlo/transforms/disc_ral_passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/mlir-hlo/Dialect/mhlo/transforms/disc_ral_passes.td", + td_includes = [ + "external/mlir-hlo/include", + "include", + ], + deps = [ + "@llvm-project//mlir:PassBaseTdFiles", + ], +) + +gentbl_cc_library( + name = "disc_ral_ops_inc_gen", + strip_include_prefix = "include", + tbl_outs = [ + ( + ["-gen-op-decls"], + "include/mlir-hlo/Dialect/mhlo/IR/disc_ral_ops.h.inc", + ), + ( + ["-gen-op-defs"], + "include/mlir-hlo/Dialect/mhlo/IR/disc_ral_ops.cc.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/mlir-hlo/Dialect/mhlo/IR/disc_ral_ops.td", + td_includes = [ + "external/mlir-hlo/include", + "include", + ], + deps = [":hlo_ops_td_files"], +) + +cc_library( + name = "disc_ral", + srcs = [ + "include/mlir-hlo/Dialect/mhlo/IR/disc_ral_ops.cc.inc", + "include/mlir-hlo/Dialect/mhlo/IR/disc_ral_ops.h.inc", + "lib/Dialect/mhlo/IR/disc_ral_ops.cc", + ], + hdrs = [ + "include/mlir-hlo/Dialect/mhlo/IR/disc_ral_ops.h", + ], + includes = ["include"], + deps = [ + ":disc_ral_ops_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:ControlFlowInterfaces", + "@llvm-project//mlir:CopyOpInterface", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:InferTypeOpInterface", + "@llvm-project//mlir:LoopLikeInterface", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SideEffects", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + "@llvm-project//mlir:ViewLikeInterface", + ], + alwayslink = 1, +) + +cc_library( + name = "ral_inject_execution_context", + srcs = ["lib/Dialect/mhlo/transforms/ral_inject_execution_context.cc"], + hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/passes.h"], + deps = [ + ":disc_ral", + ":pass_details", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:Shape", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:Transforms", + ], + alwayslink = 1, +) + cc_library( name = "hlo_dialect_registration", srcs = ["lib/Dialect/mhlo/IR/init.cc"], hdrs = ["include/mlir-hlo/Dialect/mhlo/IR/register.h"], deps = [ + ":disc_ral", ":hlo", ":lhlo", ":lhlo_gpu", @@ -1159,6 +1255,7 @@ cc_library( "//visibility:private", # This target is a private detail of pass implementations ], deps = [ + ":DiscRalPassIncGen", ":MhloPassIncGen", "@llvm-project//mlir:Pass", ], @@ -1202,6 +1299,7 @@ cc_library( "include/mlir-hlo/Dialect/mhlo/transforms/register_passes.h", ], deps = [ + ":DiscRalPassIncGen", ":LmhloPassIncGen", ":MhloPassIncGen", ":broadcast_propagation", @@ -1221,6 +1319,7 @@ cc_library( ":mhlo_control_flow_to_scf", ":mhlo_fusion", ":mhlo_to_mhlo_lowering_patterns", + ":ral_inject_execution_context", ":rank_specialization", ":sink_constants_to_control_flow", ":test_passes", @@ -1236,6 +1335,7 @@ cc_binary( ], deps = [ ":all_passes", + ":disc_ral", ":hlo", ":lhlo", ":lhlo_gpu", diff --git a/include/mlir-hlo/Dialect/mhlo/IR/CMakeLists.txt b/include/mlir-hlo/Dialect/mhlo/IR/CMakeLists.txt index 8b50b58..8831706 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/CMakeLists.txt +++ b/include/mlir-hlo/Dialect/mhlo/IR/CMakeLists.txt @@ -55,3 +55,13 @@ add_mlir_hlo_dialect_separate_files(lhlo_ops NO) add_mlir_hlo_dialect_separate_files(lhlo_gpu_ops YES) add_mlir_interface(infer_fusibility_op_interface) + +function(add_disc_ral_dialect dialect) + set(LLVM_TARGET_DEFINITIONS ${dialect}.td) + mlir_tablegen(${dialect}.h.inc -gen-op-decls) + mlir_tablegen(${dialect}.cc.inc -gen-op-defs) + add_public_tablegen_target(MLIR${dialect}IncGen) + add_dependencies(mlir-headers MLIR${dialect}IncGen) +endfunction() + +add_disc_ral_dialect(disc_ral_ops) diff --git a/include/mlir-hlo/Dialect/mhlo/IR/disc_ral_ops.h b/include/mlir-hlo/Dialect/mhlo/IR/disc_ral_ops.h new file mode 100644 index 0000000..a23d02e --- /dev/null +++ b/include/mlir-hlo/Dialect/mhlo/IR/disc_ral_ops.h @@ -0,0 +1,65 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file defines the operations used in the disc_ral dialect. + +#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_DISC_RAL_OPS_H_ +#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_DISC_RAL_OPS_H_ + +#include "llvm/ADT/StringRef.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/Types.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" + +namespace mlir { +class OpBuilder; + +namespace disc_ral { + +class RalDialect : public Dialect { + public: + explicit RalDialect(MLIRContext* context); + static StringRef getDialectNamespace() { return "disc_ral"; } + + // Parses a type registered to this dialect. + Type parseType(DialectAsmParser& parser) const override; + + // Prints a type registered to this dialect. + void printType(Type type, DialectAsmPrinter& os) const override; +}; + +class RalExecutionContextType + : public Type::TypeBase { + public: + using Base::Base; +}; + +} // end namespace disc_ral +} // end namespace mlir + +#define GET_OP_CLASSES +#include "mlir-hlo/Dialect/mhlo/IR/disc_ral_ops.h.inc" + +#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_DISC_RAL_OPS_H_ diff --git a/include/mlir-hlo/Dialect/mhlo/IR/disc_ral_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/disc_ral_ops.td new file mode 100644 index 0000000..d544a90 --- /dev/null +++ b/include/mlir-hlo/Dialect/mhlo/IR/disc_ral_ops.td @@ -0,0 +1,95 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This is the operation definition file for DISC RAL ops. + +#ifndef DISC_RAL_OPS +#define DISC_RAL_OPS + +include "mlir/IR/OpBase.td" + +def RAL_Dialect : Dialect { + let name = "disc_ral"; + let cppNamespace = "::mlir::disc_ral"; +} + +class RAL_Op traits> : + Op { + let verifier = [{ return Verify(*this); }]; +} + +def RAL_RalExecutionContextType : DialectType()">, + "context">, + BuildableType<"$_builder.getType<::mlir::disc_ral::RalExecutionContextType>()"> { + let description = [{ + RalExecutionContextType corresponds to C++ class RalExecution defined in + disc. + }]; +} + +def RAL_RecvInputOp: RAL_Op<"recv_input", []> { + let summary = "receive input operator"; + let description = [{ + Fetch the specified input. + + In DISC, inputs are not passed as parameter directly. Instead, users need to + explicitly fetch the individual input using this op. + }]; + let arguments = (ins + RAL_RalExecutionContextType:$ctx, + Index:$input_idx + ); + let results = (outs MemRefOf<[AnyType]>); +} + +def RAL_SendOutputOp: RAL_Op<"send_output", []> { + let summary = "send output operator"; + let description = [{ + Send the specified output. + + In DISC, outputs are not returned to the caller directly. Instead, users need to + explicitly send the individual output using this op. + }]; + let arguments = (ins + RAL_RalExecutionContextType:$ctx, + Index:$output_idx, + MemRefOf<[AnyType]>:$result + ); + let results = (outs); +} + +def RAL_DispatchOp: RAL_Op<"dispatch", []> { + let summary = "Dispatch operator"; + let description = [{ + A dispatch op invokes code external to disc. The `args` are passed to the + external code, and the external code is expected to produce a result of the + given type. The exact mechanism is backend-specific. + + `call_target_name` and `backend_config` can be arbitrary strings, but + `call_target_name` should be short as it may be used in labels. + `backend_config` can encode arbitrarily large amounts of information. + }]; + let arguments = (ins + RAL_RalExecutionContextType:$ctx, + Variadic:$args, + StrAttr:$call_target_name, + DefaultValuedAttr:$has_side_effect, + DefaultValuedAttr:$backend_config + ); + let results = (outs Variadic); +} + +#endif // DISC_RAL_OPS diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/CMakeLists.txt b/include/mlir-hlo/Dialect/mhlo/transforms/CMakeLists.txt index 6de6851..e6af855 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/CMakeLists.txt +++ b/include/mlir-hlo/Dialect/mhlo/transforms/CMakeLists.txt @@ -21,3 +21,7 @@ add_public_tablegen_target(MLIRMhloPassIncGen) set(LLVM_TARGET_DEFINITIONS lmhlo_passes.td) mlir_tablegen(lmhlo_passes.h.inc -gen-pass-decls -name LMHLO) add_public_tablegen_target(MLIRLmhloPassIncGen) + +set(LLVM_TARGET_DEFINITIONS disc_ral_passes.td) +mlir_tablegen(disc_ral_passes.h.inc -gen-pass-decls -name RAL) +add_public_tablegen_target(MLIRDiscRalPassIncGen) diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/PassDetail.h b/include/mlir-hlo/Dialect/mhlo/transforms/PassDetail.h index 5f18eeb..f3b6e2e 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/PassDetail.h +++ b/include/mlir-hlo/Dialect/mhlo/transforms/PassDetail.h @@ -27,4 +27,13 @@ namespace mhlo { } // end namespace mhlo } // end namespace mlir +namespace mlir { +namespace disc_ral { + +#define GEN_PASS_CLASSES +#include "mlir-hlo/Dialect/mhlo/transforms/disc_ral_passes.h.inc" + +} // end namespace disc_ral +} // end namespace mlir + #endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_PASSDETAIL_H_ diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/disc_ral_passes.td b/include/mlir-hlo/Dialect/mhlo/transforms/disc_ral_passes.td new file mode 100644 index 0000000..10ed7ef --- /dev/null +++ b/include/mlir-hlo/Dialect/mhlo/transforms/disc_ral_passes.td @@ -0,0 +1,25 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +include "mlir/Pass/PassBase.td" + +def RalInjectExecutionContextPass : Pass<"disc-ral-inject-execution-context", "ModuleOp"> { + let summary = "Inject DISC RAL execution context."; + let constructor = "createRalInjectExecutionContextPass()"; + let options = [ + Option<"entry_func_name_", "entry-func-name", "std::string", + /*default=*/"\"main\"", "Name of the entry function.">, + ]; +} diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/passes.h b/include/mlir-hlo/Dialect/mhlo/transforms/passes.h index 1ae5d43..cc2b6eb 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/passes.h +++ b/include/mlir-hlo/Dialect/mhlo/transforms/passes.h @@ -119,6 +119,13 @@ std::unique_ptr> createLegalizeTensorLoadOpPass(); } // namespace lmhlo +namespace disc_ral { + +std::unique_ptr> createRalInjectExecutionContextPass( + const std::string& entry_func_name = "main"); + +} // namespace disc_ral + } // namespace mlir #endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_PASSES_H_ diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/register_passes.h b/include/mlir-hlo/Dialect/mhlo/transforms/register_passes.h index 3b6041c..d76058c 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/register_passes.h +++ b/include/mlir-hlo/Dialect/mhlo/transforms/register_passes.h @@ -43,4 +43,15 @@ inline void registerAllLmhloPasses() { registerLMHLOPasses(); } } // namespace lmhlo } // namespace mlir +namespace mlir { +namespace disc_ral { + +#define GEN_PASS_REGISTRATION +#include "mlir-hlo/Dialect/mhlo/transforms/disc_ral_passes.h.inc" + +inline void registerAllDiscRalPasses() { registerRALPasses(); } + +} // namespace disc_ral +} // namespace mlir + #endif // MLIR_HLO_DIALECT_MHLO_TRANSFORMS_REGISTER_PASSES_H_ diff --git a/lib/Dialect/mhlo/IR/CMakeLists.txt b/lib/Dialect/mhlo/IR/CMakeLists.txt index 6786de6..9ee9bb0 100644 --- a/lib/Dialect/mhlo/IR/CMakeLists.txt +++ b/lib/Dialect/mhlo/IR/CMakeLists.txt @@ -91,6 +91,17 @@ target_link_libraries(LmhloGPUDialect HloOpsCommon ) +add_mlir_dialect_library(DiscRalDialect + disc_ral_ops.cc + + DEPENDS + MLIRdisc_ral_opsIncGen +) +target_link_libraries(DiscRalDialect + PUBLIC + MLIRIR +) + add_mlir_dialect_library(MhloRegisterDialects init.cc DEPENDS diff --git a/lib/Dialect/mhlo/IR/disc_ral_ops.cc b/lib/Dialect/mhlo/IR/disc_ral_ops.cc new file mode 100644 index 0000000..8ddb7c0 --- /dev/null +++ b/lib/Dialect/mhlo/IR/disc_ral_ops.cc @@ -0,0 +1,64 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file defines the operations used in the DISC RAL dialect. + +#include "mlir-hlo/Dialect/mhlo/IR/disc_ral_ops.h" + +namespace mlir { +namespace disc_ral { + +template +static LogicalResult Verify(T op) { + return success(); +} + +//===----------------------------------------------------------------------===// +// ral Dialect Constructor +//===----------------------------------------------------------------------===// + +RalDialect::RalDialect(MLIRContext* context) + : Dialect(getDialectNamespace(), context, TypeID::get()) { + addOperations< +#define GET_OP_LIST +#include "mlir-hlo/Dialect/mhlo/IR/disc_ral_ops.cc.inc" + >(); + addTypes(); + context->loadDialect(); +} + +Type RalDialect::parseType(DialectAsmParser& parser) const { + StringRef data_type; + if (parser.parseKeyword(&data_type)) return Type(); + + if (data_type == "context") return RalExecutionContextType::get(getContext()); + parser.emitError(parser.getNameLoc()) + << "unknown disc_ral type: " << data_type; + return nullptr; +} + +void RalDialect::printType(Type type, DialectAsmPrinter& os) const { + if (type.isa()) { + os << "context"; + return; + } + os << ""; +} + +} // namespace disc_ral +} // namespace mlir + +#define GET_OP_CLASSES +#include "mlir-hlo/Dialect/mhlo/IR/disc_ral_ops.cc.inc" diff --git a/lib/Dialect/mhlo/transforms/CMakeLists.txt b/lib/Dialect/mhlo/transforms/CMakeLists.txt index 3ce712b..48c2afd 100644 --- a/lib/Dialect/mhlo/transforms/CMakeLists.txt +++ b/lib/Dialect/mhlo/transforms/CMakeLists.txt @@ -154,9 +154,26 @@ add_mlir_library(LmhloPasses MLIRPass ) +add_mlir_library(DiscRalPasses + ral_inject_execution_context.cc + + DEPENDS + MLIRdisc_ral_opsIncGen + MLIRDiscRalPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + DiscRalDialect + MLIRIR + MLIRPass +) + add_library(AllMhloPasses INTERFACE) target_link_libraries(AllMhloPasses INTERFACE ChloPasses + DiscRalPasses MhloPasses MhloToLhloConversion MhloToStandard diff --git a/lib/Dialect/mhlo/transforms/ral_inject_execution_context.cc b/lib/Dialect/mhlo/transforms/ral_inject_execution_context.cc new file mode 100644 index 0000000..b5ef862 --- /dev/null +++ b/lib/Dialect/mhlo/transforms/ral_inject_execution_context.cc @@ -0,0 +1,139 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file implements logic for injecting execution context to the entry +// function. +// +// Below is an example. Before Conversion: +// ``` +// func @main(%arg0: memref, %arg1: memref) -> +// memref { +// %0 = memref.alloc(...) +// "lmhlo.add"(%arg0, %arg1, %0) : (memref, memref, +// memref) -> memref return %0 : memref +// } +// ``` +// After conversion: +// ``` +// func @main(%ctx: !disc_ral.context) { +// %c0 = constant 0 : index +// %c1 = constant 1 : index +// "disc_ral.recv_input"(%ctx, %c0) : (!disc_ral.context, index) -> +// memref "disc_ral.recv_input"(%ctx, %c1) : (!disc_ral.context, +// index) -> memref %0 = memref.alloc(...) "lmhlo.add"(%arg0, +// %arg1, %0) : (memref, memref, memref) -> +// memref "disc_ral.send_output"(%ctx, %c0, %0) : +// (!disc_ral.context, index, memref) -> () +// } +// ``` + +// 1. rewrite entry function (supposed that no other function directly calls the +// entry function) +// - function signature rewrite +// - return-like ops rewrite. +// 2. Currently we suppose that functions except the entry function are inlined +// to the entry function. Thus, we don't rewrite all call ops and other +// functions a.t.m. Re-visit this assumption if necessary. + +#include "mlir-hlo/Dialect/mhlo/IR/disc_ral_ops.h" +#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Operation.h" +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace disc_ral { + +namespace { + +struct RalInjectExecutionContextPass + : public RalInjectExecutionContextPassBase { + explicit RalInjectExecutionContextPass(const std::string& entry_func_name) + : RalInjectExecutionContextPassBase:: + RalInjectExecutionContextPassBase() { + this->entry_func_name_ = entry_func_name; + } + + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } + + void runOnOperation() override { + ModuleOp m = getOperation(); + FuncOp main = m.lookupSymbol(entry_func_name_); + if (!main) { + m.emitError("entry func: " + entry_func_name_ + " not found"); + signalPassFailure(); + } + + Location loc = main.getLoc(); + FunctionType funcType = main.getType(); + OpBuilder b(&main.getBody()); + Block* entry_block = &main.getBody().front(); + Type ctx_type = RalExecutionContextType::get(b.getContext()); + + // 1. Prepend context to the entry block arguments + Value ctx = entry_block->insertArgument(0u, ctx_type); + + // 2. remap original arguments to recv_input ops + for (auto&& en : llvm::enumerate( + llvm::zip(funcType.getInputs(), + entry_block->getArguments().drop_front(1)))) { + Value idx = b.create(loc, en.index()); + Type argType = std::get<0>(en.value()); + Value oldArgument = std::get<1>(en.value()); + Value newInput = b.create(loc, argType, ctx, idx); + oldArgument.replaceAllUsesWith(newInput); + } + + // 3. remap all return-like ops to send_output ops + for (auto& block : main.getBody()) { + if (block.empty()) continue; + Operation& operation = block.back(); + if (!operation.hasTrait()) continue; + b.setInsertionPoint(&operation); + for (auto& en : llvm::enumerate(operation.getOperands())) { + Value idx = b.create(loc, en.index()); + b.create(loc, ctx, idx, en.value()); + } + operation.eraseOperands(0, operation.getNumOperands()); + } + + // 4. remove unused block arguments of entry block + for (int i = 0, e = funcType.getInputs().size(); i < e; ++i) { + // continue to remove the 1st (starting from zero) argument + entry_block->eraseArgument(1); + } + + // 5. set entry func to new type + main.setType(b.getFunctionType({ctx_type}, {})); + } +}; + +} // namespace + +std::unique_ptr> createRalInjectExecutionContextPass( + const std::string& entry_func_name) { + return std::make_unique(entry_func_name); +} + +} // namespace disc_ral +} // namespace mlir diff --git a/tests/disc_ral_inject_execution_context.mlir b/tests/disc_ral_inject_execution_context.mlir new file mode 100644 index 0000000..6fa0461 --- /dev/null +++ b/tests/disc_ral_inject_execution_context.mlir @@ -0,0 +1,19 @@ +// RUN: mlir-hlo-opt -disc-ral-inject-execution-context=entry-func-name=test \ +// RUN: -canonicalize %s -o - | FileCheck %s + +// CHECK-LABEL: func @test +// CHECK-SAME: (%[[CTX:.*]]: !disc_ral.context) { +func @test(%arg0: memref, %arg1: memref, + %arg2: memref, %arg3: memref) -> (memref, memref) { + // %[[T0:.*]] = "disc_ral.recv_input"(%[[CTX]], %c0) : (!disc_ral.context, index) -> memref + // %[[T1:.*]] = "disc_ral.recv_input"(%[[CTX]], %c1) : (!disc_ral.context, index) -> memref + // %[[T2:.*]] = "disc_ral.recv_input"(%[[CTX]], %c2) : (!disc_ral.context, index) -> memref + // %[[T3:.*]] = "disc_ral.recv_input"(%[[CTX]], %c3) : (!disc_ral.context, index) -> memref + // "lmhlo.abs"(%[[T0]], %[[T1]]) : (memref, memref) -> () + // "lmhlo.add"(%[[T1]], %[[T2]], %[[T3]]) : (memref, memref, memref) -> () + // "disc_ral.send_output"(%[[CTX]], %c0, %[[T0]]) : (!disc_ral.context, index, memref) -> () + // "disc_ral.send_output"(%[[CTX]], %c1, %[[T3]]) : (!disc_ral.context, index, memref) -> () + "lmhlo.abs"(%arg0, %arg1) : (memref, memref) -> () + "lmhlo.add"(%arg1, %arg2, %arg3) : (memref, memref, memref) -> () + return %arg1, %arg3 : memref, memref +} diff --git a/tools/mlir-hlo-opt/mlir-hlo-opt.cpp b/tools/mlir-hlo-opt/mlir-hlo-opt.cpp index 5c9a1e1..9a9cee8 100644 --- a/tools/mlir-hlo-opt/mlir-hlo-opt.cpp +++ b/tools/mlir-hlo-opt/mlir-hlo-opt.cpp @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/IR/disc_ral_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" @@ -26,6 +27,7 @@ int main(int argc, char **argv) { mlir::registerAllPasses(); mlir::mhlo::registerAllMhloPasses(); mlir::lmhlo::registerAllLmhloPasses(); + mlir::disc_ral::registerAllDiscRalPasses(); mlir::DialectRegistry registry; mlir::registerAllDialects(registry); @@ -33,6 +35,7 @@ int main(int argc, char **argv) { registry.insert(); registry.insert(); registry.insert(); + registry.insert(); return failed(mlir::MlirOptMain(argc, argv, "MLIR HLO pass driver\n", registry,