Fix a bug in the case check of reshape op lowering.

PiperOrigin-RevId: 332044191
This commit is contained in:
Hanhan Wang 2020-09-16 11:05:26 -07:00 committed by TensorFlow MLIR Team
parent f442ea84e2
commit b29dd5ef8f
3 changed files with 50 additions and 1 deletions

View File

@ -627,7 +627,8 @@ class ReshapeOpConverter : public OpConversionPattern<OpTy> {
} }
currDstDim++; currDstDim++;
} }
if (currSrcDim != srcShape.size()) isExpandingOrCollapsing = false; if (currSrcDim != srcShape.size() || currDstDim != dstShape.size())
isExpandingOrCollapsing = false;
if (!isExpandingOrCollapsing) { if (!isExpandingOrCollapsing) {
auto getIdentityExprs = [&rewriter](int n) { auto getIdentityExprs = [&rewriter](int n) {

View File

@ -395,6 +395,28 @@ func @reshape_3D_4D(%arg0: tensor<1x49x16xf32>) -> tensor<1x784x1x1xf32> {
// ----- // -----
// CHECK-DAG: #[[MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
// CHECK-LABEL: func @reshape1_4D_4D
func @reshape1_4D_4D(%arg0: tensor<4x512x1x1xi32>) -> tensor<1x4x1x512xi32> {
%0 = "mhlo.reshape"(%arg0) : (tensor<4x512x1x1xi32>) -> tensor<1x4x1x512xi32>
return %0 : tensor<1x4x1x512xi32>
}
// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP]]]
// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP]]]
// -----
// CHECK-DAG: #[[MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
// CHECK-LABEL: func @reshape2_4D_4D
func @reshape2_4D_4D(%arg0: tensor<4x1x1x1024xi32>) -> tensor<4x1024x1x1xi32> {
%0 = "mhlo.reshape"(%arg0) : (tensor<4x1x1x1024xi32>) -> tensor<4x1024x1x1xi32>
return %0 : tensor<4x1024x1x1xi32>
}
// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP]]]
// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP]]]
// -----
// CHECK-LABEL: func @minf // CHECK-LABEL: func @minf
func @minf(%lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> { func @minf(%lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
%0 = "mhlo.minimum"(%lhs, %rhs) %0 = "mhlo.minimum"(%lhs, %rhs)

View File

@ -714,6 +714,32 @@ func @reshape_3D_4D(%arg0: memref<1x49x16xf32>, %arg1: memref<1x784x1x1xf32>) {
// ----- // -----
// CHECK-DAG: #[[MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
// CHECK-LABEL: func @reshape1_4D_4D
func @reshape1_4D_4D(%arg0: memref<4x512x1x1xi32>,
%arg1: memref<1x4x1x512xi32>) {
"lmhlo.reshape"(%arg0, %arg1)
: (memref<4x512x1x1xi32>, memref<1x4x1x512xi32>) -> ()
return
}
// CHECK: linalg.reshape %{{.*}} [#[[MAP]]]
// CHECK: linalg.reshape %{{.*}} [#[[MAP]]]
// -----
// CHECK-DAG: #[[MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
// CHECK-LABEL: func @reshape2_4D_4D
func @reshape2_4D_4D(%arg0: memref<4x1x1x1024xi32>,
%arg1: memref<4x1024x1x1xi32>) {
"lmhlo.reshape"(%arg0, %arg1)
: (memref<4x1x1x1024xi32>, memref<4x1024x1x1xi32>) -> ()
return
}
// CHECK: linalg.reshape %{{.*}} [#[[MAP]]]
// CHECK: linalg.reshape %{{.*}} [#[[MAP]]]
// -----
// CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1) -> (d0, -d1 + 2)> // CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1) -> (d0, -d1 + 2)>
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-LABEL: func @reverse // CHECK-LABEL: func @reverse