[MLIR][HLO] Implement `RegionBranchOpInterface` for rank specialization cluster
PiperOrigin-RevId: 373163196
This commit is contained in:
parent
1432db02e7
commit
7f7a86ad0d
|
@ -757,6 +757,7 @@ def HLOClient_MinimumBroadcastShapesOp :
|
|||
|
||||
def HLOClient_RankSpecializationClusterOp
|
||||
: HLOClient_Op<"rank_specialization_cluster", [
|
||||
DeclareOpInterfaceMethods<RegionBranchOpInterface>,
|
||||
SingleBlockImplicitTerminator<"RankSpecializationClusterYieldOp">,
|
||||
RecursiveSideEffects]> {
|
||||
|
||||
|
|
|
@ -422,7 +422,21 @@ LogicalResult BroadcastSelectOp::inferReturnTypeComponents(
|
|||
// RankSpecializationClusterOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void RankSpecializationClusterOp::getSuccessorRegions(
|
||||
Optional<unsigned> index, ArrayRef<Attribute> operands,
|
||||
SmallVectorImpl<RegionSuccessor>& regions) {
|
||||
// RankSpecializationClusterOp has unconditional control flows into the region
|
||||
// and back to the parent, so return the correct RegionSuccessor purely based
|
||||
// on the index being None or 0.
|
||||
if (index.hasValue()) {
|
||||
regions.push_back(RegionSuccessor(getResults()));
|
||||
return;
|
||||
}
|
||||
regions.push_back(RegionSuccessor(&body()));
|
||||
}
|
||||
|
||||
static LogicalResult Verify(RankSpecializationClusterOp op) {
|
||||
if (failed(RegionBranchOpInterface::verifyTypes(op))) return failure();
|
||||
if (op.body().getArgumentTypes() != op.getOperandTypes())
|
||||
return op.emitOpError() << "block argument types must match operand types";
|
||||
|
||||
|
|
|
@ -42,6 +42,19 @@ func @rank_specialization_cluster(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>,
|
|||
|
||||
// -----
|
||||
|
||||
func @rank_specialization_cluster(%arg0 : tensor<*xf32>,
|
||||
%arg1 : tensor<*xf32>) -> tensor<*xf32> {
|
||||
// expected-error @+1{{source has 2 operands, but target successor needs 1}}
|
||||
%0 = "chlo.rank_specialization_cluster"(%arg0, %arg1) ({
|
||||
^bb0(%arg0_ : tensor<*xf32>, %arg1_ : tensor<*xf32>):
|
||||
"chlo.rank_specialization_cluster_yield"(%arg0_, %arg1_)
|
||||
: (tensor<*xf32>, tensor<*xf32>) -> ()
|
||||
}) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @rank_specialization_cluster(%arg0 : tensor<*xf32>) -> tensor<*xf32> {
|
||||
// expected-error @+1{{block argument types must match operand types}}
|
||||
%0 = "chlo.rank_specialization_cluster"(%arg0) ({
|
||||
|
|
Loading…
Reference in New Issue