From 8c854886cb2ae681b1454b8a5d2acdca5d538b49 Mon Sep 17 00:00:00 2001 From: Rahul Joshi Date: Thu, 6 May 2021 11:02:00 -0700 Subject: [PATCH] [XLA:GPU] Allow all-gather operands to have different element types. - XLA's all-gather combiner can create such all-gathers, so relax the same element type trait for all-gathers. PiperOrigin-RevId: 372380446 --- include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td | 6 +++--- tests/lhlo_ops.mlir | 10 ++++++++++ 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td index a4ca489..6dacce1 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td @@ -539,7 +539,7 @@ def LHLO_ReducePrecisionOp: LHLO_Op<"reduce_precision", [SameTypeOperands]>, // Common base class for AllReduce, AllGather, and AllToAll. class LHLO_CollectiveCommunicationOp traits = []> : - LHLO_Op { + LHLO_Op { dag arguments_base = (ins Arg, "", [MemRead]>:$operands, Arg, "", [MemWrite]>:$results, @@ -562,13 +562,13 @@ def LHLO_AllGatherOp : LHLO_CollectiveCommunicationOp<"all_gather">, (ins I64Attr:$all_gather_dimension)); } -def LHLO_AllReduceOp : LHLO_CollectiveCommunicationOp<"all_reduce">, +def LHLO_AllReduceOp : LHLO_CollectiveCommunicationOp<"all_reduce", [SameOperandsElementType]>, BASE_HLO_AllReduceOp { let arguments = arguments_base; let regions = (region SizedRegion<1>:$computation); } -def LHLO_AllToAllOp : LHLO_CollectiveCommunicationOp<"all_to_all">, +def LHLO_AllToAllOp : LHLO_CollectiveCommunicationOp<"all_to_all", [SameOperandsElementType]>, BASE_HLO_AllToAllOp { let arguments = !con( arguments_base, diff --git a/tests/lhlo_ops.mlir b/tests/lhlo_ops.mlir index 97eec5d..4fd3a90 100644 --- a/tests/lhlo_ops.mlir +++ b/tests/lhlo_ops.mlir @@ -32,6 +32,16 @@ func @invalid_allreduce(%input0: memref<2xf32>, %input1: memref<3xf16>) { // ----- +// CHECK-LABEL: func @mixed_types_allgather +func @mixed_types_allgather(%a0: memref<1x1xf32>, %a1:memref<1x1xi32>) { + "lmhlo.all_gather"(%a0, %a1, %a0, %a1) {all_gather_dimension = 0 : i64, + constrain_layout = false, replica_groups = dense<0> : tensor<1x1xi64>, + use_global_device_ids = false} : (memref<1x1xf32>, memref<1x1xi32>, memref<1x1xf32>, memref<1x1xi32>) -> () + return +} + +// ----- + func @invalid_allgather(%input0: memref<2xf32>, %output: memref<8xf32>) { // expected-error@+1 {{replica id #1 seen more than once}} "lmhlo.all_gather"(%input0, %output)