diff --git a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td index 29ef9da..b6c4340 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td @@ -396,6 +396,7 @@ def HLO_SignOp: HLO_UnaryElementwiseOp<"sign", See https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. }]; + let hasFolder = 1; } def HLO_SinOp: HLO_UnaryElementwiseOp<"sine", [NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor> { diff --git a/lib/Dialect/mhlo/IR/hlo_ops.cc b/lib/Dialect/mhlo/IR/hlo_ops.cc index 7c4ff44..d6855d6 100644 --- a/lib/Dialect/mhlo/IR/hlo_ops.cc +++ b/lib/Dialect/mhlo/IR/hlo_ops.cc @@ -2498,6 +2498,29 @@ struct logical_not { } }; +template +struct sign { + APFloat compute(const APFloat& f) { + if (f.isZero() || f.isNaN()) return f; + double value = f.isNegative() ? -1.0 : 1.0; + APFloat val(value); + bool unused; + val.convert(f.getSemantics(), APFloat::rmNearestTiesToEven, &unused); + return val; + } + + APInt compute(const APInt& i) { + APInt r = i; + if (r == 0) return r; + if (r.isNegative()) { + return APInt(r.getBitWidth(), -1, /*isSigned=*/true); + } + return APInt(r.getBitWidth(), 1, /*isSigned=*/true); + } + + FloatOrInt operator()(const FloatOrInt& fi) { return compute(fi); } +}; + #define UNARY_FOLDER(Op, Func) \ OpFoldResult Op::fold(ArrayRef attrs) { \ if (getElementTypeOrSelf(getType()).isa()) \ @@ -2522,6 +2545,7 @@ struct logical_not { } UNARY_FOLDER(NegOp, std::negate); +UNARY_FOLDER(SignOp, sign); UNARY_FOLDER_INT(NotOp, logical_not); UNARY_FOLDER_FLOAT(RoundOp, round); diff --git a/tests/canonicalize.mlir b/tests/canonicalize.mlir index 1186faf..a3e12e6 100644 --- a/tests/canonicalize.mlir +++ b/tests/canonicalize.mlir @@ -679,6 +679,46 @@ func @dce_while_without_side_effect(%arg0: tensor) -> tensor { return %arg0 : tensor } +// CHECK-LABEL: fold_sign_posi +func @fold_sign_posi() -> tensor { + // CHECK: %0 = mhlo.constant dense<1> : tensor + %0 = mhlo.constant dense<2> : tensor + %1 = "mhlo.sign"(%0) : (tensor) -> tensor + return %1 : tensor +} + +// CHECK-LABEL: fold_sign_negi +func @fold_sign_negi() -> tensor { + // CHECK: %0 = mhlo.constant dense<-1> : tensor + %0 = mhlo.constant dense<-2> : tensor + %1 = "mhlo.sign"(%0) : (tensor) -> tensor + return %1 : tensor +} + +// CHECK-LABEL: fold_sign_posf +func @fold_sign_posf() -> tensor { + // CHECK: %0 = mhlo.constant dense<1.000000e+00> : tensor + %0 = mhlo.constant dense<2.000000e+00> : tensor + %1 = "mhlo.sign"(%0) : (tensor) -> tensor + return %1 : tensor +} + +// CHECK-LABEL: fold_sign_negf +func @fold_sign_negf() -> tensor { + // CHECK: %0 = mhlo.constant dense<-1.000000e+00> : tensor + %0 = mhlo.constant dense<-2.000000e+00> : tensor + %1 = "mhlo.sign"(%0) : (tensor) -> tensor + return %1 : tensor +} + +// CHECK-LABEL: fold_sign_negzf +func @fold_sign_negzf() -> tensor { + // CHECK: %0 = mhlo.constant dense<-0.000000e+00> : tensor + %0 = mhlo.constant dense<-0.000000e+00> : tensor + %1 = "mhlo.sign"(%0) : (tensor) -> tensor + return %1 : tensor +} + // CHECK-LABEL: fold_compare_same_eq func @fold_compare_same_eq(%arg0: tensor) -> tensor { // CHECK: %0 = mhlo.constant dense : tensor