[XLA:GPU] Add support for PartitionId

PiperOrigin-RevId: 354599221
This commit is contained in:
Rahul Joshi 2021-01-29 13:30:59 -08:00 committed by TensorFlow MLIR Team
parent b1ce05cfc9
commit 1be1123c70
2 changed files with 12 additions and 0 deletions

View File

@ -637,6 +637,14 @@ class BASE_HLO_ReplicaIdOp {
}]; }];
} }
class BASE_HLO_PartitionIdOp {
string summary = "PartitionId operator";
string description = [{
Returns the unique ID (int32 scalar) of the partition.
}];
}
class BASE_HLO_AllReduceOp { class BASE_HLO_AllReduceOp {
string summary = "AllReduce operator"; string summary = "AllReduce operator";

View File

@ -608,6 +608,10 @@ def LHLO_ReplicaIdOp : LHLO_Op<"replica_id", []>, BASE_HLO_ReplicaIdOp {
let arguments = (ins Arg<MemRefOf<[UI32]>, "", [MemWrite]>); let arguments = (ins Arg<MemRefOf<[UI32]>, "", [MemWrite]>);
} }
def LHLO_PartitionIdOp : LHLO_Op<"partition_id", []>, BASE_HLO_PartitionIdOp {
let arguments = (ins Arg<MemRefOf<[UI32]>, "", [MemWrite]>);
}
def LHLO_TriangularSolveOp: LHLO_Op<"triangular_solve", [SameOperandsElementType]>, def LHLO_TriangularSolveOp: LHLO_Op<"triangular_solve", [SameOperandsElementType]>,
BASE_HLO_TriangularSolveOp { BASE_HLO_TriangularSolveOp {
let arguments = (ins let arguments = (ins