diff --git a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td index 7fec26a..2af7c44 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td @@ -470,14 +470,6 @@ def ReshapeMemRefCastOp: Op]; - let extraClassDeclaration = [{ MemRefType getType() { return getResult().getType().cast(); } }]; diff --git a/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc b/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc index 8751ab4..3838d3d 100644 --- a/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc +++ b/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc @@ -220,6 +220,31 @@ struct HloToLhloDynamicBroadcastInDimOpConverter } }; +struct HloToLhloDynamicReshapeConverter + : public BaseOpConversion { + public: + using BaseOpConversion::BaseOpConversion; + + LogicalResult matchAndRewrite( + mhlo::DynamicReshapeOp op, ArrayRef operands, + ConversionPatternRewriter& rewriter) const final { + Type result_type; + if (auto ranked_type = op.getType().dyn_cast()) { + result_type = + MemRefType::get(ranked_type.getShape(), ranked_type.getElementType()); + } else if (auto unranked_type = + op.getType().dyn_cast()) { + result_type = UnrankedMemRefType::get(unranked_type.getElementType(), 0); + } else { + return failure(); + } + mhlo::DynamicReshapeOp::Adaptor adaptor(operands); + rewriter.replaceOpWithNewOp( + op, result_type, adaptor.operand(), adaptor.output_shape()); + return success(); + } +}; + struct HloToLhloReduceOpConverter : public BaseOpConversion { public: using BaseOpConversion::BaseOpConversion; @@ -441,6 +466,7 @@ void populateHLOToLHLOConversionPattern( // clang-format off patterns->insert< HloToLhloDynamicBroadcastInDimOpConverter, + HloToLhloDynamicReshapeConverter, HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, diff --git a/tests/hlo-legalize-to-lhlo-unranked.mlir b/tests/hlo-legalize-to-lhlo-unranked.mlir index 063716a..cc60217 100644 --- a/tests/hlo-legalize-to-lhlo-unranked.mlir +++ b/tests/hlo-legalize-to-lhlo-unranked.mlir @@ -6,3 +6,29 @@ func @func_op_unranked_arg_result(%arg0: tensor<*xf32>) -> tensor<*xf32> { } // CHECK-SAME: ([[ARG:%.*]]: memref<*xf32>) -> memref<*xf32> // CHECK-NEXT: return [[ARG]] : memref<*xf32> + +// ----- + +// CHECK-LABEL: func @dynamic_reshape_from_unranked +func @dynamic_reshape_from_unranked( + %operand: tensor<*xf32>, %shape: tensor<1xi32>) -> tensor { + %reshaped = "mhlo.dynamic_reshape"(%operand, %shape) + : (tensor<*xf32>, tensor<1xi32>) -> tensor + return %reshaped : tensor +} +// CHECK-SAME: ([[ARG:%.*]]: memref<*xf32>, [[SHAPE:%.*]]: memref<1xi32>) +// CHECK-NEXT: reshape_memref_cast [[ARG]]([[SHAPE]]) +// CHECK-SAME: : (memref<*xf32>, memref<1xi32>) -> memref + +// ----- + +// CHECK-LABEL: func @dynamic_reshape_to_unranked +func @dynamic_reshape_to_unranked( + %operand: tensor, %shape: tensor) -> tensor<*xf32> { + %reshaped = "mhlo.dynamic_reshape"(%operand, %shape) + : (tensor, tensor) -> tensor<*xf32> + return %reshaped : tensor<*xf32> +} +// CHECK-SAME: ([[ARG:%.*]]: memref, [[SHAPE:%.*]]: memref) +// CHECK-NEXT: reshape_memref_cast [[ARG]]([[SHAPE]]) +// CHECK-SAME: : (memref, memref) -> memref<*xf32>