Fold away shape.shape_of(mhlo.dynamic_reshape(inp, shape))
This specific pattern can be replaced with the shape passed to dynamic_reshape. This is implemented as a canonicalization on mhlo.dynamic_reshape to fit in the infrastructure of canonicalization. PiperOrigin-RevId: 342009365
This commit is contained in:
		
							parent
							
								
									7fc4985eae
								
							
						
					
					
						commit
						1dffa62fe9
					
				|  | @ -1268,7 +1268,8 @@ class DynamicReshapeOpNotActuallyDynamic | |||
| 
 | ||||
| void DynamicReshapeOp::getCanonicalizationPatterns( | ||||
|     OwningRewritePatternList& results, MLIRContext* context) { | ||||
|   results.insert<DynamicReshapeOpNotActuallyDynamic>(context); | ||||
|   results.insert<DynamicReshapeOpNotActuallyDynamic, ShapeOfDynamicReshape>( | ||||
|       context); | ||||
| } | ||||
| 
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
|  |  | |||
|  | @ -28,3 +28,6 @@ def DynamicBroadcastToOwnShape_2 : Pat< | |||
|   (HLO_DynamicBroadcastInDimOp:$op $x, (Shape_ShapeOfOp $x), $attr), | ||||
|   (replaceWithValue $x)>; | ||||
| 
 | ||||
| def ShapeOfDynamicReshape : Pat< | ||||
|   (Shape_ShapeOfOp (HLO_DynamicReshapeOp $x, $shape)), | ||||
|   (replaceWithValue $shape)>; | ||||
|  |  | |||
|  | @ -575,6 +575,16 @@ func @dynamic_reshape_not_actually_dynamic(%arg0: tensor<4xf32>, %shape: tensor< | |||
|   return %0 : tensor<4x1xf32> | ||||
| } | ||||
| 
 | ||||
| // CHECK-LABEL: func @shape_of_dynamic_reshape | ||||
| // CHECK-SAME: [[ARG0:%[a-zA-Z0-9]+]] | ||||
| // CHECK-SAME: [[ARG1:%[a-zA-Z0-9]+]] | ||||
| func @shape_of_dynamic_reshape(%arg0: tensor<*xf32>, %shape: tensor<2xindex>) -> tensor<2xindex> { | ||||
|   // CHECK: return [[ARG1]] | ||||
|   %0 = "mhlo.dynamic_reshape"(%arg0, %shape) : (tensor<*xf32>, tensor<2xindex>) -> tensor<?x?xf32> | ||||
|   %1 = shape.shape_of %0 : tensor<?x?xf32> -> tensor<2xindex> | ||||
|   return %1 : tensor<2xindex> | ||||
| } | ||||
| 
 | ||||
| // CHECK-LABEL: do_not_dce_while_with_outfeed | ||||
| func @do_not_dce_while_with_outfeed(%arg0: tensor<i64>) -> tensor<i64> { | ||||
|   // CHECK: mhlo.while | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue