diff --git a/BUILD b/BUILD index fe62a1b..549ee59 100644 --- a/BUILD +++ b/BUILD @@ -426,6 +426,7 @@ cc_library( ":infer_fusibility_op_interface", "@llvm-project//llvm:Support", "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:ControlFlowInterfaces", "@llvm-project//mlir:IR", "@llvm-project//mlir:InferTypeOpInterface", "@llvm-project//mlir:Pass", diff --git a/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h b/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h index b179531..87f3025 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h +++ b/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h @@ -26,6 +26,7 @@ limitations under the License. #include "mlir/IR/Operation.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/IR/Types.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" diff --git a/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td index 22106d9..bebe3dc 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td @@ -30,6 +30,7 @@ limitations under the License. #define CHLO_OPS include "mlir/IR/OpBase.td" +include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td" @@ -754,4 +755,54 @@ def HLOClient_MinimumBroadcastShapesOp : } +def HLOClient_RankSpecializationClusterOp + : HLOClient_Op<"rank_specialization_cluster", [ + SingleBlockImplicitTerminator<"RankSpecializationClusterYieldOp">, + RecursiveSideEffects]> { + + let summary = "Cluster of operations that will be rank-specialized together."; + + let description = [{ + Groups compatible element-wise operatons together so that they can be + rank-specialized together. The operation takes and yields a variadic number + of (unranked) tensor operands. Its body region holds one block with one + block argument per input tensor of the same type. All operations in this + block must only operate on these block arguments. Results are returned + through the `rank_specialization_cluster_yield` operation. + + Example: + + ``` + %0 = "chlo.rank_specialization_cluster"(%arg0, %arg1, %arg2) ({ + ^bb0(%arg0_ : tensor<*xf32>, %arg1_ : tensor<*xf32>, %arg2_ : tensor<*xf32>): + %1 = chlo.broadcast_multiply %arg0_, %arg1_ + : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + %2 = chlo.broadcast_add %1, %arg2_ + : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + "chlo.rank_specialization_cluster_yield"(%2) : (tensor<*xf32>) -> () + }) : (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + ``` + }]; + + let arguments = (ins Variadic:$operands); + let results = (outs Variadic:$results); + let regions = (region SizedRegion<1>:$body); +} + +def HLOClient_RankSpecializationClusterYieldOp + : HLOClient_Op<"rank_specialization_cluster_yield", [NoSideEffect, + ReturnLike, Terminator, HasParent<"RankSpecializationClusterOp">]> { + + let summary = "Yield operation for `rank_specialization_cluster`"; + let description = [{ + This operation yields the results from within the + `chlo.rank_specialization_cluster` operation's region. The operation takes + an arbitrary number of operands and produces no results. The operand number + and types must match the number and types of the parent + `rank_specialization_cluster` operation's results. + }]; + + let arguments = (ins Variadic:$results); +} + #endif // CHLO_OPS diff --git a/lib/Dialect/mhlo/IR/chlo_ops.cc b/lib/Dialect/mhlo/IR/chlo_ops.cc index 3a1caa9..044e498 100644 --- a/lib/Dialect/mhlo/IR/chlo_ops.cc +++ b/lib/Dialect/mhlo/IR/chlo_ops.cc @@ -418,6 +418,31 @@ LogicalResult BroadcastSelectOp::inferReturnTypeComponents( return success(); } +//===----------------------------------------------------------------------===// +// RankSpecializationClusterOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(RankSpecializationClusterOp op) { + if (op.body().getArgumentTypes() != op.getOperandTypes()) + return op.emitOpError() << "block argument types must match operand types"; + + // All operands of nested ops must be defined in the body or declared by the + // cluster. + Block* body = op.getBody(); + for (Operation& nested : body->without_terminator()) { + if (!llvm::all_of(nested.getOpOperands(), [&](OpOperand& operand) { + Operation* def = operand.get().getDefiningOp(); + if (def != nullptr && def->getBlock() == body) return true; + return llvm::is_contained(body->getArguments(), operand.get()); + })) { + return op.emitOpError() + << "nested ops must not depend on implicit operands"; + } + } + + return success(); +} + } // namespace chlo } // namespace mlir diff --git a/tests/chlo_ops.mlir b/tests/chlo_ops.mlir index a4d5f79..ad72543 100644 --- a/tests/chlo_ops.mlir +++ b/tests/chlo_ops.mlir @@ -24,3 +24,45 @@ func @minimum_broadcast_shapes_one_operand(%arg: tensor) { %0 = chlo.minimum_broadcast_shapes %arg : tensor -> tensor return } + +// ----- + +func @rank_specialization_cluster(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, + %arg2 : tensor<*xf32>) -> tensor<*xf32> { + %0 = "chlo.rank_specialization_cluster"(%arg0, %arg1, %arg2) ({ + ^bb0(%arg0_ : tensor<*xf32>, %arg1_ : tensor<*xf32>, %arg2_ : tensor<*xf32>): + %1 = chlo.broadcast_multiply %arg0_, %arg1_ + : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + %2 = chlo.broadcast_add %1, %arg2_ + : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + "chlo.rank_specialization_cluster_yield"(%2) : (tensor<*xf32>) -> () + }) : (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + +// ----- + +func @rank_specialization_cluster(%arg0 : tensor<*xf32>) -> tensor<*xf32> { + // expected-error @+1{{block argument types must match operand types}} + %0 = "chlo.rank_specialization_cluster"(%arg0) ({ + ^bb0(%arg0_ : tensor<*xf32>, %arg1_ : tensor<*xf32>): + "chlo.rank_specialization_cluster_yield"(%arg0_) : (tensor<*xf32>) -> () + }) : (tensor<*xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + +// ----- + +func @rank_specialization_cluster(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, + %arg2 : tensor<*xf32>) -> tensor<*xf32> { + // expected-error @+1{{nested ops must not depend on implicit operands}} + %0 = "chlo.rank_specialization_cluster"(%arg0, %arg1, %arg2) ({ + ^bb0(%arg0_ : tensor<*xf32>, %arg1_ : tensor<*xf32>, %arg2_ : tensor<*xf32>): + %1 = chlo.broadcast_multiply %arg0_, %arg1_ + : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + %2 = chlo.broadcast_add %1, %arg2 + : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + "chlo.rank_specialization_cluster_yield"(%2) : (tensor<*xf32>) -> () + }) : (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +}