diff --git a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc index aeffaea..0a8105e 100644 --- a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc +++ b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc @@ -627,7 +627,8 @@ class ReshapeOpConverter : public OpConversionPattern { } currDstDim++; } - if (currSrcDim != srcShape.size()) isExpandingOrCollapsing = false; + if (currSrcDim != srcShape.size() || currDstDim != dstShape.size()) + isExpandingOrCollapsing = false; if (!isExpandingOrCollapsing) { auto getIdentityExprs = [&rewriter](int n) { diff --git a/tests/hlo-legalize-to-linalg.mlir b/tests/hlo-legalize-to-linalg.mlir index ea5bd83..263ea1b 100644 --- a/tests/hlo-legalize-to-linalg.mlir +++ b/tests/hlo-legalize-to-linalg.mlir @@ -395,6 +395,28 @@ func @reshape_3D_4D(%arg0: tensor<1x49x16xf32>) -> tensor<1x784x1x1xf32> { // ----- +// CHECK-DAG: #[[MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-LABEL: func @reshape1_4D_4D +func @reshape1_4D_4D(%arg0: tensor<4x512x1x1xi32>) -> tensor<1x4x1x512xi32> { + %0 = "mhlo.reshape"(%arg0) : (tensor<4x512x1x1xi32>) -> tensor<1x4x1x512xi32> + return %0 : tensor<1x4x1x512xi32> +} +// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP]]] +// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP]]] + +// ----- + +// CHECK-DAG: #[[MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-LABEL: func @reshape2_4D_4D +func @reshape2_4D_4D(%arg0: tensor<4x1x1x1024xi32>) -> tensor<4x1024x1x1xi32> { + %0 = "mhlo.reshape"(%arg0) : (tensor<4x1x1x1024xi32>) -> tensor<4x1024x1x1xi32> + return %0 : tensor<4x1024x1x1xi32> +} +// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP]]] +// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP]]] + +// ----- + // CHECK-LABEL: func @minf func @minf(%lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> { %0 = "mhlo.minimum"(%lhs, %rhs) diff --git a/tests/lhlo-legalize-to-linalg.mlir b/tests/lhlo-legalize-to-linalg.mlir index 136bee8..3162f37 100644 --- a/tests/lhlo-legalize-to-linalg.mlir +++ b/tests/lhlo-legalize-to-linalg.mlir @@ -714,6 +714,32 @@ func @reshape_3D_4D(%arg0: memref<1x49x16xf32>, %arg1: memref<1x784x1x1xf32>) { // ----- +// CHECK-DAG: #[[MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-LABEL: func @reshape1_4D_4D +func @reshape1_4D_4D(%arg0: memref<4x512x1x1xi32>, + %arg1: memref<1x4x1x512xi32>) { + "lmhlo.reshape"(%arg0, %arg1) + : (memref<4x512x1x1xi32>, memref<1x4x1x512xi32>) -> () + return +} +// CHECK: linalg.reshape %{{.*}} [#[[MAP]]] +// CHECK: linalg.reshape %{{.*}} [#[[MAP]]] + +// ----- + +// CHECK-DAG: #[[MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-LABEL: func @reshape2_4D_4D +func @reshape2_4D_4D(%arg0: memref<4x1x1x1024xi32>, + %arg1: memref<4x1024x1x1xi32>) { + "lmhlo.reshape"(%arg0, %arg1) + : (memref<4x1x1x1024xi32>, memref<4x1024x1x1xi32>) -> () + return +} +// CHECK: linalg.reshape %{{.*}} [#[[MAP]]] +// CHECK: linalg.reshape %{{.*}} [#[[MAP]]] + +// ----- + // CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1) -> (d0, -d1 + 2)> // CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK-LABEL: func @reverse