diff --git a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td index 126c7ad..52a5a49 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td @@ -217,6 +217,7 @@ def HLO_LogisticOp: HLO_UnaryElementwiseOp<"logistic", def HLO_NotOp: HLO_UnaryElementwiseOp<"not", [NoSideEffect, SameOperandsAndResultType], HLO_PredOrIntTensor>, BASE_HLO_NotOp { + let hasFolder = 1; } def HLO_NegOp: HLO_UnaryElementwiseOp<"negate", diff --git a/lib/Dialect/mhlo/IR/hlo_ops.cc b/lib/Dialect/mhlo/IR/hlo_ops.cc index 8eb0670..16b987d 100644 --- a/lib/Dialect/mhlo/IR/hlo_ops.cc +++ b/lib/Dialect/mhlo/IR/hlo_ops.cc @@ -2195,6 +2195,12 @@ struct round { } }; +struct logical_not { + APInt operator()(const APInt& i) { + return APInt(i.getBitWidth(), static_cast(!i)); + } +}; + #define UNARY_FOLDER(Op, Func) \ OpFoldResult Op::fold(ArrayRef attrs) { \ if (getElementTypeOrSelf(getType()).isa()) \ @@ -2204,6 +2210,13 @@ struct round { return {}; \ } +#define UNARY_FOLDER_INT(Op, Func) \ + OpFoldResult Op::fold(ArrayRef attrs) { \ + if (getElementTypeOrSelf(getType()).isa()) \ + return UnaryFolder(this, attrs); \ + return {}; \ + } + #define UNARY_FOLDER_FLOAT(Op, Func) \ OpFoldResult Op::fold(ArrayRef attrs) { \ if (getElementTypeOrSelf(getType()).isa()) \ @@ -2212,8 +2225,13 @@ struct round { } UNARY_FOLDER(NegOp, std::negate); +UNARY_FOLDER_INT(NotOp, logical_not); UNARY_FOLDER_FLOAT(RoundOp, round); +#undef UNARY_FOLDER +#undef UNARY_FOLDER_INT +#undef UNARY_FOLDER_FLOAT + //===----------------------------------------------------------------------===// // BinaryOps //===----------------------------------------------------------------------===// diff --git a/tests/canonicalize.mlir b/tests/canonicalize.mlir index 4206ac3..c234d32 100644 --- a/tests/canonicalize.mlir +++ b/tests/canonicalize.mlir @@ -1237,6 +1237,22 @@ func @fold_negate_float() -> tensor<4xf32> { return %1 : tensor<4xf32> } +// CHECK-LABEL func @fold_not() +func @fold_not() -> tensor<2x2xi1> { + %0 = mhlo.constant dense<[[true, false], [true, false]]> : tensor<2x2xi1> + // CHECK{LITERAL}: mhlo.constant dense<[[false, true], [false, true]]> : tensor<2x2xi1> + %1 = "mhlo.not"(%0) : (tensor<2x2xi1>) -> tensor<2x2xi1> + return %1 : tensor<2x2xi1> +} + +// CHECK-LABEL func @fold_not_i32() +func @fold_not_i32() -> tensor<2x2xi32> { + %0 = mhlo.constant dense<[[42, -12], [1, 0]]> : tensor<2x2xi32> + // CHECK-LITERAL: mhlo.constant dense<[[0, 0], [0, 1]]> : tensor<2x2xi32> + %1 = "mhlo.not"(%0) : (tensor<2x2xi32>) -> tensor<2x2xi32> + return %1 : tensor<2x2xi32> +} + // CHECK-LABEL: func @fold_sqrt_f32_constants func @fold_sqrt_f32_constants() -> tensor<4xf32> { %0 = mhlo.constant dense<1.0> : tensor<4xf32>