Fold away shape.shape_of(mhlo.dynamic_reshape(inp, shape))
This specific pattern can be replaced with the shape passed to dynamic_reshape. This is implemented as a canonicalization on mhlo.dynamic_reshape to fit in the infrastructure of canonicalization. PiperOrigin-RevId: 342009365
This commit is contained in:
		
							parent
							
								
									7fc4985eae
								
							
						
					
					
						commit
						1dffa62fe9
					
				| 
						 | 
					@ -1268,7 +1268,8 @@ class DynamicReshapeOpNotActuallyDynamic
 | 
				
			||||||
 | 
					
 | 
				
			||||||
void DynamicReshapeOp::getCanonicalizationPatterns(
 | 
					void DynamicReshapeOp::getCanonicalizationPatterns(
 | 
				
			||||||
    OwningRewritePatternList& results, MLIRContext* context) {
 | 
					    OwningRewritePatternList& results, MLIRContext* context) {
 | 
				
			||||||
  results.insert<DynamicReshapeOpNotActuallyDynamic>(context);
 | 
					  results.insert<DynamicReshapeOpNotActuallyDynamic, ShapeOfDynamicReshape>(
 | 
				
			||||||
 | 
					      context);
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
//===----------------------------------------------------------------------===//
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -28,3 +28,6 @@ def DynamicBroadcastToOwnShape_2 : Pat<
 | 
				
			||||||
  (HLO_DynamicBroadcastInDimOp:$op $x, (Shape_ShapeOfOp $x), $attr),
 | 
					  (HLO_DynamicBroadcastInDimOp:$op $x, (Shape_ShapeOfOp $x), $attr),
 | 
				
			||||||
  (replaceWithValue $x)>;
 | 
					  (replaceWithValue $x)>;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def ShapeOfDynamicReshape : Pat<
 | 
				
			||||||
 | 
					  (Shape_ShapeOfOp (HLO_DynamicReshapeOp $x, $shape)),
 | 
				
			||||||
 | 
					  (replaceWithValue $shape)>;
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -575,6 +575,16 @@ func @dynamic_reshape_not_actually_dynamic(%arg0: tensor<4xf32>, %shape: tensor<
 | 
				
			||||||
  return %0 : tensor<4x1xf32>
 | 
					  return %0 : tensor<4x1xf32>
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// CHECK-LABEL: func @shape_of_dynamic_reshape
 | 
				
			||||||
 | 
					// CHECK-SAME: [[ARG0:%[a-zA-Z0-9]+]]
 | 
				
			||||||
 | 
					// CHECK-SAME: [[ARG1:%[a-zA-Z0-9]+]]
 | 
				
			||||||
 | 
					func @shape_of_dynamic_reshape(%arg0: tensor<*xf32>, %shape: tensor<2xindex>) -> tensor<2xindex> {
 | 
				
			||||||
 | 
					  // CHECK: return [[ARG1]]
 | 
				
			||||||
 | 
					  %0 = "mhlo.dynamic_reshape"(%arg0, %shape) : (tensor<*xf32>, tensor<2xindex>) -> tensor<?x?xf32>
 | 
				
			||||||
 | 
					  %1 = shape.shape_of %0 : tensor<?x?xf32> -> tensor<2xindex>
 | 
				
			||||||
 | 
					  return %1 : tensor<2xindex>
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// CHECK-LABEL: do_not_dce_while_with_outfeed
 | 
					// CHECK-LABEL: do_not_dce_while_with_outfeed
 | 
				
			||||||
func @do_not_dce_while_with_outfeed(%arg0: tensor<i64>) -> tensor<i64> {
 | 
					func @do_not_dce_while_with_outfeed(%arg0: tensor<i64>) -> tensor<i64> {
 | 
				
			||||||
  // CHECK: mhlo.while
 | 
					  // CHECK: mhlo.while
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue