diff --git a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td index 05fc6bf..0f8d721 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td @@ -909,6 +909,26 @@ def HLO_WhileOp: HLO_Op<"while", [ let hasCustomHLOConverter = 1; } +def HLO_AllGatherOp : HLO_Op<"all_gather", [SameOperandsAndResultElementType]> { + + string summary = "AllGather operator"; + + string description = [{ + Performs concatenation across replicas. + + See https://www.tensorflow.org/xla/operation_semantics#allgather + }]; + + let arguments = (ins + HLO_Tensor:$operand, + I64Attr:$all_gather_dim, + I64ElementsAttr:$replica_groups, + OptionalAttr:$channel_handle + ); + let results = (outs HLO_Tensor); + let hasCustomHLOConverter = 1; +} + def HLO_AllReduceOp : HLO_Op<"all_reduce", [SameOperandsAndResultType]> { let summary = "AllReduce operator"; @@ -921,7 +941,7 @@ def HLO_AllReduceOp : HLO_Op<"all_reduce", let arguments = (ins HLO_Tensor:$operand, I64ElementsAttr:$replica_groups, - OptionalAttr:$channel_id + OptionalAttr:$channel_handle ); let regions = (region SizedRegion<1>:$computation); let results = (outs HLO_Tensor); diff --git a/lib/Dialect/mhlo/IR/hlo_ops.cc b/lib/Dialect/mhlo/IR/hlo_ops.cc index f91bf4b..d363844 100644 --- a/lib/Dialect/mhlo/IR/hlo_ops.cc +++ b/lib/Dialect/mhlo/IR/hlo_ops.cc @@ -725,6 +725,34 @@ static LogicalResult Verify(AllToAllOp op) { return success(); } +//===----------------------------------------------------------------------===// +// AllGatherOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(AllGatherOp op) { + // If operand and result are both ranked, then the size of the gather + // dimension in the result should be a multiple of the size of the gather + // dimension in the operand. + auto operandType = op.operand().getType().dyn_cast(); + auto resultType = op.getType().dyn_cast(); + uint64_t allGatherDimIndex = op.all_gather_dim(); + if (!operandType || !resultType || + operandType.isDynamicDim(allGatherDimIndex) || + resultType.isDynamicDim(allGatherDimIndex)) + return success(); + if (operandType.getDimSize(allGatherDimIndex) == 0) + return op.emitOpError() << "operand gather dimension cannot be zero."; + if ((resultType.getDimSize(allGatherDimIndex) % + operandType.getDimSize(allGatherDimIndex)) != 0) + return op.emitOpError() + << "result gather dimension has size " + << resultType.getDimSize(allGatherDimIndex) + << ", expected to be a multiple of operand gather dimension size " + << operandType.getDimSize(allGatherDimIndex); + + return success(); +} + //===----------------------------------------------------------------------===// // BroadcastOp //===----------------------------------------------------------------------===// diff --git a/tests/ops.mlir b/tests/ops.mlir index 7c33d2f..48192fe 100644 --- a/tests/ops.mlir +++ b/tests/ops.mlir @@ -52,6 +52,42 @@ func @alltoall_invalid_split_dim_size(%data: tensor<4x16xf32>) -> tensor<16x4xf3 // ----- +func @allgather_incompatible_types(%arg0: tensor<128x32xf32>) -> tensor<128x100xf32> { + // expected-error@+1 {{result gather dimension has size 100, expected to be a multiple of operand gather dimension size 32}} + %0 = "mhlo.all_gather"(%arg0) { + all_gather_dim = 1 : i64, + channel_handle = {handle = 1 : i64, type = 0 : i64}, + replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64> + } : (tensor<128x32xf32>) -> tensor<128x100xf32> + return %0 : tensor<128x100xf32> +} + +// ----- + +func @allgather_gather_along_zero_dimension(%arg0: tensor<128x0x32xf32>) -> tensor<128x100xf32> { + // expected-error@+1 {{operand gather dimension cannot be zero}} + %0 = "mhlo.all_gather"(%arg0) { + all_gather_dim = 1 : i64, + channel_handle = {handle = 1 : i64, type = 0 : i64}, + replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64> + } : (tensor<128x0x32xf32>) -> tensor<128x100xf32> + return %0 : tensor<128x100xf32> +} + +// ----- + +// CHECK-LABEL: func @allgather_dynamic_gather_dim +func @allgather_dynamic_gather_dim(%arg0: tensor<128x32xf32>) -> tensor<128x?xf32> { + %0 = "mhlo.all_gather"(%arg0) { + all_gather_dim = 1 : i64, + channel_handle = {handle = 1 : i64, type = 0 : i64}, + replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64> + } : (tensor<128x32xf32>) -> tensor<128x?xf32> + return %0 : tensor<128x?xf32> +} + +// ----- + // CHECK-LABEL: func @broadcast func @broadcast(%arg0: tensor<3xi32>) -> tensor<1x2x3xi32> { %0 = "mhlo.broadcast"(%arg0) {broadcast_sizes = dense<[1, 2]> : tensor<2xi64>} : (tensor<3xi32>) -> tensor<1x2x3xi32>