[KERNEL_GEN] Add a pattern to bufferize `mhlo.reshape(<unranked_tensor>)`.
PiperOrigin-RevId: 356720899
This commit is contained in:
		
							parent
							
								
									54c2a49866
								
							
						
					
					
						commit
						36e04d92c0
					
				| 
						 | 
					@ -193,9 +193,30 @@ struct HloToLhloCustomCallOpConverter
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class HloToLhloReshapeUnrankedConverter
 | 
				
			||||||
 | 
					    : public BaseOpConversion<mhlo::ReshapeOp> {
 | 
				
			||||||
 | 
					 public:
 | 
				
			||||||
 | 
					  using BaseOpConversion<mhlo::ReshapeOp>::BaseOpConversion;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  LogicalResult matchAndRewrite(
 | 
				
			||||||
 | 
					      mhlo::ReshapeOp op, ArrayRef<Value> operands,
 | 
				
			||||||
 | 
					      ConversionPatternRewriter& rewriter) const final {
 | 
				
			||||||
 | 
					    mhlo::ReshapeOp::Adaptor adaptor(operands);
 | 
				
			||||||
 | 
					    auto unranked_operand_type =
 | 
				
			||||||
 | 
					        adaptor.operand().getType().dyn_cast<UnrankedMemRefType>();
 | 
				
			||||||
 | 
					    if (unranked_operand_type == nullptr) return failure();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    auto result_type = op.getType().cast<RankedTensorType>();
 | 
				
			||||||
 | 
					    rewriter.replaceOpWithNewOp<MemRefCastOp>(
 | 
				
			||||||
 | 
					        op, adaptor.operand(),
 | 
				
			||||||
 | 
					        MemRefType::get(result_type.getShape(), result_type.getElementType()));
 | 
				
			||||||
 | 
					    return success();
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// TODO(pifon): Consider inserting lhlo.copy as in
 | 
					// TODO(pifon): Consider inserting lhlo.copy as in
 | 
				
			||||||
// HloToLhloDynamicBroadcastInDimOpConverter.
 | 
					// HloToLhloDynamicBroadcastInDimOpConverter.
 | 
				
			||||||
struct HloToLhloDynamicReshapeConverter
 | 
					class HloToLhloDynamicReshapeConverter
 | 
				
			||||||
    : public BaseOpConversion<mhlo::DynamicReshapeOp> {
 | 
					    : public BaseOpConversion<mhlo::DynamicReshapeOp> {
 | 
				
			||||||
 public:
 | 
					 public:
 | 
				
			||||||
  using BaseOpConversion<mhlo::DynamicReshapeOp>::BaseOpConversion;
 | 
					  using BaseOpConversion<mhlo::DynamicReshapeOp>::BaseOpConversion;
 | 
				
			||||||
| 
						 | 
					@ -609,7 +630,8 @@ void populateDynamicHLOToLHLOConversionPattern(
 | 
				
			||||||
    OwningRewritePatternList* patterns, bool insert_copy) {
 | 
					    OwningRewritePatternList* patterns, bool insert_copy) {
 | 
				
			||||||
  patterns->insert<HloToLhloDynamicBroadcastInDimOpConverter>(
 | 
					  patterns->insert<HloToLhloDynamicBroadcastInDimOpConverter>(
 | 
				
			||||||
      *converter, context, insert_copy);
 | 
					      *converter, context, insert_copy);
 | 
				
			||||||
  patterns->insert<HloToLhloDynamicReshapeConverter>(*converter, context);
 | 
					  patterns->insert<HloToLhloDynamicReshapeConverter,
 | 
				
			||||||
 | 
					                   HloToLhloReshapeUnrankedConverter>(*converter, context);
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
void populateHLOToLHLOConversionPattern(MLIRContext* context,
 | 
					void populateHLOToLHLOConversionPattern(MLIRContext* context,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -32,3 +32,13 @@ func @dynamic_reshape_to_unranked(
 | 
				
			||||||
// CHECK-SAME: ([[ARG:%.*]]: memref<?xf32>, [[SHAPE:%.*]]: memref<?xi32>)
 | 
					// CHECK-SAME: ([[ARG:%.*]]: memref<?xf32>, [[SHAPE:%.*]]: memref<?xi32>)
 | 
				
			||||||
// CHECK-NEXT: memref_reshape [[ARG]]([[SHAPE]])
 | 
					// CHECK-NEXT: memref_reshape [[ARG]]([[SHAPE]])
 | 
				
			||||||
// CHECK-SAME:   : (memref<?xf32>, memref<?xi32>) -> memref<*xf32>
 | 
					// CHECK-SAME:   : (memref<?xf32>, memref<?xi32>) -> memref<*xf32>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// -----
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// CHECK-LABEL: func @reshape_unranked
 | 
				
			||||||
 | 
					func @reshape_unranked(%operand: tensor<*xf32>) -> tensor<f32> {
 | 
				
			||||||
 | 
					  %reshaped = "mhlo.reshape"(%operand) : (tensor<*xf32>) -> tensor<f32>
 | 
				
			||||||
 | 
					  return %reshaped : tensor<f32>
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					// CHECK-SAME: ([[ARG:%.*]]: memref<*xf32>)
 | 
				
			||||||
 | 
					// CHECK-NEXT: memref_cast [[ARG]] : memref<*xf32> to memref<f32>
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue