diff --git a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td index 57fdfb6..a948c32 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td @@ -637,6 +637,14 @@ class BASE_HLO_ReplicaIdOp { }]; } +class BASE_HLO_PartitionIdOp { + string summary = "PartitionId operator"; + + string description = [{ + Returns the unique ID (int32 scalar) of the partition. + }]; +} + class BASE_HLO_AllReduceOp { string summary = "AllReduce operator"; diff --git a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td index ee39bfc..9706473 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td @@ -608,6 +608,10 @@ def LHLO_ReplicaIdOp : LHLO_Op<"replica_id", []>, BASE_HLO_ReplicaIdOp { let arguments = (ins Arg, "", [MemWrite]>); } +def LHLO_PartitionIdOp : LHLO_Op<"partition_id", []>, BASE_HLO_PartitionIdOp { + let arguments = (ins Arg, "", [MemWrite]>); +} + def LHLO_TriangularSolveOp: LHLO_Op<"triangular_solve", [SameOperandsElementType]>, BASE_HLO_TriangularSolveOp { let arguments = (ins