diff --git a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td index 63fec24..cbdc3b6 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td @@ -146,10 +146,9 @@ class HLO_UnaryElementwiseOp traits, // Abs supports complex to real, so element type is not guaranteed to match. def HLO_AbsOp: HLO_UnaryElementwiseOp<"abs", - [NoSideEffect, SameOperandsAndResultShape], + [NoSideEffect, SameOperandsAndResultShape, + DeclareOpInterfaceMethods], TensorOf<[HLO_SInt, AnyFloat, HLO_Complex]>>, BASE_HLO_AbsOp { - let builders = [ - OpBuilderDAG<(ins "Value":$operand)>]; } def HLO_CbrtOp: HLO_UnaryElementwiseOp<"cbrt", diff --git a/lib/Dialect/mhlo/IR/hlo_ops.cc b/lib/Dialect/mhlo/IR/hlo_ops.cc index 3a091a9..a03bfa1 100644 --- a/lib/Dialect/mhlo/IR/hlo_ops.cc +++ b/lib/Dialect/mhlo/IR/hlo_ops.cc @@ -454,18 +454,23 @@ static LogicalResult Verify(DynamicUpdateSliceOp op) { // AbsOp //===----------------------------------------------------------------------===// -void AbsOp::build(OpBuilder& builder, OperationState& result, Value operand) { - auto shaped_type = operand.getType().cast(); - Type new_type; - if (!shaped_type.getElementType().isa()) { - new_type = operand.getType(); - } else if (shaped_type.hasRank()) { - new_type = RankedTensorType::get(shaped_type.getShape(), operand.getType()); - } else { - new_type = UnrankedTensorType::get(operand.getType()); +LogicalResult AbsOp::inferReturnTypes( + MLIRContext*, Optional, ValueRange operands, DictionaryAttr, + RegionRange, SmallVectorImpl& inferredReturnTypes) { + auto operand_ty = (*operands.begin()).getType().cast(); + Type element_ty = operand_ty.getElementType(); + if (auto complex_ty = element_ty.dyn_cast()) { + element_ty = complex_ty.getElementType(); } - return AbsOp::build(builder, result, new_type, operand); + Type result_ty; + if (operand_ty.hasRank()) { + result_ty = RankedTensorType::get(operand_ty.getShape(), element_ty); + } else { + result_ty = UnrankedTensorType::get(element_ty); + } + inferredReturnTypes.push_back(result_ty); + return success(); } //===----------------------------------------------------------------------===//