Implement constant folding for mhlo.Sign.

PiperOrigin-RevId: 373550014
This commit is contained in:
A. Unique TensorFlower 2021-05-13 03:53:14 -07:00 committed by TensorFlow MLIR Team
parent d764806c1e
commit d2cc74317c
3 changed files with 65 additions and 0 deletions

View File

@ -396,6 +396,7 @@ def HLO_SignOp: HLO_UnaryElementwiseOp<"sign",
See See
https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
}]; }];
let hasFolder = 1;
} }
def HLO_SinOp: HLO_UnaryElementwiseOp<"sine", def HLO_SinOp: HLO_UnaryElementwiseOp<"sine",
[NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor> { [NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor> {

View File

@ -2498,6 +2498,29 @@ struct logical_not {
} }
}; };
template <typename FloatOrInt>
struct sign {
APFloat compute(const APFloat& f) {
if (f.isZero() || f.isNaN()) return f;
double value = f.isNegative() ? -1.0 : 1.0;
APFloat val(value);
bool unused;
val.convert(f.getSemantics(), APFloat::rmNearestTiesToEven, &unused);
return val;
}
APInt compute(const APInt& i) {
APInt r = i;
if (r == 0) return r;
if (r.isNegative()) {
return APInt(r.getBitWidth(), -1, /*isSigned=*/true);
}
return APInt(r.getBitWidth(), 1, /*isSigned=*/true);
}
FloatOrInt operator()(const FloatOrInt& fi) { return compute(fi); }
};
#define UNARY_FOLDER(Op, Func) \ #define UNARY_FOLDER(Op, Func) \
OpFoldResult Op::fold(ArrayRef<Attribute> attrs) { \ OpFoldResult Op::fold(ArrayRef<Attribute> attrs) { \
if (getElementTypeOrSelf(getType()).isa<FloatType>()) \ if (getElementTypeOrSelf(getType()).isa<FloatType>()) \
@ -2522,6 +2545,7 @@ struct logical_not {
} }
UNARY_FOLDER(NegOp, std::negate); UNARY_FOLDER(NegOp, std::negate);
UNARY_FOLDER(SignOp, sign);
UNARY_FOLDER_INT(NotOp, logical_not); UNARY_FOLDER_INT(NotOp, logical_not);
UNARY_FOLDER_FLOAT(RoundOp, round); UNARY_FOLDER_FLOAT(RoundOp, round);

View File

@ -679,6 +679,46 @@ func @dce_while_without_side_effect(%arg0: tensor<i64>) -> tensor<i64> {
return %arg0 : tensor<i64> return %arg0 : tensor<i64>
} }
// CHECK-LABEL: fold_sign_posi
func @fold_sign_posi() -> tensor<i32> {
// CHECK: %0 = mhlo.constant dense<1> : tensor<i32>
%0 = mhlo.constant dense<2> : tensor<i32>
%1 = "mhlo.sign"(%0) : (tensor<i32>) -> tensor<i32>
return %1 : tensor<i32>
}
// CHECK-LABEL: fold_sign_negi
func @fold_sign_negi() -> tensor<i32> {
// CHECK: %0 = mhlo.constant dense<-1> : tensor<i32>
%0 = mhlo.constant dense<-2> : tensor<i32>
%1 = "mhlo.sign"(%0) : (tensor<i32>) -> tensor<i32>
return %1 : tensor<i32>
}
// CHECK-LABEL: fold_sign_posf
func @fold_sign_posf() -> tensor<bf16> {
// CHECK: %0 = mhlo.constant dense<1.000000e+00> : tensor<bf16>
%0 = mhlo.constant dense<2.000000e+00> : tensor<bf16>
%1 = "mhlo.sign"(%0) : (tensor<bf16>) -> tensor<bf16>
return %1 : tensor<bf16>
}
// CHECK-LABEL: fold_sign_negf
func @fold_sign_negf() -> tensor<bf16> {
// CHECK: %0 = mhlo.constant dense<-1.000000e+00> : tensor<bf16>
%0 = mhlo.constant dense<-2.000000e+00> : tensor<bf16>
%1 = "mhlo.sign"(%0) : (tensor<bf16>) -> tensor<bf16>
return %1 : tensor<bf16>
}
// CHECK-LABEL: fold_sign_negzf
func @fold_sign_negzf() -> tensor<bf16> {
// CHECK: %0 = mhlo.constant dense<-0.000000e+00> : tensor<bf16>
%0 = mhlo.constant dense<-0.000000e+00> : tensor<bf16>
%1 = "mhlo.sign"(%0) : (tensor<bf16>) -> tensor<bf16>
return %1 : tensor<bf16>
}
// CHECK-LABEL: fold_compare_same_eq // CHECK-LABEL: fold_compare_same_eq
func @fold_compare_same_eq(%arg0: tensor<i64>) -> tensor<i1> { func @fold_compare_same_eq(%arg0: tensor<i64>) -> tensor<i1> {
// CHECK: %0 = mhlo.constant dense<true> : tensor<i1> // CHECK: %0 = mhlo.constant dense<true> : tensor<i1>