diff --git a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc index deb7654..1486078 100644 --- a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc +++ b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc @@ -689,8 +689,9 @@ class ReshapeOpConverter : public OpConversionPattern { ConversionPatternRewriter& rewriter) const final { if (!VerifyHloOpBufferOrTensorSemantics(reshape_op)) return failure(); + typename OpTy::Adaptor operands(args); ShapedType operand_type = - reshape_op.operand().getType().template cast(); + operands.operand().getType().template cast(); ShapedType result_type = GetHloOpResultType(reshape_op); if (!operand_type.hasStaticShape() || !result_type.hasStaticShape()) @@ -708,7 +709,11 @@ class ReshapeOpConverter : public OpConversionPattern { unsigned curr_src_dim = 0, curr_dst_dim = 0; SmallVector reassociation_map( dst_shape.size()); - bool is_expanding_or_collapsing = true; + + // First scan all dimensions in the source shapes to see whether we have a + // perfect case where consecutive dimensions in source are collapsed. For + // such case we can just generate one single linalg.reshape. + bool is_collapsing_source = true; while (curr_src_dim < src_shape.size() && curr_dst_dim < dst_shape.size()) { int64_t dst_size = dst_shape[curr_dst_dim]; int64_t src_size = src_shape[curr_src_dim]; @@ -731,15 +736,17 @@ class ReshapeOpConverter : public OpConversionPattern { } } } else { - is_expanding_or_collapsing = false; + is_collapsing_source = false; break; } curr_dst_dim++; } if (curr_src_dim != src_shape.size() || curr_dst_dim != dst_shape.size()) - is_expanding_or_collapsing = false; + is_collapsing_source = false; - if (!is_expanding_or_collapsing) { + // Otherwise, we need to first reduce all source dimensions into one and + // then expand to the destination dimensions. + if (!is_collapsing_source) { auto get_identity_exprs = [&rewriter](int n) { SmallVector exprs; for (int i = 0; i < n; ++i) @@ -751,9 +758,13 @@ class ReshapeOpConverter : public OpConversionPattern { 1, std::multiplies()); auto elem_type = operand_type.getElementType(); SmallVector collapsing_map = { - get_identity_exprs(dst_shape.size())}; + // Use operand_type here because we need to collapse all operands + // dimensions. + get_identity_exprs(operand_type.getShape().size())}; SmallVector expanding_map = { - get_identity_exprs(src_shape.size())}; + // Use result_type here because we need to expand to all result + // dimensions. + get_identity_exprs(result_type.getShape().size())}; if (isLHLO) { auto collapsed_type = MemRefType::get({total_elems}, elem_type); diff --git a/tests/hlo-legalize-to-linalg.mlir b/tests/hlo-legalize-to-linalg.mlir index 48332d2..1fb52f5 100644 --- a/tests/hlo-legalize-to-linalg.mlir +++ b/tests/hlo-legalize-to-linalg.mlir @@ -508,6 +508,18 @@ func @reshape_3D_4D(%arg0: tensor<1x49x16xf32>) -> tensor<1x784x1x1xf32> { // ----- +// CHECK-DAG: #[[RESHAPE_MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-DAG: #[[RESHAPE_MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-LABEL: func @reshape_4D_3D +func @reshape_4D_3D(%arg0: tensor<1x8x10x3xf32>) -> tensor<1x240x1xf32> { + %0 = "mhlo.reshape"(%arg0) : (tensor<1x8x10x3xf32>) -> tensor<1x240x1xf32> + return %0 : tensor<1x240x1xf32> +} +// CHECK: linalg.tensor_reshape %{{.*}} [#[[RESHAPE_MAP1]]] +// CHECK: linalg.tensor_reshape %{{.*}} [#[[RESHAPE_MAP2]]] + +// ----- + // CHECK-DAG: #[[MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> // CHECK-LABEL: func @reshape1_4D_4D func @reshape1_4D_4D(%arg0: tensor<4x512x1x1xi32>) -> tensor<1x4x1x512xi32> { diff --git a/tests/lhlo-legalize-to-linalg.mlir b/tests/lhlo-legalize-to-linalg.mlir index 36bbdb7..ec614be 100644 --- a/tests/lhlo-legalize-to-linalg.mlir +++ b/tests/lhlo-legalize-to-linalg.mlir @@ -811,6 +811,20 @@ func @reshape_3D_4D(%arg0: memref<1x49x16xf32>, %arg1: memref<1x784x1x1xf32>) { // ----- +// CHECK-DAG: #[[RESHAPE_MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-DAG: #[[RESHAPE_MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-LABEL: func @reshape_4D_3D +func @reshape_4D_3D(%arg0: memref<1x8x10x3xf32>, %arg1: memref<1x240x1xf32>) { + "lmhlo.reshape"(%arg0, %arg1) + : (memref<1x8x10x3xf32>, memref<1x240x1xf32>) -> () + return +} +// CHECK: linalg.reshape %{{.*}} [#[[RESHAPE_MAP1]]] +// CHECK: linalg.reshape %{{.*}} [#[[RESHAPE_MAP2]]] +// CHECK: linalg.copy + +// ----- + // CHECK-DAG: #[[MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> // CHECK-LABEL: func @reshape1_4D_4D func @reshape1_4D_4D(%arg0: memref<4x512x1x1xi32>,