diff --git a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td index 2632876..5274a96 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td @@ -954,6 +954,26 @@ def HLO_AllReduceOp : HLO_Op<"all_reduce", let hasCustomHLOConverter = 1; } +def HLO_AllReduceScatterOp : HLO_Op<"all_reduce_scatter", + [SameOperandsAndResultElementType]> { + let summary = "AllReduceScatter operator"; + let description = [{ + Performs all_reduce followed by a scatter. + + See https://www.tensorflow.org/xla/operation_semantics#allreducescatter + }]; + + let arguments = (ins + HLO_Tensor:$operand, + I64Attr:$scatter_dimension, + I64ElementsAttr:$replica_groups, + OptionalAttr:$channel_handle + ); + let regions = (region SizedRegion<1>:$computation); + let results = (outs HLO_Tensor); + let hasCustomHLOConverter = 1; +} + def HLO_AllToAllOp : HLO_Op<"all_to_all", [NoSideEffect, SameOperandsElementType, SameOperandsShape]> { diff --git a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_common.h b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_common.h index 6aae3d1..09234f0 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_common.h +++ b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_common.h @@ -30,6 +30,10 @@ namespace hlo { LogicalResult VerifyCollectivePermuteSourceTargetPairs( Operation* op, DenseIntElementsAttr attr); +LogicalResult VerifyAllReduceScatter(Operation* op, TypeRange operand_types, + TypeRange result_types, + uint64_t scatter_dimension); + // Custom formatting for convolution window attributes. void printWindowAttributes(OpAsmPrinter& p, Operation* op, llvm::Optional window_strides, diff --git a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td index fce1708..ccdf982 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td @@ -1105,6 +1105,19 @@ def LHLO_AllReduceOp : LHLO_CollectiveCommunicationOp<"all_reduce", [SameOperand let regions = (region SizedRegion<1>:$computation); } +def LHLO_AllReduceScatterOp : LHLO_CollectiveCommunicationOp<"all_reduce_scatter", [SameOperandsElementType]> { + let summary = "AllReduceScatter operator"; + let description = [{ + Performs all_reduce followed by a scatter. + + See https://www.tensorflow.org/xla/operation_semantics#allreducescatter + }]; + let arguments = !con( + arguments_base, + (ins I64Attr:$scatter_dimension)); + let regions = (region SizedRegion<1>:$computation); +} + def LHLO_AllToAllOp : LHLO_CollectiveCommunicationOp<"all_to_all", [SameOperandsElementType]> { let arguments = !con( arguments_base, diff --git a/lib/Dialect/mhlo/IR/hlo_ops.cc b/lib/Dialect/mhlo/IR/hlo_ops.cc index 75a2d41..3c76042 100644 --- a/lib/Dialect/mhlo/IR/hlo_ops.cc +++ b/lib/Dialect/mhlo/IR/hlo_ops.cc @@ -196,6 +196,18 @@ Value MaybeCastTo(OpBuilder& b, Location loc, Value value, Type type) { } // namespace +//===----------------------------------------------------------------------===// +// AllReduceScatterOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(AllReduceScatterOp op) { + return mlir::hlo::VerifyAllReduceScatter( + op, + /*operand_types=*/{op.operand().getType()}, + /*result_types=*/{op.getType()}, + /*scatter_dimension=*/op.scatter_dimension()); +} + //===----------------------------------------------------------------------===// // ConstOp //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/mhlo/IR/hlo_ops_common.cc b/lib/Dialect/mhlo/IR/hlo_ops_common.cc index 15b5e0a..b1fa79e 100644 --- a/lib/Dialect/mhlo/IR/hlo_ops_common.cc +++ b/lib/Dialect/mhlo/IR/hlo_ops_common.cc @@ -53,6 +53,51 @@ LogicalResult VerifyCollectivePermuteSourceTargetPairs( return success(); } +LogicalResult VerifyAllReduceScatter(Operation *op, TypeRange operand_types, + TypeRange result_types, + uint64_t scatter_dimension) { + // If operand and result are both ranked, then the size of the scatter + // dimension in the operand should be a multiple of the size of the scatter + // dimension in the result. + for (auto it : llvm::zip(operand_types, result_types)) { + auto operand_type = std::get<0>(it).cast(); + auto result_type = std::get<1>(it).cast(); + if (!operand_type.hasRank() || !result_type.hasRank()) continue; + if (operand_type.getRank() != result_type.getRank()) + return op->emitOpError() << "operand and result should have same rank"; + if (scatter_dimension >= operand_type.getRank()) + return op->emitOpError() + << "scatter dim should be less than operand/result rank"; + if (operand_type.isDynamicDim(scatter_dimension) || + result_type.isDynamicDim(scatter_dimension)) + continue; + if (operand_type.getDimSize(scatter_dimension) == 0) + return op->emitOpError() << "operand scatter dimension cannot be zero"; + if (result_type.getDimSize(scatter_dimension) == 0) + return op->emitOpError() << "result scatter dimension cannot be zero"; + if ((operand_type.getDimSize(scatter_dimension) % + result_type.getDimSize(scatter_dimension)) != 0) + return op->emitOpError() + << "operand scatter dimension has size " + << operand_type.getDimSize(scatter_dimension) + << ", expected to be a multiple of result scatter dimension size " + << result_type.getDimSize(scatter_dimension); + + // Non scatter dimensions should be equal. + for (uint64_t index : llvm::seq(0, operand_type.getRank())) { + if (index == scatter_dimension || operand_type.isDynamicDim(index) || + result_type.isDynamicDim(index)) + continue; + if (operand_type.getDimSize(index) != result_type.getDimSize(index)) + return op->emitOpError() + << "non scatter dimensions should be same for operand (" + << operand_type.getDimSize(index) << ") and result (" + << result_type.getDimSize(index) << ")"; + } + } + return success(); +} + namespace { // Custom formatting for convolution window attributes. void printWindowAttribute(OpAsmPrinter &p, DenseElementsAttr attribute) { diff --git a/lib/Dialect/mhlo/IR/lhlo_ops.cc b/lib/Dialect/mhlo/IR/lhlo_ops.cc index 72be3a0..9e00c91 100644 --- a/lib/Dialect/mhlo/IR/lhlo_ops.cc +++ b/lib/Dialect/mhlo/IR/lhlo_ops.cc @@ -159,6 +159,21 @@ static LogicalResult Verify(AllReduceOp op) { return success(); } +//===----------------------------------------------------------------------===// +// AllReduceScatterOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(AllReduceScatterOp op) { + if (failed(VerifyReplicaGroups(op, /*is_uniform_sized=*/true))) + return failure(); + if (failed(mlir::hlo::VerifyAllReduceScatter( + op, /*operand_types=*/op.operands().getTypes(), + /*result_types=*/op.results().getTypes(), + /*scatter_dimension=*/op.scatter_dimension()))) + return failure(); + return success(); +} + //===----------------------------------------------------------------------===// // CaseOp //===----------------------------------------------------------------------===// diff --git a/tests/lhlo_ops.mlir b/tests/lhlo_ops.mlir index 9bd1308..2eb34ec 100644 --- a/tests/lhlo_ops.mlir +++ b/tests/lhlo_ops.mlir @@ -32,6 +32,19 @@ func @invalid_allreduce(%input0: memref<2xf32>, %input1: memref<3xf16>) { // ----- +// CHECK-LABEL: func @reduce_scatter +func @reduce_scatter(%data: memref<4x16xf32>, %result:memref<4x4xf32>) { + "lmhlo.all_reduce_scatter"(%data, %result) ( { + // reduction computation + ^bb0(%arg2: tensor, %arg3: tensor): + %1 = mhlo.add %arg2, %arg3 : tensor + "mhlo.return"(%1) : (tensor) -> () + }) {replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, + scatter_dimension = 1 : i64} : (memref<4x16xf32>, memref<4x4xf32>) -> () + return +} +// ----- + // CHECK-LABEL: func @mixed_types_allgather func @mixed_types_allgather(%a0: memref<1x1xf32>, %a1:memref<1x1xi32>) { "lmhlo.all_gather"(%a0, %a1, %a0, %a1) {all_gather_dimension = 0 : i64, diff --git a/tests/ops.mlir b/tests/ops.mlir index f5d5002..e437a8d 100644 --- a/tests/ops.mlir +++ b/tests/ops.mlir @@ -13,6 +13,104 @@ func private @invalid_type() -> !mhlo.foobar // ----- +// CHECK-LABEL: func @reduce_scatter +func @reduce_scatter(%data: tensor<4x16xf32>) -> tensor<4x4xf32> { + %0 = "mhlo.all_reduce_scatter"(%data) ( { + // reduction computation + ^bb0(%arg2: tensor, %arg3: tensor): + %1 = mhlo.add %arg2, %arg3 : tensor + "mhlo.return"(%1) : (tensor) -> () + }) {replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, + scatter_dimension = 1 : i64} : (tensor<4x16xf32>) -> tensor<4x4xf32> + return %0 : tensor<4x4xf32> +} + +// ----- + +func @invalid_reduce_scatter(%data: tensor<4x16xf32>) -> tensor<4x5xf32> { + // expected-error@+1 {{operand scatter dimension has size 16, expected to be a multiple of result scatter dimension size 5}} + %0 = "mhlo.all_reduce_scatter"(%data) ( { + // reduction computation + ^bb0(%arg2: tensor, %arg3: tensor): + %1 = mhlo.add %arg2, %arg3 : tensor + "mhlo.return"(%1) : (tensor) -> () + }) {replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, + scatter_dimension = 1 : i64} : (tensor<4x16xf32>) -> tensor<4x5xf32> + return %0 : tensor<4x5xf32> +} + +// ----- + +func @invalid_reduce_scatter(%data: tensor<4x0xf32>) -> tensor<4x4xf32> { + // expected-error@+1 {{operand scatter dimension cannot be zero}} + %0 = "mhlo.all_reduce_scatter"(%data) ( { + // reduction computation + ^bb0(%arg2: tensor, %arg3: tensor): + %1 = mhlo.add %arg2, %arg3 : tensor + "mhlo.return"(%1) : (tensor) -> () + }) {replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, + scatter_dimension = 1 : i64} : (tensor<4x0xf32>) -> tensor<4x4xf32> + return %0 : tensor<4x4xf32> +} + +// ----- + +func @invalid_reduce_scatter(%data: tensor<4x16xf32>) -> tensor<4x0xf32> { + // expected-error@+1 {{result scatter dimension cannot be zero}} + %0 = "mhlo.all_reduce_scatter"(%data) ( { + // reduction computation + ^bb0(%arg2: tensor, %arg3: tensor): + %1 = mhlo.add %arg2, %arg3 : tensor + "mhlo.return"(%1) : (tensor) -> () + }) {replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, + scatter_dimension = 1 : i64} : (tensor<4x16xf32>) -> tensor<4x0xf32> + return %0 : tensor<4x0xf32> +} + +// ----- + +func @invalid_reduce_scatter(%data: tensor<4x16xf32>) -> tensor<4xf32> { + // expected-error@+1 {{operand and result should have same rank}} + %0 = "mhlo.all_reduce_scatter"(%data) ( { + // reduction computation + ^bb0(%arg2: tensor, %arg3: tensor): + %1 = mhlo.add %arg2, %arg3 : tensor + "mhlo.return"(%1) : (tensor) -> () + }) {replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, + scatter_dimension = 1 : i64} : (tensor<4x16xf32>) -> tensor<4xf32> + return %0 : tensor<4xf32> +} + +// ----- + +func @invalid_reduce_scatter(%data: tensor<4x16xf32>) -> tensor<4x4xf32> { + // expected-error@+1 {{scatter dim should be less than operand/result rank}} + %0 = "mhlo.all_reduce_scatter"(%data) ( { + // reduction computation + ^bb0(%arg2: tensor, %arg3: tensor): + %1 = mhlo.add %arg2, %arg3 : tensor + "mhlo.return"(%1) : (tensor) -> () + }) {replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, + scatter_dimension = 4 : i64} : (tensor<4x16xf32>) -> tensor<4x4xf32> + return %0 : tensor<4x4xf32> +} + +// ----- + +func @invalid_reduce_scatter(%data: tensor<4x16xf32>) -> tensor<3x4xf32> { + // expected-error@+1 {{non scatter dimensions should be same for operand (4) and result (3)}} + %0 = "mhlo.all_reduce_scatter"(%data) ( { + // reduction computation + ^bb0(%arg2: tensor, %arg3: tensor): + %1 = mhlo.add %arg2, %arg3 : tensor + "mhlo.return"(%1) : (tensor) -> () + }) {replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, + scatter_dimension = 1 : i64} : (tensor<4x16xf32>) -> tensor<3x4xf32> + return %0 : tensor<3x4xf32> +} + +// ----- + // CHECK-LABEL: func @alltoall func @alltoall(%data: tensor<4x16xf32>) -> tensor<16x4xf32> { %0 = "mhlo.all_to_all"(%data) {