Fix a bug in the case check of reshape op lowering.
PiperOrigin-RevId: 332044191
This commit is contained in:
parent
f442ea84e2
commit
b29dd5ef8f
|
@ -627,7 +627,8 @@ class ReshapeOpConverter : public OpConversionPattern<OpTy> {
|
|||
}
|
||||
currDstDim++;
|
||||
}
|
||||
if (currSrcDim != srcShape.size()) isExpandingOrCollapsing = false;
|
||||
if (currSrcDim != srcShape.size() || currDstDim != dstShape.size())
|
||||
isExpandingOrCollapsing = false;
|
||||
|
||||
if (!isExpandingOrCollapsing) {
|
||||
auto getIdentityExprs = [&rewriter](int n) {
|
||||
|
|
|
@ -395,6 +395,28 @@ func @reshape_3D_4D(%arg0: tensor<1x49x16xf32>) -> tensor<1x784x1x1xf32> {
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-DAG: #[[MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
|
||||
// CHECK-LABEL: func @reshape1_4D_4D
|
||||
func @reshape1_4D_4D(%arg0: tensor<4x512x1x1xi32>) -> tensor<1x4x1x512xi32> {
|
||||
%0 = "mhlo.reshape"(%arg0) : (tensor<4x512x1x1xi32>) -> tensor<1x4x1x512xi32>
|
||||
return %0 : tensor<1x4x1x512xi32>
|
||||
}
|
||||
// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP]]]
|
||||
// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP]]]
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-DAG: #[[MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
|
||||
// CHECK-LABEL: func @reshape2_4D_4D
|
||||
func @reshape2_4D_4D(%arg0: tensor<4x1x1x1024xi32>) -> tensor<4x1024x1x1xi32> {
|
||||
%0 = "mhlo.reshape"(%arg0) : (tensor<4x1x1x1024xi32>) -> tensor<4x1024x1x1xi32>
|
||||
return %0 : tensor<4x1024x1x1xi32>
|
||||
}
|
||||
// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP]]]
|
||||
// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP]]]
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @minf
|
||||
func @minf(%lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
%0 = "mhlo.minimum"(%lhs, %rhs)
|
||||
|
|
|
@ -714,6 +714,32 @@ func @reshape_3D_4D(%arg0: memref<1x49x16xf32>, %arg1: memref<1x784x1x1xf32>) {
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-DAG: #[[MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
|
||||
// CHECK-LABEL: func @reshape1_4D_4D
|
||||
func @reshape1_4D_4D(%arg0: memref<4x512x1x1xi32>,
|
||||
%arg1: memref<1x4x1x512xi32>) {
|
||||
"lmhlo.reshape"(%arg0, %arg1)
|
||||
: (memref<4x512x1x1xi32>, memref<1x4x1x512xi32>) -> ()
|
||||
return
|
||||
}
|
||||
// CHECK: linalg.reshape %{{.*}} [#[[MAP]]]
|
||||
// CHECK: linalg.reshape %{{.*}} [#[[MAP]]]
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-DAG: #[[MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
|
||||
// CHECK-LABEL: func @reshape2_4D_4D
|
||||
func @reshape2_4D_4D(%arg0: memref<4x1x1x1024xi32>,
|
||||
%arg1: memref<4x1024x1x1xi32>) {
|
||||
"lmhlo.reshape"(%arg0, %arg1)
|
||||
: (memref<4x1x1x1024xi32>, memref<4x1024x1x1xi32>) -> ()
|
||||
return
|
||||
}
|
||||
// CHECK: linalg.reshape %{{.*}} [#[[MAP]]]
|
||||
// CHECK: linalg.reshape %{{.*}} [#[[MAP]]]
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1) -> (d0, -d1 + 2)>
|
||||
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)>
|
||||
// CHECK-LABEL: func @reverse
|
||||
|
|
Loading…
Reference in New Issue