From 39589add2233c418cd0629b712ed61f85ea2513f Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Thu, 28 Jan 2021 05:36:37 -0800 Subject: [PATCH] Use the correct shape when converting mhlo.reshape If mhlo.reshape is not purely collapsing some consecutive operand dimensions into result dimensions, we will generate two linalg reshape op for it: the first one collapses all operand dimensions into one dimension, and the second one expands it to all result dimensions. For this case, the number of collapsed/expanded dimensions should be coming strictly from the operand/result. It is different from the case where we can generate one linalg reshape. For that case, the reassociation map should have rank equal to the largest among operand/result shape. PiperOrigin-RevId: 354293826 --- .../mhlo/transforms/legalize_to_linalg.cc | 25 +++++++++++++------ tests/hlo-legalize-to-linalg.mlir | 12 +++++++++ tests/lhlo-legalize-to-linalg.mlir | 14 +++++++++++ 3 files changed, 44 insertions(+), 7 deletions(-) diff --git a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc index deb7654..1486078 100644 --- a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc +++ b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc @@ -689,8 +689,9 @@ class ReshapeOpConverter : public OpConversionPattern { ConversionPatternRewriter& rewriter) const final { if (!VerifyHloOpBufferOrTensorSemantics(reshape_op)) return failure(); + typename OpTy::Adaptor operands(args); ShapedType operand_type = - reshape_op.operand().getType().template cast(); + operands.operand().getType().template cast(); ShapedType result_type = GetHloOpResultType(reshape_op); if (!operand_type.hasStaticShape() || !result_type.hasStaticShape()) @@ -708,7 +709,11 @@ class ReshapeOpConverter : public OpConversionPattern { unsigned curr_src_dim = 0, curr_dst_dim = 0; SmallVector reassociation_map( dst_shape.size()); - bool is_expanding_or_collapsing = true; + + // First scan all dimensions in the source shapes to see whether we have a + // perfect case where consecutive dimensions in source are collapsed. For + // such case we can just generate one single linalg.reshape. + bool is_collapsing_source = true; while (curr_src_dim < src_shape.size() && curr_dst_dim < dst_shape.size()) { int64_t dst_size = dst_shape[curr_dst_dim]; int64_t src_size = src_shape[curr_src_dim]; @@ -731,15 +736,17 @@ class ReshapeOpConverter : public OpConversionPattern { } } } else { - is_expanding_or_collapsing = false; + is_collapsing_source = false; break; } curr_dst_dim++; } if (curr_src_dim != src_shape.size() || curr_dst_dim != dst_shape.size()) - is_expanding_or_collapsing = false; + is_collapsing_source = false; - if (!is_expanding_or_collapsing) { + // Otherwise, we need to first reduce all source dimensions into one and + // then expand to the destination dimensions. + if (!is_collapsing_source) { auto get_identity_exprs = [&rewriter](int n) { SmallVector exprs; for (int i = 0; i < n; ++i) @@ -751,9 +758,13 @@ class ReshapeOpConverter : public OpConversionPattern { 1, std::multiplies()); auto elem_type = operand_type.getElementType(); SmallVector collapsing_map = { - get_identity_exprs(dst_shape.size())}; + // Use operand_type here because we need to collapse all operands + // dimensions. + get_identity_exprs(operand_type.getShape().size())}; SmallVector expanding_map = { - get_identity_exprs(src_shape.size())}; + // Use result_type here because we need to expand to all result + // dimensions. + get_identity_exprs(result_type.getShape().size())}; if (isLHLO) { auto collapsed_type = MemRefType::get({total_elems}, elem_type); diff --git a/tests/hlo-legalize-to-linalg.mlir b/tests/hlo-legalize-to-linalg.mlir index 48332d2..1fb52f5 100644 --- a/tests/hlo-legalize-to-linalg.mlir +++ b/tests/hlo-legalize-to-linalg.mlir @@ -508,6 +508,18 @@ func @reshape_3D_4D(%arg0: tensor<1x49x16xf32>) -> tensor<1x784x1x1xf32> { // ----- +// CHECK-DAG: #[[RESHAPE_MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-DAG: #[[RESHAPE_MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-LABEL: func @reshape_4D_3D +func @reshape_4D_3D(%arg0: tensor<1x8x10x3xf32>) -> tensor<1x240x1xf32> { + %0 = "mhlo.reshape"(%arg0) : (tensor<1x8x10x3xf32>) -> tensor<1x240x1xf32> + return %0 : tensor<1x240x1xf32> +} +// CHECK: linalg.tensor_reshape %{{.*}} [#[[RESHAPE_MAP1]]] +// CHECK: linalg.tensor_reshape %{{.*}} [#[[RESHAPE_MAP2]]] + +// ----- + // CHECK-DAG: #[[MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> // CHECK-LABEL: func @reshape1_4D_4D func @reshape1_4D_4D(%arg0: tensor<4x512x1x1xi32>) -> tensor<1x4x1x512xi32> { diff --git a/tests/lhlo-legalize-to-linalg.mlir b/tests/lhlo-legalize-to-linalg.mlir index 36bbdb7..ec614be 100644 --- a/tests/lhlo-legalize-to-linalg.mlir +++ b/tests/lhlo-legalize-to-linalg.mlir @@ -811,6 +811,20 @@ func @reshape_3D_4D(%arg0: memref<1x49x16xf32>, %arg1: memref<1x784x1x1xf32>) { // ----- +// CHECK-DAG: #[[RESHAPE_MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-DAG: #[[RESHAPE_MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-LABEL: func @reshape_4D_3D +func @reshape_4D_3D(%arg0: memref<1x8x10x3xf32>, %arg1: memref<1x240x1xf32>) { + "lmhlo.reshape"(%arg0, %arg1) + : (memref<1x8x10x3xf32>, memref<1x240x1xf32>) -> () + return +} +// CHECK: linalg.reshape %{{.*}} [#[[RESHAPE_MAP1]]] +// CHECK: linalg.reshape %{{.*}} [#[[RESHAPE_MAP2]]] +// CHECK: linalg.copy + +// ----- + // CHECK-DAG: #[[MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> // CHECK-LABEL: func @reshape1_4D_4D func @reshape1_4D_4D(%arg0: memref<4x512x1x1xi32>,