[XLA:GPU] Allow all-gather operands to have different element types.
- XLA's all-gather combiner can create such all-gathers, so relax the same element type trait for all-gathers. PiperOrigin-RevId: 372380446
This commit is contained in:
parent
819adef41f
commit
8c854886cb
|
@ -539,7 +539,7 @@ def LHLO_ReducePrecisionOp: LHLO_Op<"reduce_precision", [SameTypeOperands]>,
|
|||
|
||||
// Common base class for AllReduce, AllGather, and AllToAll.
|
||||
class LHLO_CollectiveCommunicationOp<string name, list<OpTrait> traits = []> :
|
||||
LHLO_Op<name, !listconcat(traits, [SameVariadicOperandSize, SameOperandsElementType])> {
|
||||
LHLO_Op<name, !listconcat(traits, [SameVariadicOperandSize])> {
|
||||
dag arguments_base = (ins
|
||||
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$operands,
|
||||
Arg<Variadic<LHLO_Buffer>, "", [MemWrite]>:$results,
|
||||
|
@ -562,13 +562,13 @@ def LHLO_AllGatherOp : LHLO_CollectiveCommunicationOp<"all_gather">,
|
|||
(ins I64Attr:$all_gather_dimension));
|
||||
}
|
||||
|
||||
def LHLO_AllReduceOp : LHLO_CollectiveCommunicationOp<"all_reduce">,
|
||||
def LHLO_AllReduceOp : LHLO_CollectiveCommunicationOp<"all_reduce", [SameOperandsElementType]>,
|
||||
BASE_HLO_AllReduceOp {
|
||||
let arguments = arguments_base;
|
||||
let regions = (region SizedRegion<1>:$computation);
|
||||
}
|
||||
|
||||
def LHLO_AllToAllOp : LHLO_CollectiveCommunicationOp<"all_to_all">,
|
||||
def LHLO_AllToAllOp : LHLO_CollectiveCommunicationOp<"all_to_all", [SameOperandsElementType]>,
|
||||
BASE_HLO_AllToAllOp {
|
||||
let arguments = !con(
|
||||
arguments_base,
|
||||
|
|
|
@ -32,6 +32,16 @@ func @invalid_allreduce(%input0: memref<2xf32>, %input1: memref<3xf16>) {
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @mixed_types_allgather
|
||||
func @mixed_types_allgather(%a0: memref<1x1xf32>, %a1:memref<1x1xi32>) {
|
||||
"lmhlo.all_gather"(%a0, %a1, %a0, %a1) {all_gather_dimension = 0 : i64,
|
||||
constrain_layout = false, replica_groups = dense<0> : tensor<1x1xi64>,
|
||||
use_global_device_ids = false} : (memref<1x1xf32>, memref<1x1xi32>, memref<1x1xf32>, memref<1x1xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @invalid_allgather(%input0: memref<2xf32>, %output: memref<8xf32>) {
|
||||
// expected-error@+1 {{replica id #1 seen more than once}}
|
||||
"lmhlo.all_gather"(%input0, %output)
|
||||
|
|
Loading…
Reference in New Issue