[MLIR][CHLO] Implement type inference for `is_inf`-like operations in CHLO

PiperOrigin-RevId: 354265834
This commit is contained in:
A. Unique TensorFlower 2021-01-28 01:35:42 -08:00 committed by TensorFlow MLIR Team
parent fe2e5a175f
commit c3ddcd6c7f
2 changed files with 48 additions and 5 deletions

View File

@ -524,7 +524,8 @@ def HLOClient_ErfcOp : HLOClient_UnaryElementwiseOp<"erfc",
}]; }];
} }
def HLOClient_IsInfOp : HLOClient_UnaryElementwiseOp<"is_inf", [], HLO_FpTensor, def HLOClient_IsInfOp : HLOClient_UnaryElementwiseOp<"is_inf",
[DeclareOpInterfaceMethods<InferTypeOpInterface>], HLO_FpTensor,
HLO_PredTensor> { HLO_PredTensor> {
let summary = "IsInf predicate"; let summary = "IsInf predicate";
@ -533,8 +534,9 @@ def HLOClient_IsInfOp : HLOClient_UnaryElementwiseOp<"is_inf", [], HLO_FpTensor,
}]; }];
} }
def HLOClient_IsNegInfOp : HLOClient_UnaryElementwiseOp<"is_neg_inf", [], def HLOClient_IsNegInfOp : HLOClient_UnaryElementwiseOp<"is_neg_inf",
HLO_FpTensor, HLO_PredTensor> { [DeclareOpInterfaceMethods<InferTypeOpInterface>], HLO_FpTensor,
HLO_PredTensor> {
let summary = "IsNegInf predicate"; let summary = "IsNegInf predicate";
let description = [{ let description = [{
@ -542,8 +544,9 @@ def HLOClient_IsNegInfOp : HLOClient_UnaryElementwiseOp<"is_neg_inf", [],
}]; }];
} }
def HLOClient_IsPosInfOp : HLOClient_UnaryElementwiseOp<"is_pos_inf", [], def HLOClient_IsPosInfOp : HLOClient_UnaryElementwiseOp<"is_pos_inf",
HLO_FpTensor, HLO_PredTensor> { [DeclareOpInterfaceMethods<InferTypeOpInterface>], HLO_FpTensor,
HLO_PredTensor> {
let summary = "IsPosInf predicate"; let summary = "IsPosInf predicate";
let description = [{ let description = [{

View File

@ -227,12 +227,52 @@ LogicalResult BroadcastCompareOp::inferReturnTypeComponents(
attributes, element_type, attributes, element_type,
inferedReturnShapes); inferedReturnShapes);
} }
LogicalResult BroadcastCompareOp::reifyReturnTypeShapes( LogicalResult BroadcastCompareOp::reifyReturnTypeShapes(
OpBuilder& builder, SmallVectorImpl<Value>& reifiedReturnShapes) { OpBuilder& builder, SmallVectorImpl<Value>& reifiedReturnShapes) {
return ReifyBroadcastBinaryOpReturnTypeShapes(builder, getOperation(), return ReifyBroadcastBinaryOpReturnTypeShapes(builder, getOperation(),
reifiedReturnShapes); reifiedReturnShapes);
} }
//===----------------------------------------------------------------------===//
// IsInfOp
//===----------------------------------------------------------------------===//
static Type getIsInfLikeReturnType(Value operand) {
Builder b(operand.getContext());
return mhlo::getSameShapeTensorType(operand.getType().cast<TensorType>(),
b.getI1Type());
}
LogicalResult IsInfOp::inferReturnTypes(
MLIRContext* ctx, Optional<Location>, ValueRange operands, DictionaryAttr,
RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
inferredReturnTypes.push_back(getIsInfLikeReturnType(operands.front()));
return success();
}
//===----------------------------------------------------------------------===//
// IsNegInfOp
//===----------------------------------------------------------------------===//
LogicalResult IsNegInfOp::inferReturnTypes(
MLIRContext* ctx, Optional<Location>, ValueRange operands, DictionaryAttr,
RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
inferredReturnTypes.push_back(getIsInfLikeReturnType(operands.front()));
return success();
}
//===----------------------------------------------------------------------===//
// IsPosInfOp
//===----------------------------------------------------------------------===//
LogicalResult IsPosInfOp::inferReturnTypes(
MLIRContext* ctx, Optional<Location>, ValueRange operands, DictionaryAttr,
RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
inferredReturnTypes.push_back(getIsInfLikeReturnType(operands.front()));
return success();
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Macros for method definitions that are common to most broadcasting ops. // Macros for method definitions that are common to most broadcasting ops.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//