diff --git a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td index a4ca489..6dacce1 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td @@ -539,7 +539,7 @@ def LHLO_ReducePrecisionOp: LHLO_Op<"reduce_precision", [SameTypeOperands]>, // Common base class for AllReduce, AllGather, and AllToAll. class LHLO_CollectiveCommunicationOp traits = []> : - LHLO_Op { + LHLO_Op { dag arguments_base = (ins Arg, "", [MemRead]>:$operands, Arg, "", [MemWrite]>:$results, @@ -562,13 +562,13 @@ def LHLO_AllGatherOp : LHLO_CollectiveCommunicationOp<"all_gather">, (ins I64Attr:$all_gather_dimension)); } -def LHLO_AllReduceOp : LHLO_CollectiveCommunicationOp<"all_reduce">, +def LHLO_AllReduceOp : LHLO_CollectiveCommunicationOp<"all_reduce", [SameOperandsElementType]>, BASE_HLO_AllReduceOp { let arguments = arguments_base; let regions = (region SizedRegion<1>:$computation); } -def LHLO_AllToAllOp : LHLO_CollectiveCommunicationOp<"all_to_all">, +def LHLO_AllToAllOp : LHLO_CollectiveCommunicationOp<"all_to_all", [SameOperandsElementType]>, BASE_HLO_AllToAllOp { let arguments = !con( arguments_base, diff --git a/tests/lhlo_ops.mlir b/tests/lhlo_ops.mlir index 97eec5d..4fd3a90 100644 --- a/tests/lhlo_ops.mlir +++ b/tests/lhlo_ops.mlir @@ -32,6 +32,16 @@ func @invalid_allreduce(%input0: memref<2xf32>, %input1: memref<3xf16>) { // ----- +// CHECK-LABEL: func @mixed_types_allgather +func @mixed_types_allgather(%a0: memref<1x1xf32>, %a1:memref<1x1xi32>) { + "lmhlo.all_gather"(%a0, %a1, %a0, %a1) {all_gather_dimension = 0 : i64, + constrain_layout = false, replica_groups = dense<0> : tensor<1x1xi64>, + use_global_device_ids = false} : (memref<1x1xf32>, memref<1x1xi32>, memref<1x1xf32>, memref<1x1xi32>) -> () + return +} + +// ----- + func @invalid_allgather(%input0: memref<2xf32>, %output: memref<8xf32>) { // expected-error@+1 {{replica id #1 seen more than once}} "lmhlo.all_gather"(%input0, %output)