[KERNEL_GEN] Add a pattern to bufferize `mhlo.reshape(<unranked_tensor>)`.
PiperOrigin-RevId: 356720899
This commit is contained in:
		
							parent
							
								
									54c2a49866
								
							
						
					
					
						commit
						36e04d92c0
					
				|  | @ -193,9 +193,30 @@ struct HloToLhloCustomCallOpConverter | |||
|   } | ||||
| }; | ||||
| 
 | ||||
| class HloToLhloReshapeUnrankedConverter | ||||
|     : public BaseOpConversion<mhlo::ReshapeOp> { | ||||
|  public: | ||||
|   using BaseOpConversion<mhlo::ReshapeOp>::BaseOpConversion; | ||||
| 
 | ||||
|   LogicalResult matchAndRewrite( | ||||
|       mhlo::ReshapeOp op, ArrayRef<Value> operands, | ||||
|       ConversionPatternRewriter& rewriter) const final { | ||||
|     mhlo::ReshapeOp::Adaptor adaptor(operands); | ||||
|     auto unranked_operand_type = | ||||
|         adaptor.operand().getType().dyn_cast<UnrankedMemRefType>(); | ||||
|     if (unranked_operand_type == nullptr) return failure(); | ||||
| 
 | ||||
|     auto result_type = op.getType().cast<RankedTensorType>(); | ||||
|     rewriter.replaceOpWithNewOp<MemRefCastOp>( | ||||
|         op, adaptor.operand(), | ||||
|         MemRefType::get(result_type.getShape(), result_type.getElementType())); | ||||
|     return success(); | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| // TODO(pifon): Consider inserting lhlo.copy as in
 | ||||
| // HloToLhloDynamicBroadcastInDimOpConverter.
 | ||||
| struct HloToLhloDynamicReshapeConverter | ||||
| class HloToLhloDynamicReshapeConverter | ||||
|     : public BaseOpConversion<mhlo::DynamicReshapeOp> { | ||||
|  public: | ||||
|   using BaseOpConversion<mhlo::DynamicReshapeOp>::BaseOpConversion; | ||||
|  | @ -609,7 +630,8 @@ void populateDynamicHLOToLHLOConversionPattern( | |||
|     OwningRewritePatternList* patterns, bool insert_copy) { | ||||
|   patterns->insert<HloToLhloDynamicBroadcastInDimOpConverter>( | ||||
|       *converter, context, insert_copy); | ||||
|   patterns->insert<HloToLhloDynamicReshapeConverter>(*converter, context); | ||||
|   patterns->insert<HloToLhloDynamicReshapeConverter, | ||||
|                    HloToLhloReshapeUnrankedConverter>(*converter, context); | ||||
| } | ||||
| 
 | ||||
| void populateHLOToLHLOConversionPattern(MLIRContext* context, | ||||
|  |  | |||
|  | @ -32,3 +32,13 @@ func @dynamic_reshape_to_unranked( | |||
| // CHECK-SAME: ([[ARG:%.*]]: memref<?xf32>, [[SHAPE:%.*]]: memref<?xi32>) | ||||
| // CHECK-NEXT: memref_reshape [[ARG]]([[SHAPE]]) | ||||
| // CHECK-SAME:   : (memref<?xf32>, memref<?xi32>) -> memref<*xf32> | ||||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| // CHECK-LABEL: func @reshape_unranked | ||||
| func @reshape_unranked(%operand: tensor<*xf32>) -> tensor<f32> { | ||||
|   %reshaped = "mhlo.reshape"(%operand) : (tensor<*xf32>) -> tensor<f32> | ||||
|   return %reshaped : tensor<f32> | ||||
| } | ||||
| // CHECK-SAME: ([[ARG:%.*]]: memref<*xf32>) | ||||
| // CHECK-NEXT: memref_cast [[ARG]] : memref<*xf32> to memref<f32> | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue