[MLIR:HLO:LINALG] Fix codegen for mhlo.reshape when one side is rank 0
This is an annoying edge case because the collapse->expand lowering expects at least R1 or it will produce invalid linalg reshapes. Using the direct lowering works fine. PiperOrigin-RevId: 362269199
This commit is contained in:
parent
d77b556822
commit
09f8046816
|
@ -792,7 +792,9 @@ class ReshapeOpConverter : public OpConversionPattern<OpTy> {
|
||||||
}
|
}
|
||||||
curr_dst_dim++;
|
curr_dst_dim++;
|
||||||
}
|
}
|
||||||
if (curr_src_dim != src_shape.size() || curr_dst_dim != dst_shape.size())
|
// Rank 0 can always use the direct lowering.
|
||||||
|
if (!src_shape.empty() && !dst_shape.empty() &&
|
||||||
|
(curr_src_dim != src_shape.size() || curr_dst_dim != dst_shape.size()))
|
||||||
is_collapsing_source = false;
|
is_collapsing_source = false;
|
||||||
|
|
||||||
// Otherwise, we need to first reduce all source dimensions into one and
|
// Otherwise, we need to first reduce all source dimensions into one and
|
||||||
|
|
|
@ -473,6 +473,24 @@ func @transpose(%arg0: tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32> {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @reshape_0D_1D
|
||||||
|
func @reshape_0D_1D(%arg0: tensor<i32>) -> tensor<1xi32> {
|
||||||
|
%0 = "mhlo.reshape"(%arg0) : (tensor<i32>) -> tensor<1xi32>
|
||||||
|
return %0 : tensor<1xi32>
|
||||||
|
}
|
||||||
|
// CHECK: linalg.tensor_reshape %{{.*}} [] : tensor<i32> into tensor<1xi32>
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @reshape_1D_0D
|
||||||
|
func @reshape_1D_0D(%arg0: tensor<1xi32>) -> tensor<i32> {
|
||||||
|
%0 = "mhlo.reshape"(%arg0) : (tensor<1xi32>) -> tensor<i32>
|
||||||
|
return %0 : tensor<i32>
|
||||||
|
}
|
||||||
|
// CHECK: linalg.tensor_reshape %{{.*}} [] : tensor<1xi32> into tensor<i32>
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-DAG: #[[RESHAPE_MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
|
// CHECK-DAG: #[[RESHAPE_MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
|
||||||
// CHECK-DAG: #[[RESHAPE_MAP2:.*]] = affine_map<(d0, d1, d2) -> (d2)>
|
// CHECK-DAG: #[[RESHAPE_MAP2:.*]] = affine_map<(d0, d1, d2) -> (d2)>
|
||||||
// CHECK-LABEL: func @reshape_3D_2D
|
// CHECK-LABEL: func @reshape_3D_2D
|
||||||
|
|
Loading…
Reference in New Issue