diff --git a/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc b/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc index bb5c0c8..7ec6629 100644 --- a/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc +++ b/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc @@ -193,9 +193,30 @@ struct HloToLhloCustomCallOpConverter } }; +class HloToLhloReshapeUnrankedConverter + : public BaseOpConversion { + public: + using BaseOpConversion::BaseOpConversion; + + LogicalResult matchAndRewrite( + mhlo::ReshapeOp op, ArrayRef operands, + ConversionPatternRewriter& rewriter) const final { + mhlo::ReshapeOp::Adaptor adaptor(operands); + auto unranked_operand_type = + adaptor.operand().getType().dyn_cast(); + if (unranked_operand_type == nullptr) return failure(); + + auto result_type = op.getType().cast(); + rewriter.replaceOpWithNewOp( + op, adaptor.operand(), + MemRefType::get(result_type.getShape(), result_type.getElementType())); + return success(); + } +}; + // TODO(pifon): Consider inserting lhlo.copy as in // HloToLhloDynamicBroadcastInDimOpConverter. -struct HloToLhloDynamicReshapeConverter +class HloToLhloDynamicReshapeConverter : public BaseOpConversion { public: using BaseOpConversion::BaseOpConversion; @@ -609,7 +630,8 @@ void populateDynamicHLOToLHLOConversionPattern( OwningRewritePatternList* patterns, bool insert_copy) { patterns->insert( *converter, context, insert_copy); - patterns->insert(*converter, context); + patterns->insert(*converter, context); } void populateHLOToLHLOConversionPattern(MLIRContext* context, diff --git a/tests/hlo-legalize-to-lhlo-unranked.mlir b/tests/hlo-legalize-to-lhlo-unranked.mlir index ee6e7b3..79530d0 100644 --- a/tests/hlo-legalize-to-lhlo-unranked.mlir +++ b/tests/hlo-legalize-to-lhlo-unranked.mlir @@ -32,3 +32,13 @@ func @dynamic_reshape_to_unranked( // CHECK-SAME: ([[ARG:%.*]]: memref, [[SHAPE:%.*]]: memref) // CHECK-NEXT: memref_reshape [[ARG]]([[SHAPE]]) // CHECK-SAME: : (memref, memref) -> memref<*xf32> + +// ----- + +// CHECK-LABEL: func @reshape_unranked +func @reshape_unranked(%operand: tensor<*xf32>) -> tensor { + %reshaped = "mhlo.reshape"(%operand) : (tensor<*xf32>) -> tensor + return %reshaped : tensor +} +// CHECK-SAME: ([[ARG:%.*]]: memref<*xf32>) +// CHECK-NEXT: memref_cast [[ARG]] : memref<*xf32> to memref