[MLIR:HLO:LINALG] Fix codegen for mhlo.reshape when one side is rank 0

This is an annoying edge case because the collapse->expand lowering expects at
least R1 or it will produce invalid linalg reshapes. Using the direct lowering
works fine.

PiperOrigin-RevId: 362269199
This commit is contained in:
Benjamin Kramer 2021-03-11 05:28:51 -08:00 committed by TensorFlow MLIR Team
parent d77b556822
commit 09f8046816
2 changed files with 21 additions and 1 deletions

View File

@ -792,7 +792,9 @@ class ReshapeOpConverter : public OpConversionPattern<OpTy> {
} }
curr_dst_dim++; curr_dst_dim++;
} }
if (curr_src_dim != src_shape.size() || curr_dst_dim != dst_shape.size()) // Rank 0 can always use the direct lowering.
if (!src_shape.empty() && !dst_shape.empty() &&
(curr_src_dim != src_shape.size() || curr_dst_dim != dst_shape.size()))
is_collapsing_source = false; is_collapsing_source = false;
// Otherwise, we need to first reduce all source dimensions into one and // Otherwise, we need to first reduce all source dimensions into one and

View File

@ -473,6 +473,24 @@ func @transpose(%arg0: tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32> {
// ----- // -----
// CHECK-LABEL: func @reshape_0D_1D
func @reshape_0D_1D(%arg0: tensor<i32>) -> tensor<1xi32> {
%0 = "mhlo.reshape"(%arg0) : (tensor<i32>) -> tensor<1xi32>
return %0 : tensor<1xi32>
}
// CHECK: linalg.tensor_reshape %{{.*}} [] : tensor<i32> into tensor<1xi32>
// -----
// CHECK-LABEL: func @reshape_1D_0D
func @reshape_1D_0D(%arg0: tensor<1xi32>) -> tensor<i32> {
%0 = "mhlo.reshape"(%arg0) : (tensor<1xi32>) -> tensor<i32>
return %0 : tensor<i32>
}
// CHECK: linalg.tensor_reshape %{{.*}} [] : tensor<1xi32> into tensor<i32>
// -----
// CHECK-DAG: #[[RESHAPE_MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)> // CHECK-DAG: #[[RESHAPE_MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
// CHECK-DAG: #[[RESHAPE_MAP2:.*]] = affine_map<(d0, d1, d2) -> (d2)> // CHECK-DAG: #[[RESHAPE_MAP2:.*]] = affine_map<(d0, d1, d2) -> (d2)>
// CHECK-LABEL: func @reshape_3D_2D // CHECK-LABEL: func @reshape_3D_2D