Remove rank 1 specialization from TransformUnrankedHloPass.

For binary ops, we already special-case rank 0 vs rank 1, and same shape. So we
don't need to special-case a maximum rank of 1.

PiperOrigin-RevId: 360881387
This commit is contained in:
Adrian Kuegel 2021-03-04 04:02:50 -08:00 committed by TensorFlow MLIR Team
parent 50a516fb9c
commit 62b357b601
2 changed files with 77 additions and 87 deletions

View File

@ -289,7 +289,8 @@ struct ConvertUnrankedDynamicBroadcastOpHelper {
// Iterates over the desired ranks to be specialized and generates the code // Iterates over the desired ranks to be specialized and generates the code
// snippet for each case. // snippet for each case.
static Value HandleBroadcastAndOp(OpBuilder &rewriter, ChloOpTy op, static Value HandleBroadcastAndOp(OpBuilder &rewriter, ChloOpTy op,
ValueRange operands) { ValueRange operands,
int min_rank_specialization) {
auto loc = op.getLoc(); auto loc = op.getLoc();
// Get the minimum broadcast shapes of the operands. // Get the minimum broadcast shapes of the operands.
@ -336,18 +337,20 @@ struct ConvertUnrankedDynamicBroadcastOpHelper {
} }
// Generate a list of nested if/else statements to handle rank // Generate a list of nested if/else statements to handle rank
// specializations from 1 to `kMaxRankSpecialization`. // specializations from `min_rank_specialization` to
// `kMaxRankSpecialization`.
constexpr int kMaxRankSpecialization = 5;
scf::IfOp if_op = createIfOpForRankSpecializedBroadcastAndOp( scf::IfOp if_op = createIfOpForRankSpecializedBroadcastAndOp(
rewriter, op, greater_rank, 1); rewriter, op, greater_rank, min_rank_specialization);
OpBuilder if_builder = if_op.getThenBodyBuilder(rewriter.getListener()); OpBuilder if_builder = if_op.getThenBodyBuilder(rewriter.getListener());
createRankSpecializedBroadcastAndOp(if_builder, op, reshaped_operands, createRankSpecializedBroadcastAndOp(if_builder, op, reshaped_operands,
reduced_shapes, 1); reduced_shapes,
min_rank_specialization);
// Put each subsequent rank specialization inside the else statement of the // Put each subsequent rank specialization inside the else statement of the
// previous one. // previous one.
OpBuilder else_builder = if_op.getElseBodyBuilder(rewriter.getListener()); OpBuilder else_builder = if_op.getElseBodyBuilder(rewriter.getListener());
constexpr int kMaxRankSpecialization = 5; for (int i = min_rank_specialization + 1; i < kMaxRankSpecialization; i++) {
for (int i = 2; i < kMaxRankSpecialization; i++) {
auto inner_if = createIfOpForRankSpecializedBroadcastAndOp( auto inner_if = createIfOpForRankSpecializedBroadcastAndOp(
else_builder, op, greater_rank, i); else_builder, op, greater_rank, i);
if_builder = inner_if.getThenBodyBuilder(rewriter.getListener()); if_builder = inner_if.getThenBodyBuilder(rewriter.getListener());
@ -468,13 +471,15 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp
// If shapes do not have exactly one element, nor are equal // If shapes do not have exactly one element, nor are equal
// //
// See if values are of a rank that we support. // See if values are of a rank that we support. We already handle cases
// where one of the operands is a scalar, or both are equal. So there should
// be no case left where both have rank 1.
OpBuilder if_neq_shapes_builder = OpBuilder if_neq_shapes_builder =
if_eq_shapes_op.getElseBodyBuilder(rewriter.getListener()); if_eq_shapes_op.getElseBodyBuilder(rewriter.getListener());
if_neq_shapes_builder.create<scf::YieldOp>( if_neq_shapes_builder.create<scf::YieldOp>(
loc, ConvertUnrankedDynamicBroadcastOpHelper< loc, ConvertUnrankedDynamicBroadcastOpHelper<ChloOpTy, HloOpTy>::
ChloOpTy, HloOpTy>::HandleBroadcastAndOp(if_neq_shapes_builder, HandleBroadcastAndOp(if_neq_shapes_builder, op, {lhs, rhs},
op, {lhs, rhs})); /*min_rank_specialization=*/2));
rewriter.replaceOp(op, {if_op.getResult(0)}); rewriter.replaceOp(op, {if_op.getResult(0)});
return success(); return success();
@ -518,9 +523,10 @@ struct ConvertUnrankedDynamicBroadcastSelectOp
// more potential for optimization here. This also is missing the // more potential for optimization here. This also is missing the
// specialization for rank 0. // specialization for rank 0.
rewriter.replaceOp( rewriter.replaceOp(
op, {ConvertUnrankedDynamicBroadcastOpHelper< op, {ConvertUnrankedDynamicBroadcastOpHelper<chlo::BroadcastSelectOp,
chlo::BroadcastSelectOp, mhlo::SelectOp>::
mhlo::SelectOp>::HandleBroadcastAndOp(rewriter, op, operands)}); HandleBroadcastAndOp(rewriter, op, operands,
/*min_rank_specialization=*/1)});
return success(); return success();
} }
}; };

View File

