[KERNEL_GEN] Add a pattern to bufferize `mhlo.reshape(<unranked_tensor>)`.

PiperOrigin-RevId: 356720899
This commit is contained in:
Alexander Belyaev 2021-02-10 06:30:50 -08:00 committed by TensorFlow MLIR Team
parent 54c2a49866
commit 36e04d92c0
2 changed files with 34 additions and 2 deletions

View File

@ -193,9 +193,30 @@ struct HloToLhloCustomCallOpConverter
}
};
class HloToLhloReshapeUnrankedConverter
: public BaseOpConversion<mhlo::ReshapeOp> {
public:
using BaseOpConversion<mhlo::ReshapeOp>::BaseOpConversion;
LogicalResult matchAndRewrite(
mhlo::ReshapeOp op, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const final {
mhlo::ReshapeOp::Adaptor adaptor(operands);
auto unranked_operand_type =
adaptor.operand().getType().dyn_cast<UnrankedMemRefType>();
if (unranked_operand_type == nullptr) return failure();
auto result_type = op.getType().cast<RankedTensorType>();
rewriter.replaceOpWithNewOp<MemRefCastOp>(
op, adaptor.operand(),
MemRefType::get(result_type.getShape(), result_type.getElementType()));
return success();
}
};
// TODO(pifon): Consider inserting lhlo.copy as in
// HloToLhloDynamicBroadcastInDimOpConverter.
struct HloToLhloDynamicReshapeConverter
class HloToLhloDynamicReshapeConverter
: public BaseOpConversion<mhlo::DynamicReshapeOp> {
public:
using BaseOpConversion<mhlo::DynamicReshapeOp>::BaseOpConversion;
@ -609,7 +630,8 @@ void populateDynamicHLOToLHLOConversionPattern(
OwningRewritePatternList* patterns, bool insert_copy) {
patterns->insert<HloToLhloDynamicBroadcastInDimOpConverter>(
*converter, context, insert_copy);
patterns->insert<HloToLhloDynamicReshapeConverter>(*converter, context);
patterns->insert<HloToLhloDynamicReshapeConverter,
HloToLhloReshapeUnrankedConverter>(*converter, context);
}
void populateHLOToLHLOConversionPattern(MLIRContext* context,

View File

@ -32,3 +32,13 @@ func @dynamic_reshape_to_unranked(
// CHECK-SAME: ([[ARG:%.*]]: memref<?xf32>, [[SHAPE:%.*]]: memref<?xi32>)
// CHECK-NEXT: memref_reshape [[ARG]]([[SHAPE]])
// CHECK-SAME: : (memref<?xf32>, memref<?xi32>) -> memref<*xf32>
// -----
// CHECK-LABEL: func @reshape_unranked
func @reshape_unranked(%operand: tensor<*xf32>) -> tensor<f32> {
%reshaped = "mhlo.reshape"(%operand) : (tensor<*xf32>) -> tensor<f32>
return %reshaped : tensor<f32>
}
// CHECK-SAME: ([[ARG:%.*]]: memref<*xf32>)
// CHECK-NEXT: memref_cast [[ARG]] : memref<*xf32> to memref<f32>