From 8d051723c09c109c8241bf651de984d13d6956fa Mon Sep 17 00:00:00 2001 From: Smit Hinsu Date: Thu, 17 Dec 2020 17:36:05 -0800 Subject: [PATCH] Use InferTypeOpInterface for HLO AbsOp and fix result shape inference Shape inference in case of ops with complex element types need to use the element type of complex as the result element type and not the full operand type. Before: "mhlo.abs"(%arg0) : (tensor<4xcomplex>) -> tensor<4xtensor<4xcomplex>> After: "mhlo.abs"(%arg0) : (tensor<4xcomplex>) -> tensor<4xf32> PiperOrigin-RevId: 348123967 --- include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td | 5 ++--- lib/Dialect/mhlo/IR/hlo_ops.cc | 25 ++++++++++++--------- 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td index 63fec24..cbdc3b6 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td @@ -146,10 +146,9 @@ class HLO_UnaryElementwiseOp traits, // Abs supports complex to real, so element type is not guaranteed to match. def HLO_AbsOp: HLO_UnaryElementwiseOp<"abs", - [NoSideEffect, SameOperandsAndResultShape], + [NoSideEffect, SameOperandsAndResultShape, + DeclareOpInterfaceMethods], TensorOf<[HLO_SInt, AnyFloat, HLO_Complex]>>, BASE_HLO_AbsOp { - let builders = [ - OpBuilderDAG<(ins "Value":$operand)>]; } def HLO_CbrtOp: HLO_UnaryElementwiseOp<"cbrt", diff --git a/lib/Dialect/mhlo/IR/hlo_ops.cc b/lib/Dialect/mhlo/IR/hlo_ops.cc index 3a091a9..a03bfa1 100644 --- a/lib/Dialect/mhlo/IR/hlo_ops.cc +++ b/lib/Dialect/mhlo/IR/hlo_ops.cc @@ -454,18 +454,23 @@ static LogicalResult Verify(DynamicUpdateSliceOp op) { // AbsOp //===----------------------------------------------------------------------===// -void AbsOp::build(OpBuilder& builder, OperationState& result, Value operand) { - auto shaped_type = operand.getType().cast(); - Type new_type; - if (!shaped_type.getElementType().isa()) { - new_type = operand.getType(); - } else if (shaped_type.hasRank()) { - new_type = RankedTensorType::get(shaped_type.getShape(), operand.getType()); - } else { - new_type = UnrankedTensorType::get(operand.getType()); +LogicalResult AbsOp::inferReturnTypes( + MLIRContext*, Optional, ValueRange operands, DictionaryAttr, + RegionRange, SmallVectorImpl& inferredReturnTypes) { + auto operand_ty = (*operands.begin()).getType().cast(); + Type element_ty = operand_ty.getElementType(); + if (auto complex_ty = element_ty.dyn_cast()) { + element_ty = complex_ty.getElementType(); } - return AbsOp::build(builder, result, new_type, operand); + Type result_ty; + if (operand_ty.hasRank()) { + result_ty = RankedTensorType::get(operand_ty.getShape(), element_ty); + } else { + result_ty = UnrankedTensorType::get(element_ty); + } + inferredReturnTypes.push_back(result_ty); + return success(); } //===----------------------------------------------------------------------===//