[XLA:GPU] Allow all-gather operands to have different element types.
- XLA's all-gather combiner can create such all-gathers, so relax the same element type trait for all-gathers. PiperOrigin-RevId: 372380446
This commit is contained in:
parent
819adef41f
commit
8c854886cb
|
@ -539,7 +539,7 @@ def LHLO_ReducePrecisionOp: LHLO_Op<"reduce_precision", [SameTypeOperands]>,
|
||||||
|
|
||||||
// Common base class for AllReduce, AllGather, and AllToAll.
|
// Common base class for AllReduce, AllGather, and AllToAll.
|
||||||
class LHLO_CollectiveCommunicationOp<string name, list<OpTrait> traits = []> :
|
class LHLO_CollectiveCommunicationOp<string name, list<OpTrait> traits = []> :
|
||||||
LHLO_Op<name, !listconcat(traits, [SameVariadicOperandSize, SameOperandsElementType])> {
|
LHLO_Op<name, !listconcat(traits, [SameVariadicOperandSize])> {
|
||||||
dag arguments_base = (ins
|
dag arguments_base = (ins
|
||||||
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$operands,
|
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$operands,
|
||||||
Arg<Variadic<LHLO_Buffer>, "", [MemWrite]>:$results,
|
Arg<Variadic<LHLO_Buffer>, "", [MemWrite]>:$results,
|
||||||
|
@ -562,13 +562,13 @@ def LHLO_AllGatherOp : LHLO_CollectiveCommunicationOp<"all_gather">,
|
||||||
(ins I64Attr:$all_gather_dimension));
|
(ins I64Attr:$all_gather_dimension));
|
||||||
}
|
}
|
||||||
|
|
||||||
def LHLO_AllReduceOp : LHLO_CollectiveCommunicationOp<"all_reduce">,
|
def LHLO_AllReduceOp : LHLO_CollectiveCommunicationOp<"all_reduce", [SameOperandsElementType]>,
|
||||||
BASE_HLO_AllReduceOp {
|
BASE_HLO_AllReduceOp {
|
||||||
let arguments = arguments_base;
|
let arguments = arguments_base;
|
||||||
let regions = (region SizedRegion<1>:$computation);
|
let regions = (region SizedRegion<1>:$computation);
|
||||||
}
|
}
|
||||||
|
|
||||||
def LHLO_AllToAllOp : LHLO_CollectiveCommunicationOp<"all_to_all">,
|
def LHLO_AllToAllOp : LHLO_CollectiveCommunicationOp<"all_to_all", [SameOperandsElementType]>,
|
||||||
BASE_HLO_AllToAllOp {
|
BASE_HLO_AllToAllOp {
|
||||||
let arguments = !con(
|
let arguments = !con(
|
||||||
arguments_base,
|
arguments_base,
|
||||||
|
|
|
@ -32,6 +32,16 @@ func @invalid_allreduce(%input0: memref<2xf32>, %input1: memref<3xf16>) {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @mixed_types_allgather
|
||||||
|
func @mixed_types_allgather(%a0: memref<1x1xf32>, %a1:memref<1x1xi32>) {
|
||||||
|
"lmhlo.all_gather"(%a0, %a1, %a0, %a1) {all_gather_dimension = 0 : i64,
|
||||||
|
constrain_layout = false, replica_groups = dense<0> : tensor<1x1xi64>,
|
||||||
|
use_global_device_ids = false} : (memref<1x1xf32>, memref<1x1xi32>, memref<1x1xf32>, memref<1x1xi32>) -> ()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
func @invalid_allgather(%input0: memref<2xf32>, %output: memref<8xf32>) {
|
func @invalid_allgather(%input0: memref<2xf32>, %output: memref<8xf32>) {
|
||||||
// expected-error@+1 {{replica id #1 seen more than once}}
|
// expected-error@+1 {{replica id #1 seen more than once}}
|
||||||
"lmhlo.all_gather"(%input0, %output)
|
"lmhlo.all_gather"(%input0, %output)
|
||||||
|
|
Loading…
Reference in New Issue