From 09f804681688003677ee31529dbbbbb14086bcf9 Mon Sep 17 00:00:00 2001 From: Benjamin Kramer Date: Thu, 11 Mar 2021 05:28:51 -0800 Subject: [PATCH] [MLIR:HLO:LINALG] Fix codegen for mhlo.reshape when one side is rank 0 This is an annoying edge case because the collapse->expand lowering expects at least R1 or it will produce invalid linalg reshapes. Using the direct lowering works fine. PiperOrigin-RevId: 362269199 --- .../mhlo/transforms/legalize_to_linalg.cc | 4 +++- tests/hlo-legalize-to-linalg.mlir | 18 ++++++++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc index 69f72bf..11f2fee 100644 --- a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc +++ b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc @@ -792,7 +792,9 @@ class ReshapeOpConverter : public OpConversionPattern { } curr_dst_dim++; } - if (curr_src_dim != src_shape.size() || curr_dst_dim != dst_shape.size()) + // Rank 0 can always use the direct lowering. + if (!src_shape.empty() && !dst_shape.empty() && + (curr_src_dim != src_shape.size() || curr_dst_dim != dst_shape.size())) is_collapsing_source = false; // Otherwise, we need to first reduce all source dimensions into one and diff --git a/tests/hlo-legalize-to-linalg.mlir b/tests/hlo-legalize-to-linalg.mlir index af42311..c8259d0 100644 --- a/tests/hlo-legalize-to-linalg.mlir +++ b/tests/hlo-legalize-to-linalg.mlir @@ -473,6 +473,24 @@ func @transpose(%arg0: tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32> { // ----- +// CHECK-LABEL: func @reshape_0D_1D +func @reshape_0D_1D(%arg0: tensor) -> tensor<1xi32> { + %0 = "mhlo.reshape"(%arg0) : (tensor) -> tensor<1xi32> + return %0 : tensor<1xi32> +} +// CHECK: linalg.tensor_reshape %{{.*}} [] : tensor into tensor<1xi32> + +// ----- + +// CHECK-LABEL: func @reshape_1D_0D +func @reshape_1D_0D(%arg0: tensor<1xi32>) -> tensor { + %0 = "mhlo.reshape"(%arg0) : (tensor<1xi32>) -> tensor + return %0 : tensor +} +// CHECK: linalg.tensor_reshape %{{.*}} [] : tensor<1xi32> into tensor + +// ----- + // CHECK-DAG: #[[RESHAPE_MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)> // CHECK-DAG: #[[RESHAPE_MAP2:.*]] = affine_map<(d0, d1, d2) -> (d2)> // CHECK-LABEL: func @reshape_3D_2D