Remove rank 1 specialization from TransformUnrankedHloPass.
For binary ops, we already special-case rank 0 vs rank 1, and same shape. So we don't need to special-case a maximum rank of 1. PiperOrigin-RevId: 360881387
This commit is contained in:
parent
50a516fb9c
commit
62b357b601
|
@ -289,7 +289,8 @@ struct ConvertUnrankedDynamicBroadcastOpHelper {
|
||||||
// Iterates over the desired ranks to be specialized and generates the code
|
// Iterates over the desired ranks to be specialized and generates the code
|
||||||
// snippet for each case.
|
// snippet for each case.
|
||||||
static Value HandleBroadcastAndOp(OpBuilder &rewriter, ChloOpTy op,
|
static Value HandleBroadcastAndOp(OpBuilder &rewriter, ChloOpTy op,
|
||||||
ValueRange operands) {
|
ValueRange operands,
|
||||||
|
int min_rank_specialization) {
|
||||||
auto loc = op.getLoc();
|
auto loc = op.getLoc();
|
||||||
|
|
||||||
// Get the minimum broadcast shapes of the operands.
|
// Get the minimum broadcast shapes of the operands.
|
||||||
|
@ -336,18 +337,20 @@ struct ConvertUnrankedDynamicBroadcastOpHelper {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate a list of nested if/else statements to handle rank
|
// Generate a list of nested if/else statements to handle rank
|
||||||
// specializations from 1 to `kMaxRankSpecialization`.
|
// specializations from `min_rank_specialization` to
|
||||||
|
// `kMaxRankSpecialization`.
|
||||||
|
constexpr int kMaxRankSpecialization = 5;
|
||||||
scf::IfOp if_op = createIfOpForRankSpecializedBroadcastAndOp(
|
scf::IfOp if_op = createIfOpForRankSpecializedBroadcastAndOp(
|
||||||
rewriter, op, greater_rank, 1);
|
rewriter, op, greater_rank, min_rank_specialization);
|
||||||
OpBuilder if_builder = if_op.getThenBodyBuilder(rewriter.getListener());
|
OpBuilder if_builder = if_op.getThenBodyBuilder(rewriter.getListener());
|
||||||
createRankSpecializedBroadcastAndOp(if_builder, op, reshaped_operands,
|
createRankSpecializedBroadcastAndOp(if_builder, op, reshaped_operands,
|
||||||
reduced_shapes, 1);
|
reduced_shapes,
|
||||||
|
min_rank_specialization);
|
||||||
|
|
||||||
// Put each subsequent rank specialization inside the else statement of the
|
// Put each subsequent rank specialization inside the else statement of the
|
||||||
// previous one.
|
// previous one.
|
||||||
OpBuilder else_builder = if_op.getElseBodyBuilder(rewriter.getListener());
|
OpBuilder else_builder = if_op.getElseBodyBuilder(rewriter.getListener());
|
||||||
constexpr int kMaxRankSpecialization = 5;
|
for (int i = min_rank_specialization + 1; i < kMaxRankSpecialization; i++) {
|
||||||
for (int i = 2; i < kMaxRankSpecialization; i++) {
|
|
||||||
auto inner_if = createIfOpForRankSpecializedBroadcastAndOp(
|
auto inner_if = createIfOpForRankSpecializedBroadcastAndOp(
|
||||||
else_builder, op, greater_rank, i);
|
else_builder, op, greater_rank, i);
|
||||||
if_builder = inner_if.getThenBodyBuilder(rewriter.getListener());
|
if_builder = inner_if.getThenBodyBuilder(rewriter.getListener());
|
||||||
|
@ -468,13 +471,15 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp
|
||||||
|
|
||||||
// If shapes do not have exactly one element, nor are equal
|
// If shapes do not have exactly one element, nor are equal
|
||||||
//
|
//
|
||||||
// See if values are of a rank that we support.
|
// See if values are of a rank that we support. We already handle cases
|
||||||
|
// where one of the operands is a scalar, or both are equal. So there should
|
||||||
|
// be no case left where both have rank 1.
|
||||||
OpBuilder if_neq_shapes_builder =
|
OpBuilder if_neq_shapes_builder =
|
||||||
if_eq_shapes_op.getElseBodyBuilder(rewriter.getListener());
|
if_eq_shapes_op.getElseBodyBuilder(rewriter.getListener());
|
||||||
if_neq_shapes_builder.create<scf::YieldOp>(
|
if_neq_shapes_builder.create<scf::YieldOp>(
|
||||||
loc, ConvertUnrankedDynamicBroadcastOpHelper<
|
loc, ConvertUnrankedDynamicBroadcastOpHelper<ChloOpTy, HloOpTy>::
|
||||||
ChloOpTy, HloOpTy>::HandleBroadcastAndOp(if_neq_shapes_builder,
|
HandleBroadcastAndOp(if_neq_shapes_builder, op, {lhs, rhs},
|
||||||
op, {lhs, rhs}));
|
/*min_rank_specialization=*/2));
|
||||||
|
|
||||||
rewriter.replaceOp(op, {if_op.getResult(0)});
|
rewriter.replaceOp(op, {if_op.getResult(0)});
|
||||||
return success();
|
return success();
|
||||||
|
@ -518,9 +523,10 @@ struct ConvertUnrankedDynamicBroadcastSelectOp
|
||||||
// more potential for optimization here. This also is missing the
|
// more potential for optimization here. This also is missing the
|
||||||
// specialization for rank 0.
|
// specialization for rank 0.
|
||||||
rewriter.replaceOp(
|
rewriter.replaceOp(
|
||||||
op, {ConvertUnrankedDynamicBroadcastOpHelper<
|
op, {ConvertUnrankedDynamicBroadcastOpHelper<chlo::BroadcastSelectOp,
|
||||||
chlo::BroadcastSelectOp,
|
mhlo::SelectOp>::
|
||||||
mhlo::SelectOp>::HandleBroadcastAndOp(rewriter, op, operands)});
|
HandleBroadcastAndOp(rewriter, op, operands,
|
||||||
|
/*min_rank_specialization=*/1)});
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -207,24 +207,10 @@ func @addUnrankedUnranked(
|
||||||
// CHECK-NEXT: %[[RHS_RANK:.*]] = shape.rank %[[MINIMUM_SHAPES]]#1 : tensor<?xindex> -> index
|
// CHECK-NEXT: %[[RHS_RANK:.*]] = shape.rank %[[MINIMUM_SHAPES]]#1 : tensor<?xindex> -> index
|
||||||
// CHECK-NEXT: %[[LHS_RANK_GREATER:.*]] = cmpi sgt, %[[LHS_RANK]], %[[RHS_RANK]] : index
|
// CHECK-NEXT: %[[LHS_RANK_GREATER:.*]] = cmpi sgt, %[[LHS_RANK]], %[[RHS_RANK]] : index
|
||||||
// CHECK-NEXT: %[[GREATEST_RANK:.*]] = select %[[LHS_RANK_GREATER]], %[[LHS_RANK]], %[[RHS_RANK]] : index
|
// CHECK-NEXT: %[[GREATEST_RANK:.*]] = select %[[LHS_RANK_GREATER]], %[[LHS_RANK]], %[[RHS_RANK]] : index
|
||||||
// Handle rank 1 specialization
|
|
||||||
// CHECK-NEXT: %[[GREATEST_RANK_IS_1:.*]] = cmpi eq, %[[GREATEST_RANK]], %[[C1]] : index
|
|
||||||
// CHECK-NEXT: %[[RESULT_RANK_SPECIALIZATION:.*]] = scf.if %[[GREATEST_RANK_IS_1]] -> (tensor<*xf32>) {
|
|
||||||
// CHECK-NEXT: %[[CONST_SHAPE_1:.*]] = shape.const_shape [1]
|
|
||||||
// CHECK-NEXT: %[[BROADCASTED_LHS_1:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
|
|
||||||
// CHECK-NEXT: %[[CASTED_LHS_1:.*]] = tensor.cast %[[BROADCASTED_LHS_1]] : tensor<?xindex> to tensor<1xindex>
|
|
||||||
// CHECK-NEXT: %[[RESHAPED_LHS_1:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS_1]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
|
||||||
// CHECK-NEXT: %[[BROADCASTED_RHS_1:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
|
|
||||||
// CHECK-NEXT: %[[CASTED_RHS_1:.*]] = tensor.cast %[[BROADCASTED_RHS_1]] : tensor<?xindex> to tensor<1xindex>
|
|
||||||
// CHECK-NEXT: %[[RESHAPED_RHS_1:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS_1]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
|
||||||
// CHECK-NEXT: %[[RESULT_RANK_1:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_1]], %[[RESHAPED_RHS_1]] : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
|
||||||
// CHECK-NEXT: %[[RESULT_1:.*]] = tensor.cast %[[RESULT_RANK_1]] : tensor<?xf32> to tensor<*xf32>
|
|
||||||
// CHECK-NEXT: scf.yield %[[RESULT_1]] : tensor<*xf32>
|
|
||||||
// CHECK-NEXT: } else {
|
|
||||||
// CHECK-NEXT: %[[C2:.*]] = constant 2 : index
|
// CHECK-NEXT: %[[C2:.*]] = constant 2 : index
|
||||||
// CHECK-NEXT: %[[GREATEST_RANK_IS_2:.*]] = cmpi eq, %[[GREATEST_RANK]], %[[C2]] : index
|
// CHECK-NEXT: %[[GREATEST_RANK_IS_2:.*]] = cmpi eq, %[[GREATEST_RANK]], %[[C2]] : index
|
||||||
// Handle rank 2 specialization
|
// Handle rank 2 specialization
|
||||||
// CHECK-NEXT: %[[VAL_26:.*]] = scf.if %[[GREATEST_RANK_IS_2]] -> (tensor<*xf32>) {
|
// CHECK-NEXT: %[[RESULT_RANK_SPECIALIZATION:.*]] = scf.if %[[GREATEST_RANK_IS_2]] -> (tensor<*xf32>) {
|
||||||
// CHECK-NEXT: %[[CONST_SHAPE_2:.*]] = shape.const_shape [1, 1]
|
// CHECK-NEXT: %[[CONST_SHAPE_2:.*]] = shape.const_shape [1, 1]
|
||||||
// CHECK-NEXT: %[[BROADCASTED_LHS_2:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_2]] : tensor<?xindex>, tensor<2xindex> -> tensor<?xindex>
|
// CHECK-NEXT: %[[BROADCASTED_LHS_2:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_2]] : tensor<?xindex>, tensor<2xindex> -> tensor<?xindex>
|
||||||
// CHECK-NEXT: %[[CASTED_LHS_2:.*]] = tensor.cast %[[BROADCASTED_LHS_2]] : tensor<?xindex> to tensor<2xindex>
|
// CHECK-NEXT: %[[CASTED_LHS_2:.*]] = tensor.cast %[[BROADCASTED_LHS_2]] : tensor<?xindex> to tensor<2xindex>
|
||||||
|
@ -285,8 +271,6 @@ func @addUnrankedUnranked(
|
||||||
// CHECK-NEXT: }
|
// CHECK-NEXT: }
|
||||||
// CHECK-NEXT: scf.yield %[[VAL_67:.*]] : tensor<*xf32>
|
// CHECK-NEXT: scf.yield %[[VAL_67:.*]] : tensor<*xf32>
|
||||||
// CHECK-NEXT: }
|
// CHECK-NEXT: }
|
||||||
// CHECK-NEXT: scf.yield %[[VAL_68:.*]] : tensor<*xf32>
|
|
||||||
// CHECK-NEXT: }
|
|
||||||
// CHECK-NEXT: %[[RESHAPED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[RESULT_RANK_SPECIALIZATION]], %[[RESULT_SHAPE]]) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
|
// CHECK-NEXT: %[[RESHAPED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[RESULT_RANK_SPECIALIZATION]], %[[RESULT_SHAPE]]) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||||
// CHECK-NEXT: scf.yield %[[RESHAPED_RESULT]] : tensor<*xf32>
|
// CHECK-NEXT: scf.yield %[[RESHAPED_RESULT]] : tensor<*xf32>
|
||||||
// CHECK-NEXT: }
|
// CHECK-NEXT: }
|
||||||
|
|
Loading…
Reference in New Issue