Implement constant folding for mhlo.Sign.
PiperOrigin-RevId: 373550014
This commit is contained in:
parent
d764806c1e
commit
d2cc74317c
|
@ -396,6 +396,7 @@ def HLO_SignOp: HLO_UnaryElementwiseOp<"sign",
|
||||||
See
|
See
|
||||||
https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
|
https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
|
||||||
}];
|
}];
|
||||||
|
let hasFolder = 1;
|
||||||
}
|
}
|
||||||
def HLO_SinOp: HLO_UnaryElementwiseOp<"sine",
|
def HLO_SinOp: HLO_UnaryElementwiseOp<"sine",
|
||||||
[NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor> {
|
[NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor> {
|
||||||
|
|
|
@ -2498,6 +2498,29 @@ struct logical_not {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename FloatOrInt>
|
||||||
|
struct sign {
|
||||||
|
APFloat compute(const APFloat& f) {
|
||||||
|
if (f.isZero() || f.isNaN()) return f;
|
||||||
|
double value = f.isNegative() ? -1.0 : 1.0;
|
||||||
|
APFloat val(value);
|
||||||
|
bool unused;
|
||||||
|
val.convert(f.getSemantics(), APFloat::rmNearestTiesToEven, &unused);
|
||||||
|
return val;
|
||||||
|
}
|
||||||
|
|
||||||
|
APInt compute(const APInt& i) {
|
||||||
|
APInt r = i;
|
||||||
|
if (r == 0) return r;
|
||||||
|
if (r.isNegative()) {
|
||||||
|
return APInt(r.getBitWidth(), -1, /*isSigned=*/true);
|
||||||
|
}
|
||||||
|
return APInt(r.getBitWidth(), 1, /*isSigned=*/true);
|
||||||
|
}
|
||||||
|
|
||||||
|
FloatOrInt operator()(const FloatOrInt& fi) { return compute(fi); }
|
||||||
|
};
|
||||||
|
|
||||||
#define UNARY_FOLDER(Op, Func) \
|
#define UNARY_FOLDER(Op, Func) \
|
||||||
OpFoldResult Op::fold(ArrayRef<Attribute> attrs) { \
|
OpFoldResult Op::fold(ArrayRef<Attribute> attrs) { \
|
||||||
if (getElementTypeOrSelf(getType()).isa<FloatType>()) \
|
if (getElementTypeOrSelf(getType()).isa<FloatType>()) \
|
||||||
|
@ -2522,6 +2545,7 @@ struct logical_not {
|
||||||
}
|
}
|
||||||
|
|
||||||
UNARY_FOLDER(NegOp, std::negate);
|
UNARY_FOLDER(NegOp, std::negate);
|
||||||
|
UNARY_FOLDER(SignOp, sign);
|
||||||
UNARY_FOLDER_INT(NotOp, logical_not);
|
UNARY_FOLDER_INT(NotOp, logical_not);
|
||||||
UNARY_FOLDER_FLOAT(RoundOp, round);
|
UNARY_FOLDER_FLOAT(RoundOp, round);
|
||||||
|
|
||||||
|
|
|
@ -679,6 +679,46 @@ func @dce_while_without_side_effect(%arg0: tensor<i64>) -> tensor<i64> {
|
||||||
return %arg0 : tensor<i64>
|
return %arg0 : tensor<i64>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: fold_sign_posi
|
||||||
|
func @fold_sign_posi() -> tensor<i32> {
|
||||||
|
// CHECK: %0 = mhlo.constant dense<1> : tensor<i32>
|
||||||
|
%0 = mhlo.constant dense<2> : tensor<i32>
|
||||||
|
%1 = "mhlo.sign"(%0) : (tensor<i32>) -> tensor<i32>
|
||||||
|
return %1 : tensor<i32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: fold_sign_negi
|
||||||
|
func @fold_sign_negi() -> tensor<i32> {
|
||||||
|
// CHECK: %0 = mhlo.constant dense<-1> : tensor<i32>
|
||||||
|
%0 = mhlo.constant dense<-2> : tensor<i32>
|
||||||
|
%1 = "mhlo.sign"(%0) : (tensor<i32>) -> tensor<i32>
|
||||||
|
return %1 : tensor<i32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: fold_sign_posf
|
||||||
|
func @fold_sign_posf() -> tensor<bf16> {
|
||||||
|
// CHECK: %0 = mhlo.constant dense<1.000000e+00> : tensor<bf16>
|
||||||
|
%0 = mhlo.constant dense<2.000000e+00> : tensor<bf16>
|
||||||
|
%1 = "mhlo.sign"(%0) : (tensor<bf16>) -> tensor<bf16>
|
||||||
|
return %1 : tensor<bf16>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: fold_sign_negf
|
||||||
|
func @fold_sign_negf() -> tensor<bf16> {
|
||||||
|
// CHECK: %0 = mhlo.constant dense<-1.000000e+00> : tensor<bf16>
|
||||||
|
%0 = mhlo.constant dense<-2.000000e+00> : tensor<bf16>
|
||||||
|
%1 = "mhlo.sign"(%0) : (tensor<bf16>) -> tensor<bf16>
|
||||||
|
return %1 : tensor<bf16>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: fold_sign_negzf
|
||||||
|
func @fold_sign_negzf() -> tensor<bf16> {
|
||||||
|
// CHECK: %0 = mhlo.constant dense<-0.000000e+00> : tensor<bf16>
|
||||||
|
%0 = mhlo.constant dense<-0.000000e+00> : tensor<bf16>
|
||||||
|
%1 = "mhlo.sign"(%0) : (tensor<bf16>) -> tensor<bf16>
|
||||||
|
return %1 : tensor<bf16>
|
||||||
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: fold_compare_same_eq
|
// CHECK-LABEL: fold_compare_same_eq
|
||||||
func @fold_compare_same_eq(%arg0: tensor<i64>) -> tensor<i1> {
|
func @fold_compare_same_eq(%arg0: tensor<i64>) -> tensor<i1> {
|
||||||
// CHECK: %0 = mhlo.constant dense<true> : tensor<i1>
|
// CHECK: %0 = mhlo.constant dense<true> : tensor<i1>
|
||||||
|
|
Loading…
Reference in New Issue