diff --git a/BUILD b/BUILD index 0bab387..fe62a1b 100644 --- a/BUILD +++ b/BUILD @@ -14,6 +14,7 @@ td_library( name = "hlo_ops_td_files", srcs = glob(["include/mlir-hlo/Dialect/mhlo/IR/*.td"]) + [ # TODO(gcmn): These should be encapsulate in a td_library. + "@llvm-project//mlir:include/mlir/Interfaces/ControlFlowInterfaces.td", "@llvm-project//mlir:include/mlir/Interfaces/CopyOpInterface.td", "@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td", "@llvm-project//mlir:include/mlir/Interfaces/LoopLikeInterface.td", @@ -23,6 +24,8 @@ td_library( ], includes = ["include"], deps = [ + "@llvm-project//mlir:ControlFlowInterfacesTdFiles", + "@llvm-project//mlir:LoopLikeInterfaceTdFiles", "@llvm-project//mlir:MemRefOpsTdFiles", "@llvm-project//mlir:OpBaseTdFiles", "@llvm-project//mlir:SideEffectTdFiles", @@ -461,8 +464,10 @@ cc_library( ":lhlo_ops_structs_inc_gen", "@llvm-project//llvm:Support", "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:ControlFlowInterfaces", "@llvm-project//mlir:CopyOpInterface", "@llvm-project//mlir:IR", + "@llvm-project//mlir:LoopLikeInterface", "@llvm-project//mlir:MemRefDialect", "@llvm-project//mlir:Pass", "@llvm-project//mlir:SideEffects", @@ -496,9 +501,11 @@ cc_library( ":lhlo_gpu_ops_structs", "@llvm-project//llvm:Support", "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:ControlFlowInterfaces", "@llvm-project//mlir:CopyOpInterface", "@llvm-project//mlir:IR", "@llvm-project//mlir:InferTypeOpInterface", + "@llvm-project//mlir:LoopLikeInterface", "@llvm-project//mlir:Pass", "@llvm-project//mlir:SideEffects", "@llvm-project//mlir:StandardOps", diff --git a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h index 7d32cff..3eb9161 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h +++ b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h @@ -31,7 +31,9 @@ limitations under the License. #include "mlir/IR/OpDefinition.h" #include "mlir/IR/Operation.h" #include "mlir/IR/Types.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/CopyOpInterface.h" +#include "mlir/Interfaces/LoopLikeInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/ViewLikeInterface.h" diff --git a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td index eb8c4d1..a4ca489 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td @@ -35,7 +35,9 @@ limitations under the License. include "mlir/Dialect/MemRef/IR/MemRefBase.td" include "mlir/IR/OpBase.td" +include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/CopyOpInterface.td" +include "mlir/Interfaces/LoopLikeInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/ViewLikeInterface.td" include "mlir-hlo/Dialect/mhlo/IR/lhlo_dialect.td" @@ -241,8 +243,9 @@ def LHLO_ReduceWindowOp: LHLO_Op<"reduce_window", [SameVariadicOperandSize]>, // TODO(timshen): Add a custom syntax for this. def LHLO_CaseOp: LHLO_Op<"case", [ - SingleBlockImplicitTerminator<"TerminatorOp"> - ]>, BASE_HLO_CaseOp { + SingleBlockImplicitTerminator<"TerminatorOp">, + DeclareOpInterfaceMethods]>, + BASE_HLO_CaseOp { let arguments = (ins Arg:$index); @@ -250,7 +253,10 @@ def LHLO_CaseOp: LHLO_Op<"case", [ } // TODO(timshen): Add a custom syntax for this. -def LHLO_WhileOp: LHLO_Op<"while", []>, BASE_HLO_WhileOp { +def LHLO_WhileOp: LHLO_Op<"while", [ + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]>, + BASE_HLO_WhileOp { let arguments = (ins Arg, "", [MemWrite]>:$cond_val, OptionalAttr:$trip_count); @@ -669,7 +675,10 @@ def LHLO_SortOp: LHLO_Op<"sort", [SameVariadicOperandSize, SameOperandsShape]>, // Late operations //===----------------------------------------------------------------------===// -def FusionOp : LHLO_Op<"fusion", [SingleBlockImplicitTerminator<"TerminatorOp">]> { +def FusionOp : LHLO_Op<"fusion", [ + SingleBlockImplicitTerminator<"TerminatorOp">, + DeclareOpInterfaceMethods + ]> { let summary = "Fusion operator"; let description = [{ Models the fusion instruction generated by the XLA compiler's fusion pass. @@ -725,7 +734,7 @@ def FusionOp : LHLO_Op<"fusion", [SingleBlockImplicitTerminator<"TerminatorOp">] } def TerminatorOp : - LHLO_Op<"terminator", [Terminator]> { + LHLO_Op<"terminator", [ReturnLike, Terminator]> { let summary = "LHLO termination operation"; let description = [{ Terminator operation for the LHLO dialect. diff --git a/lib/Dialect/mhlo/IR/lhlo_ops.cc b/lib/Dialect/mhlo/IR/lhlo_ops.cc index 1d6f5f4..5456f7e 100644 --- a/lib/Dialect/mhlo/IR/lhlo_ops.cc +++ b/lib/Dialect/mhlo/IR/lhlo_ops.cc @@ -159,6 +159,23 @@ static LogicalResult Verify(AllReduceOp op) { return success(); } +//===----------------------------------------------------------------------===// +// CaseOp +//===----------------------------------------------------------------------===// + +void CaseOp::getSuccessorRegions(Optional index, + ArrayRef operands, + SmallVectorImpl& regions) { + // If the predecessor is the CaseOp, branch to all other branches. + if (!index.hasValue()) { + for (auto& branch : branches()) + regions.push_back(RegionSuccessor(&branch, branch.getArguments())); + } + // If the predecessor is one of the branches, branch back to the parent + // operation. + regions.push_back(RegionSuccessor()); +} + //===----------------------------------------------------------------------===// // CollectivePermuteOp //===----------------------------------------------------------------------===// @@ -316,6 +333,36 @@ static LogicalResult Verify(ReduceWindowOp op) { return success(); } +//===----------------------------------------------------------------------===// +// WhileOp +//===----------------------------------------------------------------------===// + +void WhileOp::getSuccessorRegions(Optional index, + ArrayRef operands, + SmallVectorImpl& regions) { + // If the predecessor is the WhileOp or the body region, branch into the + // cond region. + if (!index.hasValue() || index.getValue() == 1) { + regions.push_back(RegionSuccessor(&cond(), cond().getArguments())); + return; + } + // If the predecessor is the cond region, we can branch to the body region + // or back to the parent operation. + regions.push_back(RegionSuccessor(&body(), body().getArguments())); + regions.push_back(RegionSuccessor()); +} + +Region& WhileOp::getLoopBody() { return body(); } + +bool WhileOp::isDefinedOutsideOfLoop(Value value) { + return !body().isAncestor(value.getParentRegion()); +} + +LogicalResult WhileOp::moveOutOfLoop(ArrayRef ops) { + for (auto op : ops) op->moveBefore(*this); + return success(); +} + } // namespace lmhlo } // namespace mlir @@ -334,5 +381,18 @@ void FusionOp::build(OpBuilder& builder, OperationState& result, FusionOp::ensureTerminator(*bodyRegion, builder, result.location); } +void FusionOp::getSuccessorRegions(Optional index, + ArrayRef operands, + SmallVectorImpl& regions) { + // If the predecessor is the fusion region, jump back to the parent op. + if (index.hasValue()) { + assert(index.getValue() == 0 && "expected fusion region"); + regions.push_back(RegionSuccessor()); + } else { + // If the predecessor is the FusionOp, branch into the region. + regions.push_back(RegionSuccessor(®ion(), region().getArguments())); + } +} + } // namespace lmhlo } // namespace mlir