diff --git a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td index cfe4da7..ba4749e 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td @@ -490,7 +490,8 @@ def HLO_RecvOp : HLO_Op<"recv", []> { // MHLO parallelism related op definitions. //===----------------------------------------------------------------------===// -def HLO_ReplicaIdOp : HLO_Op<"replica_id", [NoSideEffect]>, +def HLO_ReplicaIdOp : HLO_Op<"replica_id", [NoSideEffect, + DeclareOpInterfaceMethods]>, BASE_HLO_ReplicaIdOp { let results = (outs TensorOf<[UI32]>); } diff --git a/lib/Dialect/mhlo/IR/hlo_ops.cc b/lib/Dialect/mhlo/IR/hlo_ops.cc index 64b42c9..aecf784 100644 --- a/lib/Dialect/mhlo/IR/hlo_ops.cc +++ b/lib/Dialect/mhlo/IR/hlo_ops.cc @@ -1945,6 +1945,18 @@ void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList& results, context); } +//===----------------------------------------------------------------------===// +// ReplicaId Op +//===----------------------------------------------------------------------===// + +LogicalResult ReplicaIdOp::inferReturnTypes( + MLIRContext* context, Optional, ValueRange operands, + DictionaryAttr, RegionRange, SmallVectorImpl& inferredReturnTypes) { + inferredReturnTypes.push_back(RankedTensorType::get( + /*shape=*/{}, IntegerType::get(32, IntegerType::Unsigned, context))); + return success(); +} + //===----------------------------------------------------------------------===// // Case Op //===----------------------------------------------------------------------===//