Handle rank 1 broadcasts in unranked kernel lowering.

Previously this started at rank 2 after checking for scalars and equal shapes. This resulted in cases such as <1xf32> + <2xf32> being treated as impossible.

PiperOrigin-RevId: 341043965
This commit is contained in:
Tres Popp 2020-11-06 07:19:37 -08:00 committed by TensorFlow MLIR Team
parent c30ea47682
commit af4c9774dc
2 changed files with 101 additions and 84 deletions

View File

@ -386,14 +386,14 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp
rewriter.create<SelectOp>(loc, greater_rank_lhs, lhs_rank, rhs_rank); rewriter.create<SelectOp>(loc, greater_rank_lhs, lhs_rank, rhs_rank);
// Generate a list of nested if/else statements to handle rank // Generate a list of nested if/else statements to handle rank
// specializations from 2-6. // specializations from 1-6.
scf::IfOp if_op = createRankSpecializedBroadcastAndOp(rewriter, op, lhs, scf::IfOp if_op = createRankSpecializedBroadcastAndOp(rewriter, op, lhs,
rhs, greater_rank, 2); rhs, greater_rank, 1);
// Put each subsequent rank specialization inside the else statement of the // Put each subsequent rank specialization inside the else statement of the
// previous one. // previous one.
OpBuilder else_builder = if_op.getElseBodyBuilder(rewriter.getListener()); OpBuilder else_builder = if_op.getElseBodyBuilder(rewriter.getListener());
for (int i = 3; i < max_rank_specialization; i++) { for (int i = 2; i < max_rank_specialization; i++) {
auto inner_if = createRankSpecializedBroadcastAndOp(else_builder, op, lhs, auto inner_if = createRankSpecializedBroadcastAndOp(else_builder, op, lhs,
rhs, greater_rank, i); rhs, greater_rank, i);

View File

@ -199,6 +199,21 @@ func @addUnrankedUnranked(
// CHECK-NEXT: } else { // CHECK-NEXT: } else {
// CHECK-NEXT: %[[LHS_RANK_GREATER:.*]] = cmpi "sgt", %[[LHS_RANK]], %[[RHS_RANK]] : index // CHECK-NEXT: %[[LHS_RANK_GREATER:.*]] = cmpi "sgt", %[[LHS_RANK]], %[[RHS_RANK]] : index
// CHECK-NEXT: %[[GREATEST_RANK:.*]] = select %[[LHS_RANK_GREATER]], %[[LHS_RANK]], %[[RHS_RANK]] : index // CHECK-NEXT: %[[GREATEST_RANK:.*]] = select %[[LHS_RANK_GREATER]], %[[LHS_RANK]], %[[RHS_RANK]] : index
// Handle rank 1 specialization
// CHECK-NEXT: %[[C1:.*]] = constant 1 : index
// CHECK-NEXT: %[[GREATEST_RANK_IS_1:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C1]] : index
// CHECK-NEXT: %[[RESULT_RANK_1:.*]] = scf.if %[[GREATEST_RANK_IS_1]] -> (tensor<*xf32>) {
// CHECK-NEXT: %[[CONST_SHAPE_1:.*]] = shape.const_shape [1]
// CHECK-NEXT: %[[BROADCASTED_LHS_1:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[CASTED_LHS_1:.*]] = tensor_cast %[[BROADCASTED_LHS_1]] : tensor<?xindex> to tensor<1xindex>
// CHECK-NEXT: %[[BROADCASTED_RHS_1:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[CASTED_RHS_1:.*]] = tensor_cast %[[BROADCASTED_RHS_1]] : tensor<?xindex> to tensor<1xindex>
// CHECK-NEXT: %[[RESHAPED_LHS_1:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_1]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK-NEXT: %[[RESHAPED_RHS_1:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_1]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK-NEXT: %[[RESULT_RANK_1:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_1]], %[[RESHAPED_RHS_1]] : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
// CHECK-NEXT: %[[RESULT_1:.*]] = tensor_cast %[[RESULT_RANK_1]] : tensor<?xf32> to tensor<*xf32>
// CHECK-NEXT: scf.yield %[[RESULT_1]] : tensor<*xf32>
// CHECK-NEXT: } else {
// CHECK-NEXT: %[[C2:.*]] = constant 2 : index // CHECK-NEXT: %[[C2:.*]] = constant 2 : index
// CHECK-NEXT: %[[GREATEST_RANK_IS_2:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C2]] : index // CHECK-NEXT: %[[GREATEST_RANK_IS_2:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C2]] : index
// Handle rank 2 specialization // Handle rank 2 specialization
@ -292,5 +307,7 @@ func @addUnrankedUnranked(
// CHECK-NEXT: } // CHECK-NEXT: }
// CHECK-NEXT: scf.yield %[[VAL_70:.*]] : tensor<*xf32> // CHECK-NEXT: scf.yield %[[VAL_70:.*]] : tensor<*xf32>
// CHECK-NEXT: } // CHECK-NEXT: }
// CHECK-NEXT: return %[[VAL_71:.*]] : tensor<*xf32> // CHECK-NEXT: scf.yield %[[VAL_71:.*]] : tensor<*xf32>
// CHECK-NEXT: }
// CHECK-NEXT: return %[[VAL_72:.*]] : tensor<*xf32>
// CHECK-NEXT: } // CHECK-NEXT: }