Fix bug using std.rank instead of shape.rank

PiperOrigin-RevId: 339890070
This commit is contained in:
Tres Popp 2020-10-30 09:58:48 -07:00 committed by TensorFlow MLIR Team
parent 76b30fd426
commit 81e8d778c4
2 changed files with 6 additions and 8 deletions

View File

@ -377,9 +377,9 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp
Value rhs_shape = Value rhs_shape =
rewriter.create<shape::ShapeOfOp>(loc, extent_tensor_type, rhs); rewriter.create<shape::ShapeOfOp>(loc, extent_tensor_type, rhs);
Value lhs_rank = Value lhs_rank =
rewriter.create<RankOp>(loc, rewriter.getIndexType(), lhs_shape); rewriter.create<shape::RankOp>(loc, rewriter.getIndexType(), lhs_shape);
Value rhs_rank = Value rhs_rank =
rewriter.create<RankOp>(loc, rewriter.getIndexType(), rhs_shape); rewriter.create<shape::RankOp>(loc, rewriter.getIndexType(), rhs_shape);
Value greater_rank_lhs = Value greater_rank_lhs =
rewriter.create<CmpIOp>(loc, CmpIPredicate::sgt, lhs_rank, rhs_rank); rewriter.create<CmpIOp>(loc, CmpIPredicate::sgt, lhs_rank, rhs_rank);
Value greater_rank = Value greater_rank =

View File

@ -158,9 +158,9 @@ func @addUnrankedUnranked(
// CHECK-SAME: %[[LHS:.*]]: tensor<*xf32>, // CHECK-SAME: %[[LHS:.*]]: tensor<*xf32>,
// CHECK-SAME: %[[RHS:.*]]: tensor<*xf32>) -> tensor<*xf32> { // CHECK-SAME: %[[RHS:.*]]: tensor<*xf32>) -> tensor<*xf32> {
// CHECK-NEXT: %[[LHS_SHAPE:.*]] = shape.shape_of %[[LHS]] : tensor<*xf32> -> tensor<?xindex> // CHECK-NEXT: %[[LHS_SHAPE:.*]] = shape.shape_of %[[LHS]] : tensor<*xf32> -> tensor<?xindex>
// CHECK-NEXT: %[[RANK_LHS:.*]] = shape.rank %[[LHS_SHAPE]] : tensor<?xindex> -> index // CHECK-NEXT: %[[LHS_RANK:.*]] = shape.rank %[[LHS_SHAPE]] : tensor<?xindex> -> index
// CHECK-NEXT: %[[C0:.*]] = constant 0 : index // CHECK-NEXT: %[[C0:.*]] = constant 0 : index
// CHECK-NEXT: %[[LHS_IS_SCALAR:.*]] = cmpi "eq", %[[RANK_LHS]], %[[C0]] : index // CHECK-NEXT: %[[LHS_IS_SCALAR:.*]] = cmpi "eq", %[[LHS_RANK]], %[[C0]] : index
// Handle scalar LHS case // Handle scalar LHS case
// CHECK-NEXT: %[[VAL_8:.*]] = scf.if %[[LHS_IS_SCALAR]] -> (tensor<*xf32>) { // CHECK-NEXT: %[[VAL_8:.*]] = scf.if %[[LHS_IS_SCALAR]] -> (tensor<*xf32>) {
// CHECK-NEXT: %[[SCALAR_LHS:.*]] = tensor_cast %[[LHS]] : tensor<*xf32> to tensor<f32> // CHECK-NEXT: %[[SCALAR_LHS:.*]] = tensor_cast %[[LHS]] : tensor<*xf32> to tensor<f32>
@ -173,8 +173,8 @@ func @addUnrankedUnranked(
// CHECK-NEXT: scf.yield %[[RESHAPED_LHS_SCALAR_RESULT]] : tensor<*xf32> // CHECK-NEXT: scf.yield %[[RESHAPED_LHS_SCALAR_RESULT]] : tensor<*xf32>
// CHECK-NEXT: } else { // CHECK-NEXT: } else {
// CHECK-NEXT: %[[RHS_SHAPE:.*]] = shape.shape_of %[[RHS]] : tensor<*xf32> -> tensor<?xindex> // CHECK-NEXT: %[[RHS_SHAPE:.*]] = shape.shape_of %[[RHS]] : tensor<*xf32> -> tensor<?xindex>
// CHECK-NEXT: %[[RANK_RHS:.*]] = shape.rank %[[RHS_SHAPE]] : tensor<?xindex> -> index // CHECK-NEXT: %[[RHS_RANK:.*]] = shape.rank %[[RHS_SHAPE]] : tensor<?xindex> -> index
// CHECK-NEXT: %[[RHS_IS_SCALAR:.*]] = cmpi "eq", %[[RANK_RHS]], %[[C0]] : index // CHECK-NEXT: %[[RHS_IS_SCALAR:.*]] = cmpi "eq", %[[RHS_RANK]], %[[C0]] : index
// Handle scalar RHS case // Handle scalar RHS case
// CHECK-NEXT: %[[VAL_14:.*]] = scf.if %[[RHS_IS_SCALAR]] -> (tensor<*xf32>) { // CHECK-NEXT: %[[VAL_14:.*]] = scf.if %[[RHS_IS_SCALAR]] -> (tensor<*xf32>) {
// CHECK-NEXT: %[[SCALAR_RHS:.*]] = tensor_cast %[[RHS]] : tensor<*xf32> to tensor<f32> // CHECK-NEXT: %[[SCALAR_RHS:.*]] = tensor_cast %[[RHS]] : tensor<*xf32> to tensor<f32>
@ -197,8 +197,6 @@ func @addUnrankedUnranked(
// CHECK-NEXT: %[[RESHAPED_SAME_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[FLATTENED_RESULT]], %[[ANY_SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32> // CHECK-NEXT: %[[RESHAPED_SAME_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[FLATTENED_RESULT]], %[[ANY_SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK-NEXT: scf.yield %[[RESHAPED_SAME_RESULT]] : tensor<*xf32> // CHECK-NEXT: scf.yield %[[RESHAPED_SAME_RESULT]] : tensor<*xf32>
// CHECK-NEXT: } else { // CHECK-NEXT: } else {
// CHECK-NEXT: %[[LHS_RANK:.*]] = rank %[[LHS_SHAPE]] : tensor<?xindex>
// CHECK-NEXT: %[[RHS_RANK:.*]] = rank %[[RHS_SHAPE]] : tensor<?xindex>
// CHECK-NEXT: %[[LHS_RANK_GREATER:.*]] = cmpi "sgt", %[[LHS_RANK]], %[[RHS_RANK]] : index // CHECK-NEXT: %[[LHS_RANK_GREATER:.*]] = cmpi "sgt", %[[LHS_RANK]], %[[RHS_RANK]] : index
// CHECK-NEXT: %[[GREATEST_RANK:.*]] = select %[[LHS_RANK_GREATER]], %[[LHS_RANK]], %[[RHS_RANK]] : index // CHECK-NEXT: %[[GREATEST_RANK:.*]] = select %[[LHS_RANK_GREATER]], %[[LHS_RANK]], %[[RHS_RANK]] : index
// CHECK-NEXT: %[[C2:.*]] = constant 2 : index // CHECK-NEXT: %[[C2:.*]] = constant 2 : index