Constant fold compare EQ if one of the operands is true and compare NE if one of the operands is false.
PiperOrigin-RevId: 373058030
This commit is contained in:
		
							parent
							
								
									2d88f2f601
								
							
						
					
					
						commit
						7f86dd9f7e
					
				|  | @ -3237,6 +3237,39 @@ OpFoldResult CompareOp::fold(ArrayRef<Attribute> operands) { | ||||||
|     return DenseIntElementsAttr::get(result_ty, {false}); |     return DenseIntElementsAttr::get(result_ty, {false}); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|  |   auto op_el_type = lhs().getType().cast<ShapedType>().getElementType(); | ||||||
|  |   // Fold tensor<*xi1> != false to just return tensor<*xi1>
 | ||||||
|  |   if (direction == "NE" && op_el_type.isInteger(1)) { | ||||||
|  |     DenseIntElementsAttr cst_attr; | ||||||
|  |     if (matchPattern(lhs(), m_Constant(&cst_attr))) { | ||||||
|  |       if (cst_attr.isSplat() && !cst_attr.getSplatValue<bool>()) { | ||||||
|  |         return rhs(); | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     if (matchPattern(rhs(), m_Constant(&cst_attr))) { | ||||||
|  |       if (cst_attr.isSplat() && !cst_attr.getSplatValue<bool>()) { | ||||||
|  |         return lhs(); | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   // Fold tensor<*xi1> == True to just return tensor<*xi1>
 | ||||||
|  |   if (direction == "EQ" && op_el_type.isInteger(1)) { | ||||||
|  |     DenseIntElementsAttr cst_attr; | ||||||
|  |     if (matchPattern(lhs(), m_Constant(&cst_attr))) { | ||||||
|  |       if (cst_attr.isSplat() && cst_attr.getSplatValue<bool>()) { | ||||||
|  |         return rhs(); | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     if (matchPattern(rhs(), m_Constant(&cst_attr))) { | ||||||
|  |       if (cst_attr.isSplat() && cst_attr.getSplatValue<bool>()) { | ||||||
|  |         return lhs(); | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|   if (!operands[0] || !operands[1]) { |   if (!operands[0] || !operands[1]) { | ||||||
|     return {}; |     return {}; | ||||||
|   } |   } | ||||||
|  |  | ||||||
|  | @ -745,6 +745,14 @@ func @fold_compare_true_eq() -> tensor<i1> { | ||||||
|   return %2 : tensor<i1> |   return %2 : tensor<i1> | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // CHECK-LABEL: fold_compare_bools_true_eq | ||||||
|  | func @fold_compare_bools_true_eq(%arg : tensor<i1>) -> tensor<i1> { | ||||||
|  |   %1 = mhlo.constant dense<true> : tensor<i1> | ||||||
|  |   // CHECK: return %arg | ||||||
|  |   %2 = "mhlo.compare"(%arg, %1) {comparison_direction = "EQ"} : (tensor<i1>, tensor<i1>) -> tensor<i1> | ||||||
|  |   return %2 : tensor<i1> | ||||||
|  | } | ||||||
|  | 
 | ||||||
| // CHECK-LABEL: fold_compare_false_eq_float | // CHECK-LABEL: fold_compare_false_eq_float | ||||||
| func @fold_compare_false_eq_float() -> tensor<i1> { | func @fold_compare_false_eq_float() -> tensor<i1> { | ||||||
|   %0 = mhlo.constant dense<0.> : tensor<f32> |   %0 = mhlo.constant dense<0.> : tensor<f32> | ||||||
|  | @ -781,6 +789,14 @@ func @fold_compare_true_ne() -> tensor<i1> { | ||||||
|   return %2 : tensor<i1> |   return %2 : tensor<i1> | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // CHECK-LABEL: fold_compare_bools_false_ne | ||||||
|  | func @fold_compare_bools_false_ne(%arg : tensor<i1>) -> tensor<i1> { | ||||||
|  |   %1 = mhlo.constant dense<false> : tensor<i1> | ||||||
|  |   // CHECK: return %arg | ||||||
|  |   %2 = "mhlo.compare"(%arg, %1) {comparison_direction = "NE"} : (tensor<i1>, tensor<i1>) -> tensor<i1> | ||||||
|  |   return %2 : tensor<i1> | ||||||
|  | } | ||||||
|  | 
 | ||||||
| // CHECK-LABEL: fold_compare_false_ne_float | // CHECK-LABEL: fold_compare_false_ne_float | ||||||
| func @fold_compare_false_ne_float() -> tensor<i1> { | func @fold_compare_false_ne_float() -> tensor<i1> { | ||||||
|   %0 = mhlo.constant dense<1.> : tensor<f32> |   %0 = mhlo.constant dense<1.> : tensor<f32> | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue