Constant fold compare EQ if one of the operands is true and compare NE if one of the operands is false.
PiperOrigin-RevId: 373058030
This commit is contained in:
parent
2d88f2f601
commit
7f86dd9f7e
|
@ -3237,6 +3237,39 @@ OpFoldResult CompareOp::fold(ArrayRef<Attribute> operands) {
|
||||||
return DenseIntElementsAttr::get(result_ty, {false});
|
return DenseIntElementsAttr::get(result_ty, {false});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto op_el_type = lhs().getType().cast<ShapedType>().getElementType();
|
||||||
|
// Fold tensor<*xi1> != false to just return tensor<*xi1>
|
||||||
|
if (direction == "NE" && op_el_type.isInteger(1)) {
|
||||||
|
DenseIntElementsAttr cst_attr;
|
||||||
|
if (matchPattern(lhs(), m_Constant(&cst_attr))) {
|
||||||
|
if (cst_attr.isSplat() && !cst_attr.getSplatValue<bool>()) {
|
||||||
|
return rhs();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (matchPattern(rhs(), m_Constant(&cst_attr))) {
|
||||||
|
if (cst_attr.isSplat() && !cst_attr.getSplatValue<bool>()) {
|
||||||
|
return lhs();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fold tensor<*xi1> == True to just return tensor<*xi1>
|
||||||
|
if (direction == "EQ" && op_el_type.isInteger(1)) {
|
||||||
|
DenseIntElementsAttr cst_attr;
|
||||||
|
if (matchPattern(lhs(), m_Constant(&cst_attr))) {
|
||||||
|
if (cst_attr.isSplat() && cst_attr.getSplatValue<bool>()) {
|
||||||
|
return rhs();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (matchPattern(rhs(), m_Constant(&cst_attr))) {
|
||||||
|
if (cst_attr.isSplat() && cst_attr.getSplatValue<bool>()) {
|
||||||
|
return lhs();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (!operands[0] || !operands[1]) {
|
if (!operands[0] || !operands[1]) {
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|
|
@ -745,6 +745,14 @@ func @fold_compare_true_eq() -> tensor<i1> {
|
||||||
return %2 : tensor<i1>
|
return %2 : tensor<i1>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: fold_compare_bools_true_eq
|
||||||
|
func @fold_compare_bools_true_eq(%arg : tensor<i1>) -> tensor<i1> {
|
||||||
|
%1 = mhlo.constant dense<true> : tensor<i1>
|
||||||
|
// CHECK: return %arg
|
||||||
|
%2 = "mhlo.compare"(%arg, %1) {comparison_direction = "EQ"} : (tensor<i1>, tensor<i1>) -> tensor<i1>
|
||||||
|
return %2 : tensor<i1>
|
||||||
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: fold_compare_false_eq_float
|
// CHECK-LABEL: fold_compare_false_eq_float
|
||||||
func @fold_compare_false_eq_float() -> tensor<i1> {
|
func @fold_compare_false_eq_float() -> tensor<i1> {
|
||||||
%0 = mhlo.constant dense<0.> : tensor<f32>
|
%0 = mhlo.constant dense<0.> : tensor<f32>
|
||||||
|
@ -781,6 +789,14 @@ func @fold_compare_true_ne() -> tensor<i1> {
|
||||||
return %2 : tensor<i1>
|
return %2 : tensor<i1>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: fold_compare_bools_false_ne
|
||||||
|
func @fold_compare_bools_false_ne(%arg : tensor<i1>) -> tensor<i1> {
|
||||||
|
%1 = mhlo.constant dense<false> : tensor<i1>
|
||||||
|
// CHECK: return %arg
|
||||||
|
%2 = "mhlo.compare"(%arg, %1) {comparison_direction = "NE"} : (tensor<i1>, tensor<i1>) -> tensor<i1>
|
||||||
|
return %2 : tensor<i1>
|
||||||
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: fold_compare_false_ne_float
|
// CHECK-LABEL: fold_compare_false_ne_float
|
||||||
func @fold_compare_false_ne_float() -> tensor<i1> {
|
func @fold_compare_false_ne_float() -> tensor<i1> {
|
||||||
%0 = mhlo.constant dense<1.> : tensor<f32>
|
%0 = mhlo.constant dense<1.> : tensor<f32>
|
||||||
|
|
Loading…
Reference in New Issue