@ -207,24 +207,10 @@ func @addUnrankedUnranked(
// CHECK-NEXT: %[[RHS_RANK:.*]] = shape.rank %[[MINIMUM_SHAPES]]#1 : tensor<?xindex> -> index // CHECK-NEXT: %[[RHS_RANK:.*]] = shape.rank %[[MINIMUM_SHAPES]]#1 : tensor<?xindex> -> index
// CHECK-NEXT: %[[LHS_RANK_GREATER:.*]] = cmpi sgt, %[[LHS_RANK]], %[[RHS_RANK]] : index // CHECK-NEXT: %[[LHS_RANK_GREATER:.*]] = cmpi sgt, %[[LHS_RANK]], %[[RHS_RANK]] : index
// CHECK-NEXT: %[[GREATEST_RANK:.*]] = select %[[LHS_RANK_GREATER]], %[[LHS_RANK]], %[[RHS_RANK]] : index // CHECK-NEXT: %[[GREATEST_RANK:.*]] = select %[[LHS_RANK_GREATER]], %[[LHS_RANK]], %[[RHS_RANK]] : index
// Handle rank 1 specialization
// CHECK-NEXT: %[[GREATEST_RANK_IS_1:.*]] = cmpi eq, %[[GREATEST_RANK]], %[[C1]] : index
// CHECK-NEXT: %[[RESULT_RANK_SPECIALIZATION:.*]] = scf.if %[[GREATEST_RANK_IS_1]] -> (tensor<*xf32>) {
// CHECK-NEXT: %[[CONST_SHAPE_1:.*]] = shape.const_shape [1]
// CHECK-NEXT: %[[BROADCASTED_LHS_1:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[CASTED_LHS_1:.*]] = tensor.cast %[[BROADCASTED_LHS_1]] : tensor<?xindex> to tensor<1xindex>
// CHECK-NEXT: %[[RESHAPED_LHS_1:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS_1]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK-NEXT: %[[BROADCASTED_RHS_1:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[CASTED_RHS_1:.*]] = tensor.cast %[[BROADCASTED_RHS_1]] : tensor<?xindex> to tensor<1xindex>
// CHECK-NEXT: %[[RESHAPED_RHS_1:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS_1]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK-NEXT: %[[RESULT_RANK_1:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_1]], %[[RESHAPED_RHS_1]] : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
// CHECK-NEXT: %[[RESULT_1:.*]] = tensor.cast %[[RESULT_RANK_1]] : tensor<?xf32> to tensor<*xf32>
// CHECK-NEXT: scf.yield %[[RESULT_1]] : tensor<*xf32>
// CHECK-NEXT: } else {
// CHECK-NEXT: %[[C2:.*]] = constant 2 : index // CHECK-NEXT: %[[C2:.*]] = constant 2 : index
// CHECK-NEXT: %[[GREATEST_RANK_IS_2:.*]] = cmpi eq, %[[GREATEST_RANK]], %[[C2]] : index // CHECK-NEXT: %[[GREATEST_RANK_IS_2:.*]] = cmpi eq, %[[GREATEST_RANK]], %[[C2]] : index
// Handle rank 2 specialization // Handle rank 2 specialization
// CHECK-NEXT: %[[VAL_26:.*]] = scf.if %[[GREATEST_RANK_IS_2]] -> (tensor<*xf32>) { // CHECK-NEXT: %[[RESULT_RANK_SPECIALIZATION:.*]] = scf.if %[[GREATEST_RANK_IS_2]] -> (tensor<*xf32>) {
// CHECK-NEXT: %[[CONST_SHAPE_2:.*]] = shape.const_shape [1, 1] // CHECK-NEXT: %[[CONST_SHAPE_2:.*]] = shape.const_shape [1, 1]
// CHECK-NEXT: %[[BROADCASTED_LHS_2:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_2]] : tensor<?xindex>, tensor<2xindex> -> tensor<?xindex> // CHECK-NEXT: %[[BROADCASTED_LHS_2:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_2]] : tensor<?xindex>, tensor<2xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[CASTED_LHS_2:.*]] = tensor.cast %[[BROADCASTED_LHS_2]] : tensor<?xindex> to tensor<2xindex> // CHECK-NEXT: %[[CASTED_LHS_2:.*]] = tensor.cast %[[BROADCASTED_LHS_2]] : tensor<?xindex> to tensor<2xindex>
@ -285,8 +271,6 @@ func @addUnrankedUnranked(
// CHECK-NEXT: } // CHECK-NEXT: }
// CHECK-NEXT: scf.yield %[[VAL_67:.*]] : tensor<*xf32> // CHECK-NEXT: scf.yield %[[VAL_67:.*]] : tensor<*xf32>
// CHECK-NEXT: } // CHECK-NEXT: }
// CHECK-NEXT: scf.yield %[[VAL_68:.*]] : tensor<*xf32>
// CHECK-NEXT: }
// CHECK-NEXT: %[[RESHAPED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[RESULT_RANK_SPECIALIZATION]], %[[RESULT_SHAPE]]) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32> // CHECK-NEXT: %[[RESHAPED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[RESULT_RANK_SPECIALIZATION]], %[[RESULT_SHAPE]]) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK-NEXT: scf.yield %[[RESHAPED_RESULT]] : tensor<*xf32> // CHECK-NEXT: scf.yield %[[RESHAPED_RESULT]] : tensor<*xf32>
// CHECK-NEXT: } // CHECK-NEXT: }