diff --git a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td index a948c32..b5411e3 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td @@ -645,6 +645,15 @@ class BASE_HLO_PartitionIdOp { }]; } +class BASE_HLO_AllGatherOp { + string summary = "AllGather operator"; + + string description = [{ + Performs concatenation across replicas. + + See https://www.tensorflow.org/xla/operation_semantics#allgather + }]; +} class BASE_HLO_AllReduceOp { string summary = "AllReduce operator"; diff --git a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td index 9706473..d1bdd49 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td @@ -544,9 +544,10 @@ def LHLO_ReducePrecisionOp: LHLO_Op<"reduce_precision", [SameTypeOperands]>, ); } -def LHLO_AllReduceOp : LHLO_Op<"all_reduce", [SameVariadicOperandSize]>, - BASE_HLO_AllReduceOp { - let arguments = (ins +// Common base class for AllReduce, AllGather, and AllToAll. +class LHLO_CollectiveCommunicationOp traits = []> : + LHLO_Op { + dag arguments_base = (ins Arg, "", [MemRead]>:$operands, Arg, "", [MemWrite]>:$results, I64ElementsAttr:$replica_groups, @@ -554,14 +555,33 @@ def LHLO_AllReduceOp : LHLO_Op<"all_reduce", [SameVariadicOperandSize]>, OptionalAttr:$channel_id, DefaultValuedAttr:$use_global_device_ids ); - let regions = (region SizedRegion<1>:$computation); let verifier = [{ return Verify(*this); }]; let extraClassDeclaration = [{ - // AllReduce is cross replica if channel_id is not set. + // AllGather is cross replica if channel_id is not set. bool IsCrossReplica() { return !channel_id().hasValue(); } }]; } +def LHLO_AllGatherOp : LHLO_CollectiveCommunicationOp<"all_gather">, + BASE_HLO_AllGatherOp { + let arguments = !con( + arguments_base, + (ins I64Attr:$all_gather_dimension)); +} + +def LHLO_AllReduceOp : LHLO_CollectiveCommunicationOp<"all_reduce">, + BASE_HLO_AllReduceOp { + let arguments = arguments_base; + let regions = (region SizedRegion<1>:$computation); +} + +def LHLO_AllToAllOp : LHLO_CollectiveCommunicationOp<"all_to_all">, + BASE_HLO_AllToAllOp { + let arguments = !con( + arguments_base, + (ins OptionalAttr:$split_dimension)); +} + def LHLO_CollectivePermuteOp: LHLO_Op<"collective_permute", [SameTypeOperands]>, BASE_HLO_CollectivePermuteOp { diff --git a/lib/Dialect/mhlo/IR/lhlo_ops.cc b/lib/Dialect/mhlo/IR/lhlo_ops.cc index 5d06a6f..048623a 100644 --- a/lib/Dialect/mhlo/IR/lhlo_ops.cc +++ b/lib/Dialect/mhlo/IR/lhlo_ops.cc @@ -25,6 +25,7 @@ limitations under the License. #include "llvm/ADT/APInt.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/FormatVariadic.h" @@ -56,11 +57,64 @@ LmhloDialect::LmhloDialect(MLIRContext* context) >(); } +// Verifies replica groups attached to collective communication operations. +// If the attribute is not empty, it must be a rank 2 tensor, and each replica +// should appear exactly once. If `is_uniform_sized` is true, then we also check +// that each group is of the same size. If the operation has +// `use_global_device_id` set, then replica group cannot be empty. +template +LogicalResult VerifyReplicaGroups(OpT op, bool is_uniform_sized) { + DenseIntElementsAttr attr = op.replica_groups(); + auto replica_group_type = attr.getType().dyn_cast(); + if (!replica_group_type || replica_group_type.getRank() != 2 || + !replica_group_type.getElementType().isInteger(/*width=*/64)) + return op.emitOpError( + "replica groups should be a rank 2 tensor of 64 bit integers"); + + if (replica_group_type.getShape().equals(ArrayRef{0, 0})) + return success(); + + int64_t max_replica_id_seen = 0; + llvm::SmallSet replica_seen; + for (int64_t id : attr.getValues()) { + if (is_uniform_sized && id == -1) { + return op.emitOpError("Invalid replica id -1"); + } + if (id != -1) { + if (!replica_seen.insert(id).second) { + return op.emitOpError("replica id #") << id << " seen more than once"; + } + max_replica_id_seen = std::max(max_replica_id_seen, id); + } + } + + for (int64_t id = 0; id <= max_replica_id_seen; id++) { + if (!replica_seen.contains(id)) { + return op.emitOpError("replica id #") + << id << " not seen in replica groups"; + } + } + return success(); +} + +// TODO(jurahul): Add verification for output shape. +static LogicalResult Verify(AllGatherOp op) { + return VerifyReplicaGroups(op, /*is_uniform_sized=*/true); +} + +// TODO(jurahul): Add verification for output shape. +static LogicalResult Verify(AllToAllOp op) { + return VerifyReplicaGroups(op, /*is_uniform_sized=*/true); +} + //===----------------------------------------------------------------------===// // AllReduceOp //===----------------------------------------------------------------------===// static LogicalResult Verify(AllReduceOp op) { + if (failed(VerifyReplicaGroups(op, /*is_uniform_sized=*/false))) + return failure(); + // AllReduce had variadic operands and results that have the same size. // Each memeber of the operand should have the same type as the corresponding // member of the result. @@ -73,21 +127,6 @@ static LogicalResult Verify(AllReduceOp op) { << it.index() << " (type: " << operandType << ") and result #" << it.index() << " (type: " << resultType << ") to have same type"; } - - // Since AllReduce has a single reduction computation attached to it (which is - // applied over all the operands and results), they all need to have the same - // element type. Since we already check that each operand and corresponding - // result has the same type, its sufficient to check just the memref element - // type for each operands. - Type elementType = - op.operands().front().getType().cast().getElementType(); - bool allMatch = llvm::all_of( - op.operands().drop_front().getType(), [elementType](Type type) { - return type.cast().getElementType() == elementType; - }); - if (!allMatch) - return op.emitOpError("requires all operands to have same element type"); - return success(); } diff --git a/tests/lhlo_ops.mlir b/tests/lhlo_ops.mlir index 62dab01..76be69f 100644 --- a/tests/lhlo_ops.mlir +++ b/tests/lhlo_ops.mlir @@ -10,7 +10,7 @@ func @invalid_allreduce(%input0: memref<2xf32>, %input1: memref<3xf32>) { "mhlo.return"(%add) : (tensor) -> () }) {channel_id = {handle = 1 : i64, type = 0 : i64}, constrain_layout = false, - replica_groups = dense<[[0, 1, 2, 3], [5, 6, 7, 8]]> : tensor<2x4xi64>, + replica_groups = dense<[[0, 1, 2, 3], [5, 6, 7, 4]]> : tensor<2x4xi64>, use_global_device_ids = false} : (memref<2xf32>, memref<3xf32>, memref<2xf32>, memref<2xf32>) -> () return } @@ -18,7 +18,7 @@ func @invalid_allreduce(%input0: memref<2xf32>, %input1: memref<3xf32>) { // ----- func @invalid_allreduce(%input0: memref<2xf32>, %input1: memref<3xf16>) { - // expected-error@+1 {{requires all operands to have same element type}} + // expected-error@+1 {{requires the same element type for all operands}} "lmhlo.all_reduce"(%input0, %input1, %input0, %input1) ({ ^bb0(%arg0: tensor, %arg1: tensor): %add = mhlo.add %arg0, %arg1 : tensor @@ -32,6 +32,39 @@ func @invalid_allreduce(%input0: memref<2xf32>, %input1: memref<3xf16>) { // ----- +func @invalid_allgather(%input0: memref<2xf32>, %output: memref<8xf32>) { + // expected-error@+1 {{replica id #1 seen more than once}} + "lmhlo.all_gather"(%input0, %output) + {channel_id = {handle = 1 : i64, type = 0 : i64}, constrain_layout = false, + replica_groups = dense<[[0, 1, 1, 3], [5, 6, 7, 8]]> : tensor<2x4xi64>, + use_global_device_ids = false, all_gather_dimension = 0 : i64} : (memref<2xf32>, memref<8xf32>) -> () + return +} + +// ----- + +func @invalid_alltoall(%input0: memref<2xf32>, %output: memref<8xf32>) { + // expected-error@+1 {{replica id #4 not seen in replica groups}} + "lmhlo.all_to_all"(%input0, %output) + {channel_id = {handle = 1 : i64, type = 0 : i64}, constrain_layout = false, + replica_groups = dense<[[0, 1, 2, 3], [5, 6, 7, 8]]> : tensor<2x4xi64>, + use_global_device_ids = false, all_gather_dimension = 0 : i64} : (memref<2xf32>, memref<8xf32>) -> () + return +} + +// ----- + +func @invalid_alltoall(%input0: memref<2xf32>, %output: memref<8xf32>) { + // expected-error@+1 {{replica groups should be a rank 2 tensor of 64 bit integers}} + "lmhlo.all_to_all"(%input0, %output) + {channel_id = {handle = 1 : i64, type = 0 : i64}, constrain_layout = false, + replica_groups = dense<0> : tensor<1xi64>, + use_global_device_ids = false, all_gather_dimension = 0 : i64} : (memref<2xf32>, memref<8xf32>) -> () + return +} + +// ----- + // CHECK-LABEL: func @ceil func @ceil(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { "lmhlo.ceil"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()