Add a canonicalization pattern to remove redundant dynamic_reshapes.
PiperOrigin-RevId: 344517381
This commit is contained in:
		
							parent
							
								
									f183e95e69
								
							
						
					
					
						commit
						d14c63da54
					
				|  | @ -1268,8 +1268,8 @@ class DynamicReshapeOpNotActuallyDynamic | ||||||
| 
 | 
 | ||||||
| void DynamicReshapeOp::getCanonicalizationPatterns( | void DynamicReshapeOp::getCanonicalizationPatterns( | ||||||
|     OwningRewritePatternList& results, MLIRContext* context) { |     OwningRewritePatternList& results, MLIRContext* context) { | ||||||
|   results.insert<DynamicReshapeOpNotActuallyDynamic, ShapeOfDynamicReshape>( |   results.insert<DynamicReshapeOpNotActuallyDynamic, | ||||||
|       context); |                  RemoveRedundantDynamicReshape, ShapeOfDynamicReshape>(context); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| //===----------------------------------------------------------------------===//
 | //===----------------------------------------------------------------------===//
 | ||||||
|  |  | ||||||
|  | @ -28,3 +28,8 @@ def UnaryEinsumToEinsum : Pat< | ||||||
|   (HLO_UnaryEinsumOp $operand, $equation), |   (HLO_UnaryEinsumOp $operand, $equation), | ||||||
|   (HLO_EinsumOp (HLO_ConstOp (GetScalarOfType<1> $operand)), |   (HLO_EinsumOp (HLO_ConstOp (GetScalarOfType<1> $operand)), | ||||||
|                 $operand, (UnaryToBinaryEinsumEq $equation))>; |                 $operand, (UnaryToBinaryEinsumEq $equation))>; | ||||||
|  | 
 | ||||||
|  | // A dynamic reshape of a dynamic reshape is a dynamic reshape. | ||||||
|  | def RemoveRedundantDynamicReshape : Pat< | ||||||
|  |   (HLO_DynamicReshapeOp (HLO_DynamicReshapeOp $operand, $shape1), $shape2), | ||||||
|  |   (HLO_DynamicReshapeOp $operand, $shape2)>; | ||||||
|  |  | ||||||
|  | @ -585,6 +585,20 @@ func @shape_of_dynamic_reshape(%arg0: tensor<*xf32>, %shape: tensor<2xindex>) -> | ||||||
|   return %1 : tensor<2xindex> |   return %1 : tensor<2xindex> | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // CHECK-LABEL: func @dynamic_reshape_of_dynamic_reshape | ||||||
|  | // CHECK-SAME: [[ARG0:%[a-zA-Z0-9]+]] | ||||||
|  | // CHECK-SAME: [[ARG1:%[a-zA-Z0-9]+]] | ||||||
|  | func @dynamic_reshape_of_dynamic_reshape(%arg0: tensor<?xf16>, %shape: tensor<?xindex>) -> tensor<?xf16> { | ||||||
|  |   // CHECK: [[RES:%[a-zA-Z0-9]+]] = "mhlo.dynamic_reshape"([[ARG0]], %{{[a-zA-Z0-9]+}}) : (tensor<?xf16>, tensor<1xindex>) -> tensor<?xf16> | ||||||
|  |   // CHECK: return [[RES]] | ||||||
|  |   %0 = "mhlo.dynamic_reshape"(%arg0, %shape) : (tensor<?xf16>, tensor<?xindex>) -> tensor<*xf16> | ||||||
|  |   %1 = shape.shape_of %0 : tensor<*xf16> -> tensor<?xindex> | ||||||
|  |   %2 = shape.num_elements %1 : tensor<?xindex> -> index | ||||||
|  |   %3 = tensor_from_elements %2 : tensor<1xindex> | ||||||
|  |   %4 = "mhlo.dynamic_reshape"(%0, %3) : (tensor<*xf16>, tensor<1xindex>) -> tensor<?xf16> | ||||||
|  |   return %4 : tensor<?xf16> | ||||||
|  | } | ||||||
|  | 
 | ||||||
| // CHECK-LABEL: do_not_dce_while_with_outfeed | // CHECK-LABEL: do_not_dce_while_with_outfeed | ||||||
| func @do_not_dce_while_with_outfeed(%arg0: tensor<i64>) -> tensor<i64> { | func @do_not_dce_while_with_outfeed(%arg0: tensor<i64>) -> tensor<i64> { | ||||||
|   // CHECK: mhlo.while |   // CHECK: mhlo.while | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue