[MLIR][HLO] Implement type inference for `is_finite` op
PiperOrigin-RevId: 354261420
This commit is contained in:
parent
c653db73c5
commit
fe2e5a175f
|
@ -81,6 +81,9 @@ LogicalResult deriveShapeFromFirstOperand(
|
|||
OpBuilder *builder, Operation *op,
|
||||
SmallVectorImpl<Value> *reifiedReturnShapes);
|
||||
|
||||
// Type derivation function that returns a tensor type with a new element type.
|
||||
TensorType getSameShapeTensorType(TensorType tensor_type, Type element_type);
|
||||
|
||||
} // end namespace mhlo
|
||||
} // end namespace mlir
|
||||
|
||||
|
|
|
@ -196,8 +196,8 @@ def HLO_ImagOp: HLO_UnaryElementwiseOp<"imag",
|
|||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def HLO_IsFiniteOp: HLO_UnaryElementwiseOp<"is_finite",
|
||||
[NoSideEffect, SameOperandsAndResultShape], HLO_Tensor>,
|
||||
def HLO_IsFiniteOp: HLO_UnaryElementwiseOp<"is_finite", [NoSideEffect,
|
||||
DeclareOpInterfaceMethods<InferTypeOpInterface>], HLO_Tensor>,
|
||||
BASE_HLO_IsFiniteOp {
|
||||
let arguments = (ins HLO_FpTensor:$x);
|
||||
let results = (outs HLO_PredTensor:$y);
|
||||
|
|
|
@ -976,6 +976,10 @@ OpFoldResult ComplexOp::fold(ArrayRef<Attribute> operands) {
|
|||
return {};
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ImagOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
Type CreateRealType(Type type) {
|
||||
auto element_ty = getElementTypeOrSelf(type);
|
||||
|
@ -1009,6 +1013,33 @@ OpFoldResult ImagOp::fold(ArrayRef<Attribute> operands) {
|
|||
return {};
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// IsFiniteOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
TensorType getSameShapeTensorType(TensorType tensor_type, Type element_type) {
|
||||
if (auto ranked_tensor_ty = tensor_type.dyn_cast<RankedTensorType>()) {
|
||||
return RankedTensorType::get(ranked_tensor_ty.getShape(), element_type);
|
||||
}
|
||||
if (auto unranked_tensor_ty = tensor_type.dyn_cast<UnrankedTensorType>()) {
|
||||
return UnrankedTensorType::get(element_type);
|
||||
}
|
||||
llvm_unreachable("unhandled type");
|
||||
}
|
||||
|
||||
LogicalResult IsFiniteOp::inferReturnTypes(
|
||||
MLIRContext* ctx, Optional<Location>, ValueRange operands, DictionaryAttr,
|
||||
RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
|
||||
auto arg_ty = operands.front().getType().cast<TensorType>();
|
||||
Builder b(ctx);
|
||||
inferredReturnTypes.push_back(getSameShapeTensorType(arg_ty, b.getI1Type()));
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// RealOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult RealOp::inferReturnTypes(
|
||||
MLIRContext*, Optional<Location>, ValueRange operands, DictionaryAttr,
|
||||
RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
|
||||
|
|
Loading…
Reference in New Issue