[MLIR][CHLO] Implement type inference for `is_inf`-like operations in CHLO
PiperOrigin-RevId: 354265834
This commit is contained in:
parent
fe2e5a175f
commit
c3ddcd6c7f
|
@ -524,7 +524,8 @@ def HLOClient_ErfcOp : HLOClient_UnaryElementwiseOp<"erfc",
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def HLOClient_IsInfOp : HLOClient_UnaryElementwiseOp<"is_inf", [], HLO_FpTensor,
|
def HLOClient_IsInfOp : HLOClient_UnaryElementwiseOp<"is_inf",
|
||||||
|
[DeclareOpInterfaceMethods<InferTypeOpInterface>], HLO_FpTensor,
|
||||||
HLO_PredTensor> {
|
HLO_PredTensor> {
|
||||||
let summary = "IsInf predicate";
|
let summary = "IsInf predicate";
|
||||||
|
|
||||||
|
@ -533,8 +534,9 @@ def HLOClient_IsInfOp : HLOClient_UnaryElementwiseOp<"is_inf", [], HLO_FpTensor,
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def HLOClient_IsNegInfOp : HLOClient_UnaryElementwiseOp<"is_neg_inf", [],
|
def HLOClient_IsNegInfOp : HLOClient_UnaryElementwiseOp<"is_neg_inf",
|
||||||
HLO_FpTensor, HLO_PredTensor> {
|
[DeclareOpInterfaceMethods<InferTypeOpInterface>], HLO_FpTensor,
|
||||||
|
HLO_PredTensor> {
|
||||||
let summary = "IsNegInf predicate";
|
let summary = "IsNegInf predicate";
|
||||||
|
|
||||||
let description = [{
|
let description = [{
|
||||||
|
@ -542,8 +544,9 @@ def HLOClient_IsNegInfOp : HLOClient_UnaryElementwiseOp<"is_neg_inf", [],
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def HLOClient_IsPosInfOp : HLOClient_UnaryElementwiseOp<"is_pos_inf", [],
|
def HLOClient_IsPosInfOp : HLOClient_UnaryElementwiseOp<"is_pos_inf",
|
||||||
HLO_FpTensor, HLO_PredTensor> {
|
[DeclareOpInterfaceMethods<InferTypeOpInterface>], HLO_FpTensor,
|
||||||
|
HLO_PredTensor> {
|
||||||
let summary = "IsPosInf predicate";
|
let summary = "IsPosInf predicate";
|
||||||
|
|
||||||
let description = [{
|
let description = [{
|
||||||
|
|
|
@ -227,12 +227,52 @@ LogicalResult BroadcastCompareOp::inferReturnTypeComponents(
|
||||||
attributes, element_type,
|
attributes, element_type,
|
||||||
inferedReturnShapes);
|
inferedReturnShapes);
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult BroadcastCompareOp::reifyReturnTypeShapes(
|
LogicalResult BroadcastCompareOp::reifyReturnTypeShapes(
|
||||||
OpBuilder& builder, SmallVectorImpl<Value>& reifiedReturnShapes) {
|
OpBuilder& builder, SmallVectorImpl<Value>& reifiedReturnShapes) {
|
||||||
return ReifyBroadcastBinaryOpReturnTypeShapes(builder, getOperation(),
|
return ReifyBroadcastBinaryOpReturnTypeShapes(builder, getOperation(),
|
||||||
reifiedReturnShapes);
|
reifiedReturnShapes);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// IsInfOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
static Type getIsInfLikeReturnType(Value operand) {
|
||||||
|
Builder b(operand.getContext());
|
||||||
|
return mhlo::getSameShapeTensorType(operand.getType().cast<TensorType>(),
|
||||||
|
b.getI1Type());
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult IsInfOp::inferReturnTypes(
|
||||||
|
MLIRContext* ctx, Optional<Location>, ValueRange operands, DictionaryAttr,
|
||||||
|
RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
|
||||||
|
inferredReturnTypes.push_back(getIsInfLikeReturnType(operands.front()));
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// IsNegInfOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
LogicalResult IsNegInfOp::inferReturnTypes(
|
||||||
|
MLIRContext* ctx, Optional<Location>, ValueRange operands, DictionaryAttr,
|
||||||
|
RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
|
||||||
|
inferredReturnTypes.push_back(getIsInfLikeReturnType(operands.front()));
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// IsPosInfOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
LogicalResult IsPosInfOp::inferReturnTypes(
|
||||||
|
MLIRContext* ctx, Optional<Location>, ValueRange operands, DictionaryAttr,
|
||||||
|
RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
|
||||||
|
inferredReturnTypes.push_back(getIsInfLikeReturnType(operands.front()));
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Macros for method definitions that are common to most broadcasting ops.
|
// Macros for method definitions that are common to most broadcasting ops.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
Loading…
Reference in New Issue