[MLIR][CHLO] Implement type inference for `is_inf`-like operations in CHLO

PiperOrigin-RevId: 354265834
This commit is contained in:
A. Unique TensorFlower 2021-01-28 01:35:42 -08:00 committed by TensorFlow MLIR Team
parent fe2e5a175f
commit c3ddcd6c7f
2 changed files with 48 additions and 5 deletions

View File

@ -524,7 +524,8 @@ def HLOClient_ErfcOp : HLOClient_UnaryElementwiseOp<"erfc",
}];
}
def HLOClient_IsInfOp : HLOClient_UnaryElementwiseOp<"is_inf", [], HLO_FpTensor,
def HLOClient_IsInfOp : HLOClient_UnaryElementwiseOp<"is_inf",
[DeclareOpInterfaceMethods<InferTypeOpInterface>], HLO_FpTensor,
HLO_PredTensor> {
let summary = "IsInf predicate";
@ -533,8 +534,9 @@ def HLOClient_IsInfOp : HLOClient_UnaryElementwiseOp<"is_inf", [], HLO_FpTensor,
}];
}
def HLOClient_IsNegInfOp : HLOClient_UnaryElementwiseOp<"is_neg_inf", [],
HLO_FpTensor, HLO_PredTensor> {
def HLOClient_IsNegInfOp : HLOClient_UnaryElementwiseOp<"is_neg_inf",
[DeclareOpInterfaceMethods<InferTypeOpInterface>], HLO_FpTensor,
HLO_PredTensor> {
let summary = "IsNegInf predicate";
let description = [{
@ -542,8 +544,9 @@ def HLOClient_IsNegInfOp : HLOClient_UnaryElementwiseOp<"is_neg_inf", [],
}];
}
def HLOClient_IsPosInfOp : HLOClient_UnaryElementwiseOp<"is_pos_inf", [],
HLO_FpTensor, HLO_PredTensor> {
def HLOClient_IsPosInfOp : HLOClient_UnaryElementwiseOp<"is_pos_inf",
[DeclareOpInterfaceMethods<InferTypeOpInterface>], HLO_FpTensor,
HLO_PredTensor> {
let summary = "IsPosInf predicate";
let description = [{

View File

@ -227,12 +227,52 @@ LogicalResult BroadcastCompareOp::inferReturnTypeComponents(
attributes, element_type,
inferedReturnShapes);
}
LogicalResult BroadcastCompareOp::reifyReturnTypeShapes(
OpBuilder& builder, SmallVectorImpl<Value>& reifiedReturnShapes) {
return ReifyBroadcastBinaryOpReturnTypeShapes(builder, getOperation(),
reifiedReturnShapes);
}
//===----------------------------------------------------------------------===//
// IsInfOp
//===----------------------------------------------------------------------===//
static Type getIsInfLikeReturnType(Value operand) {
Builder b(operand.getContext());
return mhlo::getSameShapeTensorType(operand.getType().cast<TensorType>(),
b.getI1Type());
}
LogicalResult IsInfOp::inferReturnTypes(
MLIRContext* ctx, Optional<Location>, ValueRange operands, DictionaryAttr,
RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
inferredReturnTypes.push_back(getIsInfLikeReturnType(operands.front()));
return success();
}
//===----------------------------------------------------------------------===//
// IsNegInfOp
//===----------------------------------------------------------------------===//
LogicalResult IsNegInfOp::inferReturnTypes(
MLIRContext* ctx, Optional<Location>, ValueRange operands, DictionaryAttr,
RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
inferredReturnTypes.push_back(getIsInfLikeReturnType(operands.front()));
return success();
}
//===----------------------------------------------------------------------===//
// IsPosInfOp
//===----------------------------------------------------------------------===//
LogicalResult IsPosInfOp::inferReturnTypes(
MLIRContext* ctx, Optional<Location>, ValueRange operands, DictionaryAttr,
RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
inferredReturnTypes.push_back(getIsInfLikeReturnType(operands.front()));
return success();
}
//===----------------------------------------------------------------------===//
// Macros for method definitions that are common to most broadcasting ops.
//===----------------------------------------------------------------------===//