Implement constant folding for mhlo.Sign.
PiperOrigin-RevId: 373550014
This commit is contained in:
parent
d764806c1e
commit
d2cc74317c
|
@ -396,6 +396,7 @@ def HLO_SignOp: HLO_UnaryElementwiseOp<"sign",
|
|||
See
|
||||
https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
|
||||
}];
|
||||
let hasFolder = 1;
|
||||
}
|
||||
def HLO_SinOp: HLO_UnaryElementwiseOp<"sine",
|
||||
[NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor> {
|
||||
|
|
|
@ -2498,6 +2498,29 @@ struct logical_not {
|
|||
}
|
||||
};
|
||||
|
||||
template <typename FloatOrInt>
|
||||
struct sign {
|
||||
APFloat compute(const APFloat& f) {
|
||||
if (f.isZero() || f.isNaN()) return f;
|
||||
double value = f.isNegative() ? -1.0 : 1.0;
|
||||
APFloat val(value);
|
||||
bool unused;
|
||||
val.convert(f.getSemantics(), APFloat::rmNearestTiesToEven, &unused);
|
||||
return val;
|
||||
}
|
||||
|
||||
APInt compute(const APInt& i) {
|
||||
APInt r = i;
|
||||
if (r == 0) return r;
|
||||
if (r.isNegative()) {
|
||||
return APInt(r.getBitWidth(), -1, /*isSigned=*/true);
|
||||
}
|
||||
return APInt(r.getBitWidth(), 1, /*isSigned=*/true);
|
||||
}
|
||||
|
||||
FloatOrInt operator()(const FloatOrInt& fi) { return compute(fi); }
|
||||
};
|
||||
|
||||
#define UNARY_FOLDER(Op, Func) \
|
||||
OpFoldResult Op::fold(ArrayRef<Attribute> attrs) { \
|
||||
if (getElementTypeOrSelf(getType()).isa<FloatType>()) \
|
||||
|
@ -2522,6 +2545,7 @@ struct logical_not {
|
|||
}
|
||||
|
||||
UNARY_FOLDER(NegOp, std::negate);
|
||||
UNARY_FOLDER(SignOp, sign);
|
||||
UNARY_FOLDER_INT(NotOp, logical_not);
|
||||
UNARY_FOLDER_FLOAT(RoundOp, round);
|
||||
|
||||
|
|
|
@ -679,6 +679,46 @@ func @dce_while_without_side_effect(%arg0: tensor<i64>) -> tensor<i64> {
|
|||
return %arg0 : tensor<i64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: fold_sign_posi
|
||||
func @fold_sign_posi() -> tensor<i32> {
|
||||
// CHECK: %0 = mhlo.constant dense<1> : tensor<i32>
|
||||
%0 = mhlo.constant dense<2> : tensor<i32>
|
||||
%1 = "mhlo.sign"(%0) : (tensor<i32>) -> tensor<i32>
|
||||
return %1 : tensor<i32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: fold_sign_negi
|
||||
func @fold_sign_negi() -> tensor<i32> {
|
||||
// CHECK: %0 = mhlo.constant dense<-1> : tensor<i32>
|
||||
%0 = mhlo.constant dense<-2> : tensor<i32>
|
||||
%1 = "mhlo.sign"(%0) : (tensor<i32>) -> tensor<i32>
|
||||
return %1 : tensor<i32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: fold_sign_posf
|
||||
func @fold_sign_posf() -> tensor<bf16> {
|
||||
// CHECK: %0 = mhlo.constant dense<1.000000e+00> : tensor<bf16>
|
||||
%0 = mhlo.constant dense<2.000000e+00> : tensor<bf16>
|
||||
%1 = "mhlo.sign"(%0) : (tensor<bf16>) -> tensor<bf16>
|
||||
return %1 : tensor<bf16>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: fold_sign_negf
|
||||
func @fold_sign_negf() -> tensor<bf16> {
|
||||
// CHECK: %0 = mhlo.constant dense<-1.000000e+00> : tensor<bf16>
|
||||
%0 = mhlo.constant dense<-2.000000e+00> : tensor<bf16>
|
||||
%1 = "mhlo.sign"(%0) : (tensor<bf16>) -> tensor<bf16>
|
||||
return %1 : tensor<bf16>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: fold_sign_negzf
|
||||
func @fold_sign_negzf() -> tensor<bf16> {
|
||||
// CHECK: %0 = mhlo.constant dense<-0.000000e+00> : tensor<bf16>
|
||||
%0 = mhlo.constant dense<-0.000000e+00> : tensor<bf16>
|
||||
%1 = "mhlo.sign"(%0) : (tensor<bf16>) -> tensor<bf16>
|
||||
return %1 : tensor<bf16>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: fold_compare_same_eq
|
||||
func @fold_compare_same_eq(%arg0: tensor<i64>) -> tensor<i1> {
|
||||
// CHECK: %0 = mhlo.constant dense<true> : tensor<i1>
|
||||
|
|
Loading…
Reference in New Issue