[MLIR][HLO] Implement `RegionBranchOpInterface` for rank specialization cluster

PiperOrigin-RevId: 373163196
This commit is contained in:
A. Unique TensorFlower 2021-05-11 09:01:41 -07:00 committed by TensorFlow MLIR Team
parent 1432db02e7
commit 7f7a86ad0d
3 changed files with 28 additions and 0 deletions

View File

@ -757,6 +757,7 @@ def HLOClient_MinimumBroadcastShapesOp :
def HLOClient_RankSpecializationClusterOp
: HLOClient_Op<"rank_specialization_cluster", [
DeclareOpInterfaceMethods<RegionBranchOpInterface>,
SingleBlockImplicitTerminator<"RankSpecializationClusterYieldOp">,
RecursiveSideEffects]> {

View File

@ -422,7 +422,21 @@ LogicalResult BroadcastSelectOp::inferReturnTypeComponents(
// RankSpecializationClusterOp
//===----------------------------------------------------------------------===//
void RankSpecializationClusterOp::getSuccessorRegions(
Optional<unsigned> index, ArrayRef<Attribute> operands,
SmallVectorImpl<RegionSuccessor>& regions) {
// RankSpecializationClusterOp has unconditional control flows into the region
// and back to the parent, so return the correct RegionSuccessor purely based
// on the index being None or 0.
if (index.hasValue()) {
regions.push_back(RegionSuccessor(getResults()));
return;
}
regions.push_back(RegionSuccessor(&body()));
}
static LogicalResult Verify(RankSpecializationClusterOp op) {
if (failed(RegionBranchOpInterface::verifyTypes(op))) return failure();
if (op.body().getArgumentTypes() != op.getOperandTypes())
return op.emitOpError() << "block argument types must match operand types";

View File

@ -42,6 +42,19 @@ func @rank_specialization_cluster(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>,
// -----
func @rank_specialization_cluster(%arg0 : tensor<*xf32>,
%arg1 : tensor<*xf32>) -> tensor<*xf32> {
// expected-error @+1{{source has 2 operands, but target successor needs 1}}
%0 = "chlo.rank_specialization_cluster"(%arg0, %arg1) ({
^bb0(%arg0_ : tensor<*xf32>, %arg1_ : tensor<*xf32>):
"chlo.rank_specialization_cluster_yield"(%arg0_, %arg1_)
: (tensor<*xf32>, tensor<*xf32>) -> ()
}) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
// -----
func @rank_specialization_cluster(%arg0 : tensor<*xf32>) -> tensor<*xf32> {
// expected-error @+1{{block argument types must match operand types}}
%0 = "chlo.rank_specialization_cluster"(%arg0) ({