PR #48667: [mlir-hlo] Added RegionBranchOpInterfaces to lmhlo operations.
Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/48667 Added RegionBranchOpInterfaces to lmhlo operations that use regions. This is needed, since the bufferization features in MLIR have to reason about the control flow within these operations. Copybara import of the project: -- 572fd7d850a46630b812da84e9094280f89f259e by Julian Gross <julian.gross@dfki.de>: Added RegionBranchOpInterfaces to lmhlo operations. PiperOrigin-RevId: 372070825
This commit is contained in:
parent
ac68145565
commit
6bc854f5d9
7
BUILD
7
BUILD
|
@ -14,6 +14,7 @@ td_library(
|
||||||
name = "hlo_ops_td_files",
|
name = "hlo_ops_td_files",
|
||||||
srcs = glob(["include/mlir-hlo/Dialect/mhlo/IR/*.td"]) + [
|
srcs = glob(["include/mlir-hlo/Dialect/mhlo/IR/*.td"]) + [
|
||||||
# TODO(gcmn): These should be encapsulate in a td_library.
|
# TODO(gcmn): These should be encapsulate in a td_library.
|
||||||
|
"@llvm-project//mlir:include/mlir/Interfaces/ControlFlowInterfaces.td",
|
||||||
"@llvm-project//mlir:include/mlir/Interfaces/CopyOpInterface.td",
|
"@llvm-project//mlir:include/mlir/Interfaces/CopyOpInterface.td",
|
||||||
"@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td",
|
"@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td",
|
||||||
"@llvm-project//mlir:include/mlir/Interfaces/LoopLikeInterface.td",
|
"@llvm-project//mlir:include/mlir/Interfaces/LoopLikeInterface.td",
|
||||||
|
@ -23,6 +24,8 @@ td_library(
|
||||||
],
|
],
|
||||||
includes = ["include"],
|
includes = ["include"],
|
||||||
deps = [
|
deps = [
|
||||||
|
"@llvm-project//mlir:ControlFlowInterfacesTdFiles",
|
||||||
|
"@llvm-project//mlir:LoopLikeInterfaceTdFiles",
|
||||||
"@llvm-project//mlir:MemRefOpsTdFiles",
|
"@llvm-project//mlir:MemRefOpsTdFiles",
|
||||||
"@llvm-project//mlir:OpBaseTdFiles",
|
"@llvm-project//mlir:OpBaseTdFiles",
|
||||||
"@llvm-project//mlir:SideEffectTdFiles",
|
"@llvm-project//mlir:SideEffectTdFiles",
|
||||||
|
@ -461,8 +464,10 @@ cc_library(
|
||||||
":lhlo_ops_structs_inc_gen",
|
":lhlo_ops_structs_inc_gen",
|
||||||
"@llvm-project//llvm:Support",
|
"@llvm-project//llvm:Support",
|
||||||
"@llvm-project//mlir:Analysis",
|
"@llvm-project//mlir:Analysis",
|
||||||
|
"@llvm-project//mlir:ControlFlowInterfaces",
|
||||||
"@llvm-project//mlir:CopyOpInterface",
|
"@llvm-project//mlir:CopyOpInterface",
|
||||||
"@llvm-project//mlir:IR",
|
"@llvm-project//mlir:IR",
|
||||||
|
"@llvm-project//mlir:LoopLikeInterface",
|
||||||
"@llvm-project//mlir:MemRefDialect",
|
"@llvm-project//mlir:MemRefDialect",
|
||||||
"@llvm-project//mlir:Pass",
|
"@llvm-project//mlir:Pass",
|
||||||
"@llvm-project//mlir:SideEffects",
|
"@llvm-project//mlir:SideEffects",
|
||||||
|
@ -496,9 +501,11 @@ cc_library(
|
||||||
":lhlo_gpu_ops_structs",
|
":lhlo_gpu_ops_structs",
|
||||||
"@llvm-project//llvm:Support",
|
"@llvm-project//llvm:Support",
|
||||||
"@llvm-project//mlir:Analysis",
|
"@llvm-project//mlir:Analysis",
|
||||||
|
"@llvm-project//mlir:ControlFlowInterfaces",
|
||||||
"@llvm-project//mlir:CopyOpInterface",
|
"@llvm-project//mlir:CopyOpInterface",
|
||||||
"@llvm-project//mlir:IR",
|
"@llvm-project//mlir:IR",
|
||||||
"@llvm-project//mlir:InferTypeOpInterface",
|
"@llvm-project//mlir:InferTypeOpInterface",
|
||||||
|
"@llvm-project//mlir:LoopLikeInterface",
|
||||||
"@llvm-project//mlir:Pass",
|
"@llvm-project//mlir:Pass",
|
||||||
"@llvm-project//mlir:SideEffects",
|
"@llvm-project//mlir:SideEffects",
|
||||||
"@llvm-project//mlir:StandardOps",
|
"@llvm-project//mlir:StandardOps",
|
||||||
|
|
|
@ -31,7 +31,9 @@ limitations under the License.
|
||||||
#include "mlir/IR/OpDefinition.h"
|
#include "mlir/IR/OpDefinition.h"
|
||||||
#include "mlir/IR/Operation.h"
|
#include "mlir/IR/Operation.h"
|
||||||
#include "mlir/IR/Types.h"
|
#include "mlir/IR/Types.h"
|
||||||
|
#include "mlir/Interfaces/ControlFlowInterfaces.h"
|
||||||
#include "mlir/Interfaces/CopyOpInterface.h"
|
#include "mlir/Interfaces/CopyOpInterface.h"
|
||||||
|
#include "mlir/Interfaces/LoopLikeInterface.h"
|
||||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||||
#include "mlir/Interfaces/ViewLikeInterface.h"
|
#include "mlir/Interfaces/ViewLikeInterface.h"
|
||||||
|
|
||||||
|
|
|
@ -35,7 +35,9 @@ limitations under the License.
|
||||||
|
|
||||||
include "mlir/Dialect/MemRef/IR/MemRefBase.td"
|
include "mlir/Dialect/MemRef/IR/MemRefBase.td"
|
||||||
include "mlir/IR/OpBase.td"
|
include "mlir/IR/OpBase.td"
|
||||||
|
include "mlir/Interfaces/ControlFlowInterfaces.td"
|
||||||
include "mlir/Interfaces/CopyOpInterface.td"
|
include "mlir/Interfaces/CopyOpInterface.td"
|
||||||
|
include "mlir/Interfaces/LoopLikeInterface.td"
|
||||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||||
include "mlir/Interfaces/ViewLikeInterface.td"
|
include "mlir/Interfaces/ViewLikeInterface.td"
|
||||||
include "mlir-hlo/Dialect/mhlo/IR/lhlo_dialect.td"
|
include "mlir-hlo/Dialect/mhlo/IR/lhlo_dialect.td"
|
||||||
|
@ -241,8 +243,9 @@ def LHLO_ReduceWindowOp: LHLO_Op<"reduce_window", [SameVariadicOperandSize]>,
|
||||||
|
|
||||||
// TODO(timshen): Add a custom syntax for this.
|
// TODO(timshen): Add a custom syntax for this.
|
||||||
def LHLO_CaseOp: LHLO_Op<"case", [
|
def LHLO_CaseOp: LHLO_Op<"case", [
|
||||||
SingleBlockImplicitTerminator<"TerminatorOp">
|
SingleBlockImplicitTerminator<"TerminatorOp">,
|
||||||
]>, BASE_HLO_CaseOp {
|
DeclareOpInterfaceMethods<RegionBranchOpInterface>]>,
|
||||||
|
BASE_HLO_CaseOp {
|
||||||
|
|
||||||
let arguments = (ins Arg<LHLO_PredOrIntBuffer, "", [MemRead]>:$index);
|
let arguments = (ins Arg<LHLO_PredOrIntBuffer, "", [MemRead]>:$index);
|
||||||
|
|
||||||
|
@ -250,7 +253,10 @@ def LHLO_CaseOp: LHLO_Op<"case", [
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(timshen): Add a custom syntax for this.
|
// TODO(timshen): Add a custom syntax for this.
|
||||||
def LHLO_WhileOp: LHLO_Op<"while", []>, BASE_HLO_WhileOp {
|
def LHLO_WhileOp: LHLO_Op<"while", [
|
||||||
|
DeclareOpInterfaceMethods<RegionBranchOpInterface>,
|
||||||
|
DeclareOpInterfaceMethods<LoopLikeOpInterface>]>,
|
||||||
|
BASE_HLO_WhileOp {
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
Arg<Variadic<LHLO_PredBuffer>, "", [MemWrite]>:$cond_val,
|
Arg<Variadic<LHLO_PredBuffer>, "", [MemWrite]>:$cond_val,
|
||||||
OptionalAttr<I64Attr>:$trip_count);
|
OptionalAttr<I64Attr>:$trip_count);
|
||||||
|
@ -669,7 +675,10 @@ def LHLO_SortOp: LHLO_Op<"sort", [SameVariadicOperandSize, SameOperandsShape]>,
|
||||||
// Late operations
|
// Late operations
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
def FusionOp : LHLO_Op<"fusion", [SingleBlockImplicitTerminator<"TerminatorOp">]> {
|
def FusionOp : LHLO_Op<"fusion", [
|
||||||
|
SingleBlockImplicitTerminator<"TerminatorOp">,
|
||||||
|
DeclareOpInterfaceMethods<RegionBranchOpInterface>
|
||||||
|
]> {
|
||||||
let summary = "Fusion operator";
|
let summary = "Fusion operator";
|
||||||
let description = [{
|
let description = [{
|
||||||
Models the fusion instruction generated by the XLA compiler's fusion pass.
|
Models the fusion instruction generated by the XLA compiler's fusion pass.
|
||||||
|
@ -725,7 +734,7 @@ def FusionOp : LHLO_Op<"fusion", [SingleBlockImplicitTerminator<"TerminatorOp">]
|
||||||
}
|
}
|
||||||
|
|
||||||
def TerminatorOp :
|
def TerminatorOp :
|
||||||
LHLO_Op<"terminator", [Terminator]> {
|
LHLO_Op<"terminator", [ReturnLike, Terminator]> {
|
||||||
let summary = "LHLO termination operation";
|
let summary = "LHLO termination operation";
|
||||||
let description = [{
|
let description = [{
|
||||||
Terminator operation for the LHLO dialect.
|
Terminator operation for the LHLO dialect.
|
||||||
|
|
|
@ -159,6 +159,23 @@ static LogicalResult Verify(AllReduceOp op) {
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// CaseOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
void CaseOp::getSuccessorRegions(Optional<unsigned> index,
|
||||||
|
ArrayRef<Attribute> operands,
|
||||||
|
SmallVectorImpl<RegionSuccessor>& regions) {
|
||||||
|
// If the predecessor is the CaseOp, branch to all other branches.
|
||||||
|
if (!index.hasValue()) {
|
||||||
|
for (auto& branch : branches())
|
||||||
|
regions.push_back(RegionSuccessor(&branch, branch.getArguments()));
|
||||||
|
}
|
||||||
|
// If the predecessor is one of the branches, branch back to the parent
|
||||||
|
// operation.
|
||||||
|
regions.push_back(RegionSuccessor());
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// CollectivePermuteOp
|
// CollectivePermuteOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -316,6 +333,36 @@ static LogicalResult Verify(ReduceWindowOp op) {
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// WhileOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
void WhileOp::getSuccessorRegions(Optional<unsigned> index,
|
||||||
|
ArrayRef<Attribute> operands,
|
||||||
|
SmallVectorImpl<RegionSuccessor>& regions) {
|
||||||
|
// If the predecessor is the WhileOp or the body region, branch into the
|
||||||
|
// cond region.
|
||||||
|
if (!index.hasValue() || index.getValue() == 1) {
|
||||||
|
regions.push_back(RegionSuccessor(&cond(), cond().getArguments()));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// If the predecessor is the cond region, we can branch to the body region
|
||||||
|
// or back to the parent operation.
|
||||||
|
regions.push_back(RegionSuccessor(&body(), body().getArguments()));
|
||||||
|
regions.push_back(RegionSuccessor());
|
||||||
|
}
|
||||||
|
|
||||||
|
Region& WhileOp::getLoopBody() { return body(); }
|
||||||
|
|
||||||
|
bool WhileOp::isDefinedOutsideOfLoop(Value value) {
|
||||||
|
return !body().isAncestor(value.getParentRegion());
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult WhileOp::moveOutOfLoop(ArrayRef<Operation*> ops) {
|
||||||
|
for (auto op : ops) op->moveBefore(*this);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace lmhlo
|
} // namespace lmhlo
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
|
@ -334,5 +381,18 @@ void FusionOp::build(OpBuilder& builder, OperationState& result,
|
||||||
FusionOp::ensureTerminator(*bodyRegion, builder, result.location);
|
FusionOp::ensureTerminator(*bodyRegion, builder, result.location);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void FusionOp::getSuccessorRegions(Optional<unsigned> index,
|
||||||
|
ArrayRef<Attribute> operands,
|
||||||
|
SmallVectorImpl<RegionSuccessor>& regions) {
|
||||||
|
// If the predecessor is the fusion region, jump back to the parent op.
|
||||||
|
if (index.hasValue()) {
|
||||||
|
assert(index.getValue() == 0 && "expected fusion region");
|
||||||
|
regions.push_back(RegionSuccessor());
|
||||||
|
} else {
|
||||||
|
// If the predecessor is the FusionOp, branch into the region.
|
||||||
|
regions.push_back(RegionSuccessor(®ion(), region().getArguments()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace lmhlo
|
} // namespace lmhlo
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
Loading…
Reference in New Issue