From 7f7a86ad0dc9596636899a6d97bde5ac35909a39 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 11 May 2021 09:01:41 -0700 Subject: [PATCH] [MLIR][HLO] Implement `RegionBranchOpInterface` for rank specialization cluster PiperOrigin-RevId: 373163196 --- include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td | 1 + lib/Dialect/mhlo/IR/chlo_ops.cc | 14 ++++++++++++++ tests/chlo_ops.mlir | 13 +++++++++++++ 3 files changed, 28 insertions(+) diff --git a/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td index bebe3dc..ce384d2 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td @@ -757,6 +757,7 @@ def HLOClient_MinimumBroadcastShapesOp : def HLOClient_RankSpecializationClusterOp : HLOClient_Op<"rank_specialization_cluster", [ + DeclareOpInterfaceMethods, SingleBlockImplicitTerminator<"RankSpecializationClusterYieldOp">, RecursiveSideEffects]> { diff --git a/lib/Dialect/mhlo/IR/chlo_ops.cc b/lib/Dialect/mhlo/IR/chlo_ops.cc index 044e498..1fa000e 100644 --- a/lib/Dialect/mhlo/IR/chlo_ops.cc +++ b/lib/Dialect/mhlo/IR/chlo_ops.cc @@ -422,7 +422,21 @@ LogicalResult BroadcastSelectOp::inferReturnTypeComponents( // RankSpecializationClusterOp //===----------------------------------------------------------------------===// +void RankSpecializationClusterOp::getSuccessorRegions( + Optional index, ArrayRef operands, + SmallVectorImpl& regions) { + // RankSpecializationClusterOp has unconditional control flows into the region + // and back to the parent, so return the correct RegionSuccessor purely based + // on the index being None or 0. + if (index.hasValue()) { + regions.push_back(RegionSuccessor(getResults())); + return; + } + regions.push_back(RegionSuccessor(&body())); +} + static LogicalResult Verify(RankSpecializationClusterOp op) { + if (failed(RegionBranchOpInterface::verifyTypes(op))) return failure(); if (op.body().getArgumentTypes() != op.getOperandTypes()) return op.emitOpError() << "block argument types must match operand types"; diff --git a/tests/chlo_ops.mlir b/tests/chlo_ops.mlir index ad72543..922192e 100644 --- a/tests/chlo_ops.mlir +++ b/tests/chlo_ops.mlir @@ -42,6 +42,19 @@ func @rank_specialization_cluster(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, // ----- +func @rank_specialization_cluster(%arg0 : tensor<*xf32>, + %arg1 : tensor<*xf32>) -> tensor<*xf32> { + // expected-error @+1{{source has 2 operands, but target successor needs 1}} + %0 = "chlo.rank_specialization_cluster"(%arg0, %arg1) ({ + ^bb0(%arg0_ : tensor<*xf32>, %arg1_ : tensor<*xf32>): + "chlo.rank_specialization_cluster_yield"(%arg0_, %arg1_) + : (tensor<*xf32>, tensor<*xf32>) -> () + }) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + +// ----- + func @rank_specialization_cluster(%arg0 : tensor<*xf32>) -> tensor<*xf32> { // expected-error @+1{{block argument types must match operand types}} %0 = "chlo.rank_specialization_cluster"(%arg0) ({