[MLIR][HLO] Implement type inference for `is_finite` op
PiperOrigin-RevId: 354261420
This commit is contained in:
parent
c653db73c5
commit
fe2e5a175f
|
@ -81,6 +81,9 @@ LogicalResult deriveShapeFromFirstOperand(
|
||||||
OpBuilder *builder, Operation *op,
|
OpBuilder *builder, Operation *op,
|
||||||
SmallVectorImpl<Value> *reifiedReturnShapes);
|
SmallVectorImpl<Value> *reifiedReturnShapes);
|
||||||
|
|
||||||
|
// Type derivation function that returns a tensor type with a new element type.
|
||||||
|
TensorType getSameShapeTensorType(TensorType tensor_type, Type element_type);
|
||||||
|
|
||||||
} // end namespace mhlo
|
} // end namespace mhlo
|
||||||
} // end namespace mlir
|
} // end namespace mlir
|
||||||
|
|
||||||
|
|
|
@ -196,8 +196,8 @@ def HLO_ImagOp: HLO_UnaryElementwiseOp<"imag",
|
||||||
let hasFolder = 1;
|
let hasFolder = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
def HLO_IsFiniteOp: HLO_UnaryElementwiseOp<"is_finite",
|
def HLO_IsFiniteOp: HLO_UnaryElementwiseOp<"is_finite", [NoSideEffect,
|
||||||
[NoSideEffect, SameOperandsAndResultShape], HLO_Tensor>,
|
DeclareOpInterfaceMethods<InferTypeOpInterface>], HLO_Tensor>,
|
||||||
BASE_HLO_IsFiniteOp {
|
BASE_HLO_IsFiniteOp {
|
||||||
let arguments = (ins HLO_FpTensor:$x);
|
let arguments = (ins HLO_FpTensor:$x);
|
||||||
let results = (outs HLO_PredTensor:$y);
|
let results = (outs HLO_PredTensor:$y);
|
||||||
|
|
|
@ -976,6 +976,10 @@ OpFoldResult ComplexOp::fold(ArrayRef<Attribute> operands) {
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// ImagOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
Type CreateRealType(Type type) {
|
Type CreateRealType(Type type) {
|
||||||
auto element_ty = getElementTypeOrSelf(type);
|
auto element_ty = getElementTypeOrSelf(type);
|
||||||
|
@ -1009,6 +1013,33 @@ OpFoldResult ImagOp::fold(ArrayRef<Attribute> operands) {
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// IsFiniteOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
TensorType getSameShapeTensorType(TensorType tensor_type, Type element_type) {
|
||||||
|
if (auto ranked_tensor_ty = tensor_type.dyn_cast<RankedTensorType>()) {
|
||||||
|
return RankedTensorType::get(ranked_tensor_ty.getShape(), element_type);
|
||||||
|
}
|
||||||
|
if (auto unranked_tensor_ty = tensor_type.dyn_cast<UnrankedTensorType>()) {
|
||||||
|
return UnrankedTensorType::get(element_type);
|
||||||
|
}
|
||||||
|
llvm_unreachable("unhandled type");
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult IsFiniteOp::inferReturnTypes(
|
||||||
|
MLIRContext* ctx, Optional<Location>, ValueRange operands, DictionaryAttr,
|
||||||
|
RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
|
||||||
|
auto arg_ty = operands.front().getType().cast<TensorType>();
|
||||||
|
Builder b(ctx);
|
||||||
|
inferredReturnTypes.push_back(getSameShapeTensorType(arg_ty, b.getI1Type()));
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// RealOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
LogicalResult RealOp::inferReturnTypes(
|
LogicalResult RealOp::inferReturnTypes(
|
||||||
MLIRContext*, Optional<Location>, ValueRange operands, DictionaryAttr,
|
MLIRContext*, Optional<Location>, ValueRange operands, DictionaryAttr,
|
||||||
RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
|
RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
|
||||||
|
|
Loading…
Reference in New Issue