From 44deae2aa1f2deabb5a12d78d786b9a31ce9be0b Mon Sep 17 00:00:00 2001 From: Rahul Joshi Date: Tue, 26 Jan 2021 17:23:49 -0800 Subject: [PATCH] [MLIR:HLO] Extend AllReduce to support multiple inputs and results (to model tuples). - Instead of SameTypeOperands, add custom verification to check if operands and results pairwise have the same type. PiperOrigin-RevId: 353986341 --- include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td | 7 ++-- lib/Dialect/mhlo/IR/lhlo_ops.cc | 41 ++++++++++++++++++-- tests/lhlo_ops.mlir | 30 ++++++++++++++ 3 files changed, 72 insertions(+), 6 deletions(-) diff --git a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td index 30a8222..7b8dd94 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td @@ -544,17 +544,18 @@ def LHLO_ReducePrecisionOp: LHLO_Op<"reduce_precision", [SameTypeOperands]>, ); } -def LHLO_AllReduceOp : LHLO_Op<"all_reduce", [SameTypeOperands]>, +def LHLO_AllReduceOp : LHLO_Op<"all_reduce", [SameVariadicOperandSize]>, BASE_HLO_AllReduceOp { let arguments = (ins - Arg:$operand, - Arg:$output, + Arg, "", [MemRead]>:$operands, + Arg, "", [MemWrite]>:$results, I64ElementsAttr:$replica_groups, DefaultValuedAttr:$constrain_layout, OptionalAttr:$channel_id, DefaultValuedAttr:$use_global_device_ids ); let regions = (region SizedRegion<1>:$computation); + let verifier = [{ return Verify(*this); }]; } def LHLO_CollectivePermuteOp: LHLO_Op<"collective_permute", [SameTypeOperands]>, diff --git a/lib/Dialect/mhlo/IR/lhlo_ops.cc b/lib/Dialect/mhlo/IR/lhlo_ops.cc index f4ca3a1..5d06a6f 100644 --- a/lib/Dialect/mhlo/IR/lhlo_ops.cc +++ b/lib/Dialect/mhlo/IR/lhlo_ops.cc @@ -48,7 +48,7 @@ limitations under the License. namespace mlir { namespace lmhlo { -LmhloDialect::LmhloDialect(MLIRContext *context) +LmhloDialect::LmhloDialect(MLIRContext* context) : Dialect(getDialectNamespace(), context, TypeID::get()) { addOperations< #define GET_OP_LIST @@ -56,6 +56,41 @@ LmhloDialect::LmhloDialect(MLIRContext *context) >(); } +//===----------------------------------------------------------------------===// +// AllReduceOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(AllReduceOp op) { + // AllReduce had variadic operands and results that have the same size. + // Each memeber of the operand should have the same type as the corresponding + // member of the result. + for (auto it : llvm::enumerate( + llvm::zip(op.operands().getTypes(), op.results().getTypes()))) { + Type operandType = std::get<0>(it.value()); + Type resultType = std::get<1>(it.value()); + if (operandType != resultType) + return op.emitOpError("requires operand #") + << it.index() << " (type: " << operandType << ") and result #" + << it.index() << " (type: " << resultType << ") to have same type"; + } + + // Since AllReduce has a single reduction computation attached to it (which is + // applied over all the operands and results), they all need to have the same + // element type. Since we already check that each operand and corresponding + // result has the same type, its sufficient to check just the memref element + // type for each operands. + Type elementType = + op.operands().front().getType().cast().getElementType(); + bool allMatch = llvm::all_of( + op.operands().drop_front().getType(), [elementType](Type type) { + return type.cast().getElementType() == elementType; + }); + if (!allMatch) + return op.emitOpError("requires all operands to have same element type"); + + return success(); +} + //===----------------------------------------------------------------------===// // ConstOp. //===----------------------------------------------------------------------===// @@ -99,10 +134,10 @@ namespace lmhlo { // TODO(cheshire): Support folding, reuse code from hlo_ops.cc. -void FusionOp::build(OpBuilder &builder, OperationState &result, +void FusionOp::build(OpBuilder& builder, OperationState& result, ArrayRef attributes) { result.addAttributes(attributes); - Region *bodyRegion = result.addRegion(); + Region* bodyRegion = result.addRegion(); FusionOp::ensureTerminator(*bodyRegion, builder, result.location); } diff --git a/tests/lhlo_ops.mlir b/tests/lhlo_ops.mlir index 2167cad..62dab01 100644 --- a/tests/lhlo_ops.mlir +++ b/tests/lhlo_ops.mlir @@ -2,6 +2,36 @@ // ----- +func @invalid_allreduce(%input0: memref<2xf32>, %input1: memref<3xf32>) { + // expected-error@+1 {{requires operand #1 (type: 'memref<3xf32>') and result #1 (type: 'memref<2xf32>') to have same type}} + "lmhlo.all_reduce"(%input0, %input1, %input0, %input0) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + %add = mhlo.add %arg0, %arg1 : tensor + "mhlo.return"(%add) : (tensor) -> () + }) + {channel_id = {handle = 1 : i64, type = 0 : i64}, constrain_layout = false, + replica_groups = dense<[[0, 1, 2, 3], [5, 6, 7, 8]]> : tensor<2x4xi64>, + use_global_device_ids = false} : (memref<2xf32>, memref<3xf32>, memref<2xf32>, memref<2xf32>) -> () + return +} + +// ----- + +func @invalid_allreduce(%input0: memref<2xf32>, %input1: memref<3xf16>) { + // expected-error@+1 {{requires all operands to have same element type}} + "lmhlo.all_reduce"(%input0, %input1, %input0, %input1) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + %add = mhlo.add %arg0, %arg1 : tensor + "mhlo.return"(%add) : (tensor) -> () + }) + {channel_id = {handle = 1 : i64, type = 0 : i64}, constrain_layout = false, + replica_groups = dense<[[0, 1, 2, 3], [5, 6, 7, 8]]> : tensor<2x4xi64>, + use_global_device_ids = false} : (memref<2xf32>, memref<3xf16>, memref<2xf32>, memref<3xf16>) -> () + return +} + +// ----- + // CHECK-LABEL: func @ceil func @ceil(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { "lmhlo.ceil"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()