diff --git a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h index 2b1a18f..f1763c3 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h +++ b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h @@ -81,6 +81,9 @@ LogicalResult deriveShapeFromFirstOperand( OpBuilder *builder, Operation *op, SmallVectorImpl *reifiedReturnShapes); +// Type derivation function that returns a tensor type with a new element type. +TensorType getSameShapeTensorType(TensorType tensor_type, Type element_type); + } // end namespace mhlo } // end namespace mlir diff --git a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td index aaa5c75..2f08797 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td @@ -196,8 +196,8 @@ def HLO_ImagOp: HLO_UnaryElementwiseOp<"imag", let hasFolder = 1; } -def HLO_IsFiniteOp: HLO_UnaryElementwiseOp<"is_finite", - [NoSideEffect, SameOperandsAndResultShape], HLO_Tensor>, +def HLO_IsFiniteOp: HLO_UnaryElementwiseOp<"is_finite", [NoSideEffect, + DeclareOpInterfaceMethods], HLO_Tensor>, BASE_HLO_IsFiniteOp { let arguments = (ins HLO_FpTensor:$x); let results = (outs HLO_PredTensor:$y); diff --git a/lib/Dialect/mhlo/IR/hlo_ops.cc b/lib/Dialect/mhlo/IR/hlo_ops.cc index b037593..47b0765 100644 --- a/lib/Dialect/mhlo/IR/hlo_ops.cc +++ b/lib/Dialect/mhlo/IR/hlo_ops.cc @@ -976,6 +976,10 @@ OpFoldResult ComplexOp::fold(ArrayRef operands) { return {}; } +//===----------------------------------------------------------------------===// +// ImagOp +//===----------------------------------------------------------------------===// + namespace { Type CreateRealType(Type type) { auto element_ty = getElementTypeOrSelf(type); @@ -1009,6 +1013,33 @@ OpFoldResult ImagOp::fold(ArrayRef operands) { return {}; } +//===----------------------------------------------------------------------===// +// IsFiniteOp +//===----------------------------------------------------------------------===// + +TensorType getSameShapeTensorType(TensorType tensor_type, Type element_type) { + if (auto ranked_tensor_ty = tensor_type.dyn_cast()) { + return RankedTensorType::get(ranked_tensor_ty.getShape(), element_type); + } + if (auto unranked_tensor_ty = tensor_type.dyn_cast()) { + return UnrankedTensorType::get(element_type); + } + llvm_unreachable("unhandled type"); +} + +LogicalResult IsFiniteOp::inferReturnTypes( + MLIRContext* ctx, Optional, ValueRange operands, DictionaryAttr, + RegionRange, SmallVectorImpl& inferredReturnTypes) { + auto arg_ty = operands.front().getType().cast(); + Builder b(ctx); + inferredReturnTypes.push_back(getSameShapeTensorType(arg_ty, b.getI1Type())); + return success(); +} + +//===----------------------------------------------------------------------===// +// RealOp +//===----------------------------------------------------------------------===// + LogicalResult RealOp::inferReturnTypes( MLIRContext*, Optional, ValueRange operands, DictionaryAttr, RegionRange, SmallVectorImpl& inferredReturnTypes) {