[XLA:GPU] Add support for PartitionId
PiperOrigin-RevId: 354599221
This commit is contained in:
parent
b1ce05cfc9
commit
1be1123c70
|
@ -637,6 +637,14 @@ class BASE_HLO_ReplicaIdOp {
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class BASE_HLO_PartitionIdOp {
|
||||||
|
string summary = "PartitionId operator";
|
||||||
|
|
||||||
|
string description = [{
|
||||||
|
Returns the unique ID (int32 scalar) of the partition.
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class BASE_HLO_AllReduceOp {
|
class BASE_HLO_AllReduceOp {
|
||||||
string summary = "AllReduce operator";
|
string summary = "AllReduce operator";
|
||||||
|
|
|
@ -608,6 +608,10 @@ def LHLO_ReplicaIdOp : LHLO_Op<"replica_id", []>, BASE_HLO_ReplicaIdOp {
|
||||||
let arguments = (ins Arg<MemRefOf<[UI32]>, "", [MemWrite]>);
|
let arguments = (ins Arg<MemRefOf<[UI32]>, "", [MemWrite]>);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def LHLO_PartitionIdOp : LHLO_Op<"partition_id", []>, BASE_HLO_PartitionIdOp {
|
||||||
|
let arguments = (ins Arg<MemRefOf<[UI32]>, "", [MemWrite]>);
|
||||||
|
}
|
||||||
|
|
||||||
def LHLO_TriangularSolveOp: LHLO_Op<"triangular_solve", [SameOperandsElementType]>,
|
def LHLO_TriangularSolveOp: LHLO_Op<"triangular_solve", [SameOperandsElementType]>,
|
||||||
BASE_HLO_TriangularSolveOp {
|
BASE_HLO_TriangularSolveOp {
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
|
|
Loading…
Reference in New Issue