From 7f86dd9f7e28d90562633b500fd578a328b1d87e Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 10 May 2021 18:53:02 -0700 Subject: [PATCH] Constant fold compare EQ if one of the operands is true and compare NE if one of the operands is false. PiperOrigin-RevId: 373058030 --- lib/Dialect/mhlo/IR/hlo_ops.cc | 33 +++++++++++++++++++++++++++++++++ tests/canonicalize.mlir | 16 ++++++++++++++++ 2 files changed, 49 insertions(+) diff --git a/lib/Dialect/mhlo/IR/hlo_ops.cc b/lib/Dialect/mhlo/IR/hlo_ops.cc index 547f45f..7c4ff44 100644 --- a/lib/Dialect/mhlo/IR/hlo_ops.cc +++ b/lib/Dialect/mhlo/IR/hlo_ops.cc @@ -3237,6 +3237,39 @@ OpFoldResult CompareOp::fold(ArrayRef operands) { return DenseIntElementsAttr::get(result_ty, {false}); } + auto op_el_type = lhs().getType().cast().getElementType(); + // Fold tensor<*xi1> != false to just return tensor<*xi1> + if (direction == "NE" && op_el_type.isInteger(1)) { + DenseIntElementsAttr cst_attr; + if (matchPattern(lhs(), m_Constant(&cst_attr))) { + if (cst_attr.isSplat() && !cst_attr.getSplatValue()) { + return rhs(); + } + } + + if (matchPattern(rhs(), m_Constant(&cst_attr))) { + if (cst_attr.isSplat() && !cst_attr.getSplatValue()) { + return lhs(); + } + } + } + + // Fold tensor<*xi1> == True to just return tensor<*xi1> + if (direction == "EQ" && op_el_type.isInteger(1)) { + DenseIntElementsAttr cst_attr; + if (matchPattern(lhs(), m_Constant(&cst_attr))) { + if (cst_attr.isSplat() && cst_attr.getSplatValue()) { + return rhs(); + } + } + + if (matchPattern(rhs(), m_Constant(&cst_attr))) { + if (cst_attr.isSplat() && cst_attr.getSplatValue()) { + return lhs(); + } + } + } + if (!operands[0] || !operands[1]) { return {}; } diff --git a/tests/canonicalize.mlir b/tests/canonicalize.mlir index 85816b0..1186faf 100644 --- a/tests/canonicalize.mlir +++ b/tests/canonicalize.mlir @@ -745,6 +745,14 @@ func @fold_compare_true_eq() -> tensor { return %2 : tensor } +// CHECK-LABEL: fold_compare_bools_true_eq +func @fold_compare_bools_true_eq(%arg : tensor) -> tensor { + %1 = mhlo.constant dense : tensor + // CHECK: return %arg + %2 = "mhlo.compare"(%arg, %1) {comparison_direction = "EQ"} : (tensor, tensor) -> tensor + return %2 : tensor +} + // CHECK-LABEL: fold_compare_false_eq_float func @fold_compare_false_eq_float() -> tensor { %0 = mhlo.constant dense<0.> : tensor @@ -781,6 +789,14 @@ func @fold_compare_true_ne() -> tensor { return %2 : tensor } +// CHECK-LABEL: fold_compare_bools_false_ne +func @fold_compare_bools_false_ne(%arg : tensor) -> tensor { + %1 = mhlo.constant dense : tensor + // CHECK: return %arg + %2 = "mhlo.compare"(%arg, %1) {comparison_direction = "NE"} : (tensor, tensor) -> tensor + return %2 : tensor +} + // CHECK-LABEL: fold_compare_false_ne_float func @fold_compare_false_ne_float() -> tensor { %0 = mhlo.constant dense<1.> : tensor