[KERNEL_GEN] Add a pattern to bufferize `mhlo.reshape(<unranked_tensor>)`.
PiperOrigin-RevId: 356720899
This commit is contained in:
parent
54c2a49866
commit
36e04d92c0
|
@ -193,9 +193,30 @@ struct HloToLhloCustomCallOpConverter
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class HloToLhloReshapeUnrankedConverter
|
||||||
|
: public BaseOpConversion<mhlo::ReshapeOp> {
|
||||||
|
public:
|
||||||
|
using BaseOpConversion<mhlo::ReshapeOp>::BaseOpConversion;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(
|
||||||
|
mhlo::ReshapeOp op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter& rewriter) const final {
|
||||||
|
mhlo::ReshapeOp::Adaptor adaptor(operands);
|
||||||
|
auto unranked_operand_type =
|
||||||
|
adaptor.operand().getType().dyn_cast<UnrankedMemRefType>();
|
||||||
|
if (unranked_operand_type == nullptr) return failure();
|
||||||
|
|
||||||
|
auto result_type = op.getType().cast<RankedTensorType>();
|
||||||
|
rewriter.replaceOpWithNewOp<MemRefCastOp>(
|
||||||
|
op, adaptor.operand(),
|
||||||
|
MemRefType::get(result_type.getShape(), result_type.getElementType()));
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// TODO(pifon): Consider inserting lhlo.copy as in
|
// TODO(pifon): Consider inserting lhlo.copy as in
|
||||||
// HloToLhloDynamicBroadcastInDimOpConverter.
|
// HloToLhloDynamicBroadcastInDimOpConverter.
|
||||||
struct HloToLhloDynamicReshapeConverter
|
class HloToLhloDynamicReshapeConverter
|
||||||
: public BaseOpConversion<mhlo::DynamicReshapeOp> {
|
: public BaseOpConversion<mhlo::DynamicReshapeOp> {
|
||||||
public:
|
public:
|
||||||
using BaseOpConversion<mhlo::DynamicReshapeOp>::BaseOpConversion;
|
using BaseOpConversion<mhlo::DynamicReshapeOp>::BaseOpConversion;
|
||||||
|
@ -609,7 +630,8 @@ void populateDynamicHLOToLHLOConversionPattern(
|
||||||
OwningRewritePatternList* patterns, bool insert_copy) {
|
OwningRewritePatternList* patterns, bool insert_copy) {
|
||||||
patterns->insert<HloToLhloDynamicBroadcastInDimOpConverter>(
|
patterns->insert<HloToLhloDynamicBroadcastInDimOpConverter>(
|
||||||
*converter, context, insert_copy);
|
*converter, context, insert_copy);
|
||||||
patterns->insert<HloToLhloDynamicReshapeConverter>(*converter, context);
|
patterns->insert<HloToLhloDynamicReshapeConverter,
|
||||||
|
HloToLhloReshapeUnrankedConverter>(*converter, context);
|
||||||
}
|
}
|
||||||
|
|
||||||
void populateHLOToLHLOConversionPattern(MLIRContext* context,
|
void populateHLOToLHLOConversionPattern(MLIRContext* context,
|
||||||
|
|
|
@ -32,3 +32,13 @@ func @dynamic_reshape_to_unranked(
|
||||||
// CHECK-SAME: ([[ARG:%.*]]: memref<?xf32>, [[SHAPE:%.*]]: memref<?xi32>)
|
// CHECK-SAME: ([[ARG:%.*]]: memref<?xf32>, [[SHAPE:%.*]]: memref<?xi32>)
|
||||||
// CHECK-NEXT: memref_reshape [[ARG]]([[SHAPE]])
|
// CHECK-NEXT: memref_reshape [[ARG]]([[SHAPE]])
|
||||||
// CHECK-SAME: : (memref<?xf32>, memref<?xi32>) -> memref<*xf32>
|
// CHECK-SAME: : (memref<?xf32>, memref<?xi32>) -> memref<*xf32>
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @reshape_unranked
|
||||||
|
func @reshape_unranked(%operand: tensor<*xf32>) -> tensor<f32> {
|
||||||
|
%reshaped = "mhlo.reshape"(%operand) : (tensor<*xf32>) -> tensor<f32>
|
||||||
|
return %reshaped : tensor<f32>
|
||||||
|
}
|
||||||
|
// CHECK-SAME: ([[ARG:%.*]]: memref<*xf32>)
|
||||||
|
// CHECK-NEXT: memref_cast [[ARG]] : memref<*xf32> to memref<f32>
|
||||||
|
|
Loading…
Reference in New Issue