[XLA:GPU] Allow all-gather operands to have different element types.

- XLA's all-gather combiner can create such all-gathers, so relax the same element type
  trait for all-gathers.

PiperOrigin-RevId: 372380446
This commit is contained in:
Rahul Joshi 2021-05-06 11:02:00 -07:00 committed by TensorFlow MLIR Team
parent 819adef41f
commit 8c854886cb
2 changed files with 13 additions and 3 deletions

View File

@ -539,7 +539,7 @@ def LHLO_ReducePrecisionOp: LHLO_Op<"reduce_precision", [SameTypeOperands]>,
// Common base class for AllReduce, AllGather, and AllToAll. // Common base class for AllReduce, AllGather, and AllToAll.
class LHLO_CollectiveCommunicationOp<string name, list<OpTrait> traits = []> : class LHLO_CollectiveCommunicationOp<string name, list<OpTrait> traits = []> :
LHLO_Op<name, !listconcat(traits, [SameVariadicOperandSize, SameOperandsElementType])> { LHLO_Op<name, !listconcat(traits, [SameVariadicOperandSize])> {
dag arguments_base = (ins dag arguments_base = (ins
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$operands, Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$operands,
Arg<Variadic<LHLO_Buffer>, "", [MemWrite]>:$results, Arg<Variadic<LHLO_Buffer>, "", [MemWrite]>:$results,
@ -562,13 +562,13 @@ def LHLO_AllGatherOp : LHLO_CollectiveCommunicationOp<"all_gather">,
(ins I64Attr:$all_gather_dimension)); (ins I64Attr:$all_gather_dimension));
} }
def LHLO_AllReduceOp : LHLO_CollectiveCommunicationOp<"all_reduce">, def LHLO_AllReduceOp : LHLO_CollectiveCommunicationOp<"all_reduce", [SameOperandsElementType]>,
BASE_HLO_AllReduceOp { BASE_HLO_AllReduceOp {
let arguments = arguments_base; let arguments = arguments_base;
let regions = (region SizedRegion<1>:$computation); let regions = (region SizedRegion<1>:$computation);
} }
def LHLO_AllToAllOp : LHLO_CollectiveCommunicationOp<"all_to_all">, def LHLO_AllToAllOp : LHLO_CollectiveCommunicationOp<"all_to_all", [SameOperandsElementType]>,
BASE_HLO_AllToAllOp { BASE_HLO_AllToAllOp {
let arguments = !con( let arguments = !con(
arguments_base, arguments_base,

View File

@ -32,6 +32,16 @@ func @invalid_allreduce(%input0: memref<2xf32>, %input1: memref<3xf16>) {
// ----- // -----
// CHECK-LABEL: func @mixed_types_allgather
func @mixed_types_allgather(%a0: memref<1x1xf32>, %a1:memref<1x1xi32>) {
"lmhlo.all_gather"(%a0, %a1, %a0, %a1) {all_gather_dimension = 0 : i64,
constrain_layout = false, replica_groups = dense<0> : tensor<1x1xi64>,
use_global_device_ids = false} : (memref<1x1xf32>, memref<1x1xi32>, memref<1x1xf32>, memref<1x1xi32>) -> ()
return
}
// -----
func @invalid_allgather(%input0: memref<2xf32>, %output: memref<8xf32>) { func @invalid_allgather(%input0: memref<2xf32>, %output: memref<8xf32>) {
// expected-error@+1 {{replica id #1 seen more than once}} // expected-error@+1 {{replica id #1 seen more than once}}
"lmhlo.all_gather"(%input0, %output) "lmhlo.all_gather"(%input0, %output)