[MLIR][HLO] Implement `RegionBranchOpInterface` for rank specialization cluster
PiperOrigin-RevId: 373163196
This commit is contained in:
parent
1432db02e7
commit
7f7a86ad0d
|
@ -757,6 +757,7 @@ def HLOClient_MinimumBroadcastShapesOp :
|
||||||
|
|
||||||
def HLOClient_RankSpecializationClusterOp
|
def HLOClient_RankSpecializationClusterOp
|
||||||
: HLOClient_Op<"rank_specialization_cluster", [
|
: HLOClient_Op<"rank_specialization_cluster", [
|
||||||
|
DeclareOpInterfaceMethods<RegionBranchOpInterface>,
|
||||||
SingleBlockImplicitTerminator<"RankSpecializationClusterYieldOp">,
|
SingleBlockImplicitTerminator<"RankSpecializationClusterYieldOp">,
|
||||||
RecursiveSideEffects]> {
|
RecursiveSideEffects]> {
|
||||||
|
|
||||||
|
|
|
@ -422,7 +422,21 @@ LogicalResult BroadcastSelectOp::inferReturnTypeComponents(
|
||||||
// RankSpecializationClusterOp
|
// RankSpecializationClusterOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
void RankSpecializationClusterOp::getSuccessorRegions(
|
||||||
|
Optional<unsigned> index, ArrayRef<Attribute> operands,
|
||||||
|
SmallVectorImpl<RegionSuccessor>& regions) {
|
||||||
|
// RankSpecializationClusterOp has unconditional control flows into the region
|
||||||
|
// and back to the parent, so return the correct RegionSuccessor purely based
|
||||||
|
// on the index being None or 0.
|
||||||
|
if (index.hasValue()) {
|
||||||
|
regions.push_back(RegionSuccessor(getResults()));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
regions.push_back(RegionSuccessor(&body()));
|
||||||
|
}
|
||||||
|
|
||||||
static LogicalResult Verify(RankSpecializationClusterOp op) {
|
static LogicalResult Verify(RankSpecializationClusterOp op) {
|
||||||
|
if (failed(RegionBranchOpInterface::verifyTypes(op))) return failure();
|
||||||
if (op.body().getArgumentTypes() != op.getOperandTypes())
|
if (op.body().getArgumentTypes() != op.getOperandTypes())
|
||||||
return op.emitOpError() << "block argument types must match operand types";
|
return op.emitOpError() << "block argument types must match operand types";
|
||||||
|
|
||||||
|
|
|
@ -42,6 +42,19 @@ func @rank_specialization_cluster(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>,
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
func @rank_specialization_cluster(%arg0 : tensor<*xf32>,
|
||||||
|
%arg1 : tensor<*xf32>) -> tensor<*xf32> {
|
||||||
|
// expected-error @+1{{source has 2 operands, but target successor needs 1}}
|
||||||
|
%0 = "chlo.rank_specialization_cluster"(%arg0, %arg1) ({
|
||||||
|
^bb0(%arg0_ : tensor<*xf32>, %arg1_ : tensor<*xf32>):
|
||||||
|
"chlo.rank_specialization_cluster_yield"(%arg0_, %arg1_)
|
||||||
|
: (tensor<*xf32>, tensor<*xf32>) -> ()
|
||||||
|
}) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
return %0 : tensor<*xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
func @rank_specialization_cluster(%arg0 : tensor<*xf32>) -> tensor<*xf32> {
|
func @rank_specialization_cluster(%arg0 : tensor<*xf32>) -> tensor<*xf32> {
|
||||||
// expected-error @+1{{block argument types must match operand types}}
|
// expected-error @+1{{block argument types must match operand types}}
|
||||||
%0 = "chlo.rank_specialization_cluster"(%arg0) ({
|
%0 = "chlo.rank_specialization_cluster"(%arg0) ({
|
||||||
|
|
Loading…
Reference in New Issue