Legalize XlaReplicaId to HLO replica-id op

Also, define shape inference function for HLO replica-id op.

PiperOrigin-RevId: 345714342
This commit is contained in:
Smit Hinsu 2020-12-04 11:04:02 -08:00 committed by TensorFlow MLIR Team
parent e48881af81
commit 9bd1995f90
2 changed files with 14 additions and 1 deletions

View File

@ -490,7 +490,8 @@ def HLO_RecvOp : HLO_Op<"recv", []> {
// MHLO parallelism related op definitions.
//===----------------------------------------------------------------------===//
def HLO_ReplicaIdOp : HLO_Op<"replica_id", [NoSideEffect]>,
def HLO_ReplicaIdOp : HLO_Op<"replica_id", [NoSideEffect,
DeclareOpInterfaceMethods<InferTypeOpInterface>]>,
BASE_HLO_ReplicaIdOp {
let results = (outs TensorOf<[UI32]>);
}

View File

@ -1945,6 +1945,18 @@ void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
context);
}
//===----------------------------------------------------------------------===//
// ReplicaId Op
//===----------------------------------------------------------------------===//
LogicalResult ReplicaIdOp::inferReturnTypes(
MLIRContext* context, Optional<Location>, ValueRange operands,
DictionaryAttr, RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
inferredReturnTypes.push_back(RankedTensorType::get(
/*shape=*/{}, IntegerType::get(32, IntegerType::Unsigned, context)));
return success();
}
//===----------------------------------------------------------------------===//
// Case Op
//===----------------------------------------------------------------------===//