diff --git a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td index 0986fa4..fd241cb 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td @@ -224,11 +224,14 @@ def HLO_LogisticOp: HLO_UnaryElementwiseOp<"logistic", def HLO_NotOp: HLO_UnaryElementwiseOp<"not", [NoSideEffect, SameOperandsAndResultType], HLO_PredOrIntTensor>, - BASE_HLO_NotOp; + BASE_HLO_NotOp { +} def HLO_NegOp: HLO_UnaryElementwiseOp<"negate", [NoSideEffect, SameOperandsAndResultType], HLO_IntFpOrComplexTensor>, - BASE_HLO_NegOp; + BASE_HLO_NegOp { + let hasFolder = 1; +} def HLO_PopulationCountOp: HLO_UnaryElementwiseOp<"popcnt", [NoSideEffect, SameOperandsAndResultType], HLO_IntTensor>, diff --git a/lib/Dialect/mhlo/IR/hlo_ops.cc b/lib/Dialect/mhlo/IR/hlo_ops.cc index 14ac3da..6e032d7 100644 --- a/lib/Dialect/mhlo/IR/hlo_ops.cc +++ b/lib/Dialect/mhlo/IR/hlo_ops.cc @@ -1821,6 +1821,50 @@ static LogicalResult Verify(CaseOp op) { return success(); } +//===----------------------------------------------------------------------===// +// UnaryOps +//===----------------------------------------------------------------------===// + +template +static Attribute UnaryFolder(Op* op, ArrayRef attrs) { + if (!attrs[0]) return {}; + + DenseElementsAttr val = attrs[0].dyn_cast(); + if (!val) return {}; + + ShapedType type = op->getType().template cast(); + if (!type.hasStaticShape()) { + return {}; + } + + Type etype = type.getElementType(); + + // Evaluate for integer values. + if (!etype.isa()) { + return {}; + } + + SmallVector values; + values.reserve(val.getNumElements()); + for (const auto v : val.getValues()) { + values.push_back(Convert()(v)); + } + + return DenseElementsAttr::get(type, values); +} + +#define UNARY_FOLDER(Op, Func) \ + OpFoldResult Op::fold(ArrayRef attrs) { \ + if (getElementTypeOrSelf(getType()).isa()) \ + return UnaryFolder>(this, attrs); \ + if (getElementTypeOrSelf(getType()).isa()) \ + return UnaryFolder>(this, attrs); \ + return {}; \ + } + +UNARY_FOLDER(NegOp, std::negate); + //===----------------------------------------------------------------------===// // BinaryOps //===----------------------------------------------------------------------===// diff --git a/tests/canonicalize.mlir b/tests/canonicalize.mlir index b1dbec4..5f45bab 100644 --- a/tests/canonicalize.mlir +++ b/tests/canonicalize.mlir @@ -834,3 +834,20 @@ func @fold_xor_zeros_constants() -> tensor<4xi32> { // CHECK: return %0 return %2 : tensor<4xi32> } + +// CHECK-LABEL: func @fold_negate_int +func @fold_negate_int() -> tensor<4xi32> { + %0 = mhlo.constant dense<[0, 1, 6, -3]> : tensor<4xi32> + // CHECK: mhlo.constant dense<[0, -1, -6, 3]> + %1 = "mhlo.negate"(%0) : (tensor<4xi32>) -> tensor<4xi32> + return %1 : tensor<4xi32> +} + +// CHECK-LABEL: func @fold_negate_float +func @fold_negate_float() -> tensor<4xf32> { + %0 = mhlo.constant dense<[0., 1., 6., -3.]> : tensor<4xf32> + // CHECK: mhlo.constant dense<[-0.000000e+00, -1.000000e+00, -6.000000e+00, 3.000000e+00]> + %1 = "mhlo.negate"(%0) : (tensor<4xf32>) -> tensor<4xf32> + return %1 : tensor<4xf32> +} +