Legalize XlaReplicaId to HLO replica-id op
Also, define shape inference function for HLO replica-id op. PiperOrigin-RevId: 345714342
This commit is contained in:
parent
e48881af81
commit
9bd1995f90
|
@ -490,7 +490,8 @@ def HLO_RecvOp : HLO_Op<"recv", []> {
|
|||
// MHLO parallelism related op definitions.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def HLO_ReplicaIdOp : HLO_Op<"replica_id", [NoSideEffect]>,
|
||||
def HLO_ReplicaIdOp : HLO_Op<"replica_id", [NoSideEffect,
|
||||
DeclareOpInterfaceMethods<InferTypeOpInterface>]>,
|
||||
BASE_HLO_ReplicaIdOp {
|
||||
let results = (outs TensorOf<[UI32]>);
|
||||
}
|
||||
|
|
|
@ -1945,6 +1945,18 @@ void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
|
|||
context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ReplicaId Op
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult ReplicaIdOp::inferReturnTypes(
|
||||
MLIRContext* context, Optional<Location>, ValueRange operands,
|
||||
DictionaryAttr, RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
|
||||
inferredReturnTypes.push_back(RankedTensorType::get(
|
||||
/*shape=*/{}, IntegerType::get(32, IntegerType::Unsigned, context)));
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Case Op
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
Loading…
Reference in New Issue