Legalize XlaReplicaId to HLO replica-id op

Also, define shape inference function for HLO replica-id op.

PiperOrigin-RevId: 345714342
This commit is contained in:
Smit Hinsu 2020-12-04 11:04:02 -08:00 committed by TensorFlow MLIR Team
parent e48881af81
commit 9bd1995f90
2 changed files with 14 additions and 1 deletions

View File

@ -490,7 +490,8 @@ def HLO_RecvOp : HLO_Op<"recv", []> {
// MHLO parallelism related op definitions. // MHLO parallelism related op definitions.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
def HLO_ReplicaIdOp : HLO_Op<"replica_id", [NoSideEffect]>, def HLO_ReplicaIdOp : HLO_Op<"replica_id", [NoSideEffect,
DeclareOpInterfaceMethods<InferTypeOpInterface>]>,
BASE_HLO_ReplicaIdOp { BASE_HLO_ReplicaIdOp {
let results = (outs TensorOf<[UI32]>); let results = (outs TensorOf<[UI32]>);
} }

View File

@ -1945,6 +1945,18 @@ void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
context); context);
} }
//===----------------------------------------------------------------------===//
// ReplicaId Op
//===----------------------------------------------------------------------===//
LogicalResult ReplicaIdOp::inferReturnTypes(
MLIRContext* context, Optional<Location>, ValueRange operands,
DictionaryAttr, RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
inferredReturnTypes.push_back(RankedTensorType::get(
/*shape=*/{}, IntegerType::get(32, IntegerType::Unsigned, context)));
return success();
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Case Op // Case Op
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//