[MLIR][HLO] Implement type inference for `is_finite` op

PiperOrigin-RevId: 354261420
This commit is contained in:
A. Unique TensorFlower 2021-01-28 00:55:03 -08:00 committed by TensorFlow MLIR Team
parent c653db73c5
commit fe2e5a175f
3 changed files with 36 additions and 2 deletions

View File

@ -81,6 +81,9 @@ LogicalResult deriveShapeFromFirstOperand(
OpBuilder *builder, Operation *op,
SmallVectorImpl<Value> *reifiedReturnShapes);
// Type derivation function that returns a tensor type with a new element type.
TensorType getSameShapeTensorType(TensorType tensor_type, Type element_type);
} // end namespace mhlo
} // end namespace mlir

View File

@ -196,8 +196,8 @@ def HLO_ImagOp: HLO_UnaryElementwiseOp<"imag",
let hasFolder = 1;
}
def HLO_IsFiniteOp: HLO_UnaryElementwiseOp<"is_finite",
[NoSideEffect, SameOperandsAndResultShape], HLO_Tensor>,
def HLO_IsFiniteOp: HLO_UnaryElementwiseOp<"is_finite", [NoSideEffect,
DeclareOpInterfaceMethods<InferTypeOpInterface>], HLO_Tensor>,
BASE_HLO_IsFiniteOp {
let arguments = (ins HLO_FpTensor:$x);
let results = (outs HLO_PredTensor:$y);

View File

@ -976,6 +976,10 @@ OpFoldResult ComplexOp::fold(ArrayRef<Attribute> operands) {
return {};
}
//===----------------------------------------------------------------------===//
// ImagOp
//===----------------------------------------------------------------------===//
namespace {
Type CreateRealType(Type type) {
auto element_ty = getElementTypeOrSelf(type);
@ -1009,6 +1013,33 @@ OpFoldResult ImagOp::fold(ArrayRef<Attribute> operands) {
return {};
}
//===----------------------------------------------------------------------===//
// IsFiniteOp
//===----------------------------------------------------------------------===//
TensorType getSameShapeTensorType(TensorType tensor_type, Type element_type) {
if (auto ranked_tensor_ty = tensor_type.dyn_cast<RankedTensorType>()) {
return RankedTensorType::get(ranked_tensor_ty.getShape(), element_type);
}
if (auto unranked_tensor_ty = tensor_type.dyn_cast<UnrankedTensorType>()) {
return UnrankedTensorType::get(element_type);
}
llvm_unreachable("unhandled type");
}
LogicalResult IsFiniteOp::inferReturnTypes(
MLIRContext* ctx, Optional<Location>, ValueRange operands, DictionaryAttr,
RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
auto arg_ty = operands.front().getType().cast<TensorType>();
Builder b(ctx);
inferredReturnTypes.push_back(getSameShapeTensorType(arg_ty, b.getI1Type()));
return success();
}
//===----------------------------------------------------------------------===//
// RealOp
//===----------------------------------------------------------------------===//
LogicalResult RealOp::inferReturnTypes(
MLIRContext*, Optional<Location>, ValueRange operands, DictionaryAttr,
RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {