From 6bc854f5d92d94621d7d011927483519b248ec84 Mon Sep 17 00:00:00 2001 From: dfki-jugr Date: Wed, 5 May 2021 00:26:46 -0700 Subject: [PATCH] PR #48667: [mlir-hlo] Added RegionBranchOpInterfaces to lmhlo operations. Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/48667 Added RegionBranchOpInterfaces to lmhlo operations that use regions. This is needed, since the bufferization features in MLIR have to reason about the control flow within these operations. Copybara import of the project: -- 572fd7d850a46630b812da84e9094280f89f259e by Julian Gross : Added RegionBranchOpInterfaces to lmhlo operations. PiperOrigin-RevId: 372070825 --- BUILD | 7 +++ include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h | 2 + include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td | 19 +++++-- lib/Dialect/mhlo/IR/lhlo_ops.cc | 60 ++++++++++++++++++++ 4 files changed, 83 insertions(+), 5 deletions(-) diff --git a/BUILD b/BUILD index 0bab387..fe62a1b 100644 --- a/BUILD +++ b/BUILD @@ -14,6 +14,7 @@ td_library( name = "hlo_ops_td_files", srcs = glob(["include/mlir-hlo/Dialect/mhlo/IR/*.td"]) + [ # TODO(gcmn): These should be encapsulate in a td_library. + "@llvm-project//mlir:include/mlir/Interfaces/ControlFlowInterfaces.td", "@llvm-project//mlir:include/mlir/Interfaces/CopyOpInterface.td", "@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td", "@llvm-project//mlir:include/mlir/Interfaces/LoopLikeInterface.td", @@ -23,6 +24,8 @@ td_library( ], includes = ["include"], deps = [ + "@llvm-project//mlir:ControlFlowInterfacesTdFiles", + "@llvm-project//mlir:LoopLikeInterfaceTdFiles", "@llvm-project//mlir:MemRefOpsTdFiles", "@llvm-project//mlir:OpBaseTdFiles", "@llvm-project//mlir:SideEffectTdFiles", @@ -461,8 +464,10 @@ cc_library( ":lhlo_ops_structs_inc_gen", "@llvm-project//llvm:Support", "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:ControlFlowInterfaces", "@llvm-project//mlir:CopyOpInterface", "@llvm-project//mlir:IR", + "@llvm-project//mlir:LoopLikeInterface", "@llvm-project//mlir:MemRefDialect", "@llvm-project//mlir:Pass", "@llvm-project//mlir:SideEffects", @@ -496,9 +501,11 @@ cc_library( ":lhlo_gpu_ops_structs", "@llvm-project//llvm:Support", "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:ControlFlowInterfaces", "@llvm-project//mlir:CopyOpInterface", "@llvm-project//mlir:IR", "@llvm-project//mlir:InferTypeOpInterface", + "@llvm-project//mlir:LoopLikeInterface", "@llvm-project//mlir:Pass", "@llvm-project//mlir:SideEffects", "@llvm-project//mlir:StandardOps", diff --git a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h index 7d32cff..3eb9161 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h +++ b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h @@ -31,7 +31,9 @@ limitations under the License. #include "mlir/IR/OpDefinition.h" #include "mlir/IR/Operation.h" #include "mlir/IR/Types.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/CopyOpInterface.h" +#include "mlir/Interfaces/LoopLikeInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/ViewLikeInterface.h" diff --git a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td index eb8c4d1..a4ca489 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td @@ -35,7 +35,9 @@ limitations under the License. include "mlir/Dialect/MemRef/IR/MemRefBase.td" include "mlir/IR/OpBase.td" +include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/CopyOpInterface.td" +include "mlir/Interfaces/LoopLikeInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/ViewLikeInterface.td" include "mlir-hlo/Dialect/mhlo/IR/lhlo_dialect.td" @@ -241,8 +243,9 @@ def LHLO_ReduceWindowOp: LHLO_Op<"reduce_window", [SameVariadicOperandSize]>, // TODO(timshen): Add a custom syntax for this. def LHLO_CaseOp: LHLO_Op<"case", [ - SingleBlockImplicitTerminator<"TerminatorOp"> - ]>, BASE_HLO_CaseOp { + SingleBlockImplicitTerminator<"TerminatorOp">, + DeclareOpInterfaceMethods]>, + BASE_HLO_CaseOp { let arguments = (ins Arg:$index); @@ -250,7 +253,10 @@ def LHLO_CaseOp: LHLO_Op<"case", [ } // TODO(timshen): Add a custom syntax for this. -def LHLO_WhileOp: LHLO_Op<"while", []>, BASE_HLO_WhileOp { +def LHLO_WhileOp: LHLO_Op<"while", [ + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]>, + BASE_HLO_WhileOp { let arguments = (ins Arg, "", [MemWrite]>:$cond_val, OptionalAttr:$trip_count); @@ -669,7 +675,10 @@ def LHLO_SortOp: LHLO_Op<"sort", [SameVariadicOperandSize, SameOperandsShape]>, // Late operations //===----------------------------------------------------------------------===// -def FusionOp : LHLO_Op<"fusion", [SingleBlockImplicitTerminator<"TerminatorOp">]> { +def FusionOp : LHLO_Op<"fusion", [ + SingleBlockImplicitTerminator<"TerminatorOp">, + DeclareOpInterfaceMethods + ]> { let summary = "Fusion operator"; let description = [{ Models the fusion instruction generated by the XLA compiler's fusion pass. @@ -725,7 +734,7 @@ def FusionOp : LHLO_Op<"fusion", [SingleBlockImplicitTerminator<"TerminatorOp">] } def TerminatorOp : - LHLO_Op<"terminator", [Terminator]> { + LHLO_Op<"terminator", [ReturnLike, Terminator]> { let summary = "LHLO termination operation"; let description = [{ Terminator operation for the LHLO dialect. diff --git a/lib/Dialect/mhlo/IR/lhlo_ops.cc b/lib/Dialect/mhlo/IR/lhlo_ops.cc index 1d6f5f4..5456f7e 100644 --- a/lib/Dialect/mhlo/IR/lhlo_ops.cc +++ b/lib/Dialect/mhlo/IR/lhlo_ops.cc @@ -159,6 +159,23 @@ static LogicalResult Verify(AllReduceOp op) { return success(); } +//===----------------------------------------------------------------------===// +// CaseOp +//===----------------------------------------------------------------------===// + +void CaseOp::getSuccessorRegions(Optional index, + ArrayRef operands, + SmallVectorImpl& regions) { + // If the predecessor is the CaseOp, branch to all other branches. + if (!index.hasValue()) { + for (auto& branch : branches()) + regions.push_back(RegionSuccessor(&branch, branch.getArguments())); + } + // If the predecessor is one of the branches, branch back to the parent + // operation. + regions.push_back(RegionSuccessor()); +} + //===----------------------------------------------------------------------===// // CollectivePermuteOp //===----------------------------------------------------------------------===// @@ -316,6 +333,36 @@ static LogicalResult Verify(ReduceWindowOp op) { return success(); } +//===----------------------------------------------------------------------===// +// WhileOp +//===----------------------------------------------------------------------===// + +void WhileOp::getSuccessorRegions(Optional index, + ArrayRef operands, + SmallVectorImpl& regions) { + // If the predecessor is the WhileOp or the body region, branch into the + // cond region. + if (!index.hasValue() || index.getValue() == 1) { + regions.push_back(RegionSuccessor(&cond(), cond().getArguments())); + return; + } + // If the predecessor is the cond region, we can branch to the body region + // or back to the parent operation. + regions.push_back(RegionSuccessor(&body(), body().getArguments())); + regions.push_back(RegionSuccessor()); +} + +Region& WhileOp::getLoopBody() { return body(); } + +bool WhileOp::isDefinedOutsideOfLoop(Value value) { + return !body().isAncestor(value.getParentRegion()); +} + +LogicalResult WhileOp::moveOutOfLoop(ArrayRef ops) { + for (auto op : ops) op->moveBefore(*this); + return success(); +} + } // namespace lmhlo } // namespace mlir @@ -334,5 +381,18 @@ void FusionOp::build(OpBuilder& builder, OperationState& result, FusionOp::ensureTerminator(*bodyRegion, builder, result.location); } +void FusionOp::getSuccessorRegions(Optional index, + ArrayRef operands, + SmallVectorImpl& regions) { + // If the predecessor is the fusion region, jump back to the parent op. + if (index.hasValue()) { + assert(index.getValue() == 0 && "expected fusion region"); + regions.push_back(RegionSuccessor()); + } else { + // If the predecessor is the FusionOp, branch into the region. + regions.push_back(RegionSuccessor(®ion(), region().getArguments())); + } +} + } // namespace lmhlo } // namespace mlir