Add special cases for SelectOp rank specialization.

We now use the same special cases for all ops with arity >= 2.
For binary ops, we now have only one special case if at least one of the
operands has exactly one element. In that case, we reshape both operands to
rank 1. Before, we had separate special cases whether the left-hand side
or the right-hand side have a scalar shape.

PiperOrigin-RevId: 366005835
This commit is contained in:
Adrian Kuegel 2021-03-31 04:28:00 -07:00 committed by TensorFlow MLIR Team
parent 9206805c58
commit 4033a56750
2 changed files with 287 additions and 245 deletions

View File

@ -183,7 +183,9 @@ struct ConvertUnrankedScalarDynamicBroadcastBinaryOp
Value size_tensor =
rewriter.create<tensor::FromElementsOp>(loc, num_elements);
Value reshaped = rewriter.create<mhlo::DynamicReshapeOp>(
loc, RankedTensorType::get({-1}, scalar_element_type),
loc,
RankedTensorType::get({RankedTensorType::kDynamicSize},
scalar_element_type),
lhs_is_scalar ? rhs : lhs, size_tensor);
// Create a new ranked Chlo op that will be further lowered by other
@ -191,7 +193,9 @@ struct ConvertUnrankedScalarDynamicBroadcastBinaryOp
SmallVector<Value, 2> new_operands{lhs_is_scalar ? lhs : reshaped,
rhs_is_scalar ? rhs : reshaped};
Value computed = rewriter.create<ChloOpTy>(
loc, TypeRange{RankedTensorType::get({-1}, result_element_type)},
loc,
TypeRange{RankedTensorType::get({RankedTensorType::kDynamicSize},
result_element_type)},
new_operands, op->getAttrs());
// Reshape the result back into an unranked tensor.
@ -202,7 +206,7 @@ struct ConvertUnrankedScalarDynamicBroadcastBinaryOp
}
};
template <typename ChloOpTy, typename HloOpTy>
template <typename ChloOpTy>
struct ConvertUnrankedDynamicBroadcastOpHelper {
// Returns the dynamic result of checking the given value is effectively a
// scalar shape (i.e. the number of elements is 1).
@ -378,14 +382,15 @@ struct ConvertUnrankedDynamicBroadcastOpHelper {
// Handles lowering of the following pattern to patterns that will be further
// matched by other patterns until they result in LHLO:
// %result = "chlo.op"(%lhs, %rhs) : (<*xTy>, <*xTy>) -> <*xTy>
// %result = "chlo.op"(%op0, %op1, ...) : (<*xTy>, <*xTy>, ...) -> <*xTy>
//
// The sequence of specializations this handles is:
// - Either operand being scalar
// - Operands having equal shapes
// - The resulting value being any of ranks [2,6]
// - At most one operand has a shape that does not consist of exactly one
// element.
// - All operands having equal shapes
// - The resulting minimized shapes being any of ranks [1,5]
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
struct ConvertUnrankedDynamicBroadcastBinaryOp
struct ConvertUnrankedDynamicBroadcastNaryOp
: public OpConversionPattern<ChloOpTy> {
using OpConversionPattern<ChloOpTy>::OpConversionPattern;
@ -394,76 +399,96 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
typename ChloOpTy::Adaptor transformed(operands);
Value lhs = transformed.lhs();
Value rhs = transformed.rhs();
auto lhs_type = lhs.getType().dyn_cast<UnrankedTensorType>();
auto rhs_type = rhs.getType().dyn_cast<UnrankedTensorType>();
ValueRange transformed_operands = transformed.getOperands();
auto num_operands = transformed_operands.size();
llvm::SmallVector<UnrankedTensorType, 3> operand_types;
operand_types.reserve(num_operands);
for (int i = 0; i < num_operands; ++i) {
auto type =
transformed_operands[i].getType().dyn_cast<UnrankedTensorType>();
// Only support unranked operands.
if (!type) return failure();
operand_types.push_back(type);
}
auto result_type = op.getResult().getType().template dyn_cast<TensorType>();
// Only support unranked operands. If either operand is ranked, another
// pattern will handle the lowering.
if (!lhs_type || !rhs_type) return failure();
llvm::SmallVector<Value> shapes;
shapes.reserve(num_operands);
for (int i = 0; i < num_operands; ++i) {
shapes.push_back(
rewriter.create<shape::ShapeOfOp>(loc, transformed_operands[i]));
}
Value shape_of_lhs = rewriter.create<shape::ShapeOfOp>(loc, lhs);
Value shape_of_rhs = rewriter.create<shape::ShapeOfOp>(loc, rhs);
// If at most one shape does not have exactly one element
Value counter = rewriter.create<ConstantIndexOp>(loc, 0);
Value one = rewriter.create<ConstantIndexOp>(loc, 1);
for (int i = 0; i < num_operands; ++i) {
Value is_scalar_like = IsSingleElementShape(rewriter, op, shapes[i]);
Value counter_plus_one = rewriter.create<AddIOp>(loc, counter, one);
counter = rewriter.create<SelectOp>(loc, is_scalar_like, counter_plus_one,
counter);
}
Value num_operands_minus_one =
rewriter.create<ConstantIndexOp>(loc, num_operands - 1);
Value at_most_one_non_scalar =
rewriter.create<CmpIOp>(loc, rewriter.getI1Type(), CmpIPredicate::uge,
counter, num_operands_minus_one);
// If lhs has exactly one element
auto if_op = rewriter.create<scf::IfOp>(
loc, result_type, IsSingleElementShape(rewriter, op, shape_of_lhs),
true);
OpBuilder if_lhs_scalar_builder =
auto if_op = rewriter.create<scf::IfOp>(loc, result_type,
at_most_one_non_scalar, true);
OpBuilder if_at_most_one_non_scalar_builder =
if_op.getThenBodyBuilder(rewriter.getListener());
Value reshaped_lhs = if_lhs_scalar_builder.create<mhlo::ReshapeOp>(
loc, RankedTensorType::get({}, lhs_type.getElementType()), lhs);
Value if_lhs_scalar_result = if_lhs_scalar_builder.create<ChloOpTy>(
loc, ArrayRef<Type>{result_type}, ArrayRef<Value>{reshaped_lhs, rhs},
llvm::SmallVector<Value, 3> reshaped_operands;
reshaped_operands.reserve(num_operands);
for (int i = 0; i < num_operands; ++i) {
Value num_elements =
if_at_most_one_non_scalar_builder.create<shape::NumElementsOp>(
loc, shapes[i]);
Value size_tensor =
if_at_most_one_non_scalar_builder.create<tensor::FromElementsOp>(
loc, num_elements);
Value reshaped =
if_at_most_one_non_scalar_builder.create<mhlo::DynamicReshapeOp>(
loc,
RankedTensorType::get({RankedTensorType::kDynamicSize},
operand_types[i].getElementType()),
transformed_operands[i], size_tensor);
reshaped_operands.push_back(reshaped);
}
auto rank_one_result_type = RankedTensorType::get(
{RankedTensorType::kDynamicSize}, result_type.getElementType());
Value if_at_most_one_non_scalar_result =
if_at_most_one_non_scalar_builder.create<ChloOpTy>(
loc, ArrayRef<Type>{rank_one_result_type}, reshaped_operands,
op->getAttrs());
Value extended_if_lhs_scalar_result =
extendToBroadcastShape(if_lhs_scalar_builder, loc, if_lhs_scalar_result,
shape_of_lhs, shape_of_rhs);
if_lhs_scalar_builder.create<scf::YieldOp>(loc,
extended_if_lhs_scalar_result);
Value extended_result = extendToBroadcastShape(
if_at_most_one_non_scalar_builder, loc, result_type,
if_at_most_one_non_scalar_result, shapes);
if_at_most_one_non_scalar_builder.create<scf::YieldOp>(loc,
extended_result);
// If lhs does not have exactly one element
// If there is more than one shape which does not have exactly one element
//
// See if rhs has exactly one element
OpBuilder else_lhs_scalar_builder =
if_op.getElseBodyBuilder(rewriter.getListener());
auto if_rhs_scalar_op = else_lhs_scalar_builder.create<scf::IfOp>(
loc, result_type,
IsSingleElementShape(else_lhs_scalar_builder, op, shape_of_rhs), true);
else_lhs_scalar_builder.create<scf::YieldOp>(loc,
if_rhs_scalar_op.getResult(0));
OpBuilder if_rhs_scalar_builder =
if_rhs_scalar_op.getThenBodyBuilder(rewriter.getListener());
Value reshaped_rhs = if_rhs_scalar_builder.create<mhlo::ReshapeOp>(
loc, RankedTensorType::get({}, rhs_type.getElementType()), rhs);
Value if_rhs_scalar_result = if_rhs_scalar_builder.create<ChloOpTy>(
loc, ArrayRef<Type>{result_type}, ArrayRef<Value>{lhs, reshaped_rhs},
op->getAttrs());
Value extended_if_rhs_scalar_result =
extendToBroadcastShape(if_rhs_scalar_builder, loc, if_rhs_scalar_result,
shape_of_lhs, shape_of_rhs);
if_rhs_scalar_builder.create<scf::YieldOp>(loc,
extended_if_rhs_scalar_result);
// See if all shapes are equal.
OpBuilder else_builder = if_op.getElseBodyBuilder(rewriter.getListener());
Value equal_shapes =
else_builder.create<shape::ShapeEqOp>(loc, shapes[0], shapes[1]);
for (int i = 2; i < num_operands; ++i) {
Value are_equal =
else_builder.create<shape::ShapeEqOp>(loc, shapes[0], shapes[i]);
equal_shapes = else_builder.create<AndOp>(loc, equal_shapes, are_equal);
}
// If NEITHER shape has exactly one element
//
// See if shapes are equal.
OpBuilder else_no_scalars_builder =
if_rhs_scalar_op.getElseBodyBuilder(rewriter.getListener());
Value equal_shapes = else_no_scalars_builder.create<shape::ShapeEqOp>(
loc, shape_of_lhs, shape_of_rhs);
auto if_eq_shapes_op = else_no_scalars_builder.create<scf::IfOp>(
loc, result_type, equal_shapes, true);
else_no_scalars_builder.create<scf::YieldOp>(loc,
if_eq_shapes_op.getResult(0));
auto if_eq_shapes_op =
else_builder.create<scf::IfOp>(loc, result_type, equal_shapes, true);
else_builder.create<scf::YieldOp>(loc, if_eq_shapes_op.getResult(0));
OpBuilder if_eq_shapes_builder =
if_eq_shapes_op.getThenBodyBuilder(rewriter.getListener());
Value non_broadcast_op =
Adaptor::CreateOp(op, result_type, {lhs, rhs}, if_eq_shapes_builder);
Value non_broadcast_op = Adaptor::CreateOp(
op, result_type, transformed_operands, if_eq_shapes_builder);
if_eq_shapes_builder.create<scf::YieldOp>(loc, non_broadcast_op);
// If shapes do not have exactly one element, nor are equal
@ -472,9 +497,9 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp
OpBuilder if_neq_shapes_builder =
if_eq_shapes_op.getElseBodyBuilder(rewriter.getListener());
if_neq_shapes_builder.create<scf::YieldOp>(
loc, ConvertUnrankedDynamicBroadcastOpHelper<
ChloOpTy, HloOpTy>::HandleBroadcastAndOp(if_neq_shapes_builder,
op, {lhs, rhs}));
loc,
ConvertUnrankedDynamicBroadcastOpHelper<ChloOpTy>::HandleBroadcastAndOp(
if_neq_shapes_builder, op, transformed_operands));
rewriter.replaceOp(op, {if_op.getResult(0)});
return success();
@ -494,37 +519,18 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp
rewriter.create<ConstantIndexOp>(loc, 1));
}
Value extendToBroadcastShape(OpBuilder &builder, Location loc, Value value,
Value shape_of_lhs, Value shape_of_rhs) const {
Value extendToBroadcastShape(OpBuilder &builder, Location loc,
Type result_type, Value value,
ValueRange shapes) const {
auto unknown_rank_extent_tensor_type = RankedTensorType::get(
{RankedTensorType::kDynamicSize}, builder.getIndexType());
Value broadcast_shape =
builder.create<shape::BroadcastOp>(loc, unknown_rank_extent_tensor_type,
shape_of_lhs, shape_of_rhs, nullptr);
return builder.create<mhlo::DynamicReshapeOp>(loc, value.getType(), value,
Value broadcast_shape = builder.create<shape::BroadcastOp>(
loc, unknown_rank_extent_tensor_type, shapes, nullptr);
return builder.create<mhlo::DynamicReshapeOp>(loc, result_type, value,
broadcast_shape);
}
};
// Rank-specialize chlo.broadcast_select ops.
struct ConvertUnrankedDynamicBroadcastSelectOp
: public OpConversionPattern<chlo::BroadcastSelectOp> {
using OpConversionPattern<chlo::BroadcastSelectOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
chlo::BroadcastSelectOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
// For now only do the bare minimum and specialize for every rank. There is
// more potential for optimization here. This also is missing the
// specialization for rank 0.
rewriter.replaceOp(
op, {ConvertUnrankedDynamicBroadcastOpHelper<
chlo::BroadcastSelectOp,
mhlo::SelectOp>::HandleBroadcastAndOp(rewriter, op, operands)});
return success();
}
};
struct TransformUnrankedHloPass
: public PassWrapper<TransformUnrankedHloPass, FunctionPass> {
void getDependentDialects(DialectRegistry &registry) const override {
@ -588,11 +594,14 @@ void PopulateTransformUnrankedHloPatterns(MLIRContext *context,
#undef MAP_HLO
#undef MAP_CHLO
#undef COMMA
chlo::PopulateForBroadcastingBinaryOp<
ConvertUnrankedDynamicBroadcastBinaryOp>(context, patterns);
chlo::PopulateForBroadcastingBinaryOp<ConvertUnrankedDynamicBroadcastNaryOp>(
context, patterns);
patterns->insert<ConvertUnrankedDynamicBroadcastNaryOp<
chlo::BroadcastSelectOp, mhlo::SelectOp,
chlo::HloNaryElementwiseAdaptor<chlo::BroadcastSelectOp,
mhlo::SelectOp>>>(context);
chlo::PopulateForBroadcastingBinaryOp<
ConvertUnrankedScalarDynamicBroadcastBinaryOp>(context, patterns);
patterns->insert<ConvertUnrankedDynamicBroadcastSelectOp>(context);
}
std::unique_ptr<FunctionPass> createTransformUnrankedHloPass() {

View File

@ -159,33 +159,27 @@ func @addUnrankedUnranked(
// CHECK-SAME: %[[RHS:.*]]: tensor<*xf32>) -> tensor<*xf32> {
// CHECK-NEXT: %[[LHS_SHAPE:.*]] = shape.shape_of %[[LHS]] : tensor<*xf32> -> tensor<?xindex>
// CHECK-NEXT: %[[RHS_SHAPE:.*]] = shape.shape_of %[[RHS]] : tensor<*xf32> -> tensor<?xindex>
// CHECK-NEXT: %[[NUM_LHS:.*]] = shape.num_elements %[[LHS_SHAPE]] : tensor<?xindex> -> index
// CHECK-NEXT: %[[C0:.*]] = constant 0 : index
// CHECK-NEXT: %[[C1:.*]] = constant 1 : index
// CHECK-NEXT: %[[NUM_LHS:.*]] = shape.num_elements %[[LHS_SHAPE]] : tensor<?xindex> -> index
// CHECK-NEXT: %[[LHS_IS_SCALAR:.*]] = cmpi eq, %[[NUM_LHS]], %[[C1]] : index
// Handle scalar LHS case
// CHECK-NEXT: %[[VAL_8:.*]] = scf.if %[[LHS_IS_SCALAR]] -> (tensor<*xf32>) {
// CHECK-NEXT: %[[SCALAR_LHS:.*]] = "mhlo.reshape"(%[[LHS]]) : (tensor<*xf32>) -> tensor<f32>
// CHECK-NEXT: %[[NUM_RHS:.*]] = shape.num_elements %[[RHS_SHAPE]] : tensor<?xindex> -> index
// CHECK-NEXT: %[[NUM_TENS_RHS:.*]] = tensor.from_elements %[[NUM_RHS]] : tensor<1xindex>
// CHECK-NEXT: %[[RESHAPED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[NUM_TENS_RHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK-NEXT: %[[LHS_SCALAR_RESULT:.*]] = chlo.broadcast_add %[[SCALAR_LHS]], %[[RESHAPED_RHS]] : (tensor<f32>, tensor<?xf32>) -> tensor<?xf32>
// CHECK-NEXT: %[[RESHAPED_LHS_SCALAR_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[LHS_SCALAR_RESULT]], %[[RHS_SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK-NEXT: %[[SHAPE_BROADCAST_LHS:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[RESHAPED_EXTENDED_LHS_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[RESHAPED_LHS_SCALAR_RESULT]], %[[SHAPE_BROADCAST_LHS]]) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK-NEXT: scf.yield %[[RESHAPED_EXTENDED_LHS_RESULT]] : tensor<*xf32>
// CHECK-NEXT: } else {
// CHECK-NEXT: %[[COUNTER_PLUS_ONE:.*]] = addi %[[C0]], %[[C1]] : index
// CHECK-NEXT: %[[COUNTER:.*]] = select %[[LHS_IS_SCALAR]], %[[COUNTER_PLUS_ONE]], %[[C0]] : index
// CHECK-NEXT: %[[NUM_RHS:.*]] = shape.num_elements %[[RHS_SHAPE]] : tensor<?xindex> -> index
// CHECK-NEXT: %[[RHS_IS_SCALAR:.*]] = cmpi eq, %[[NUM_RHS]], %[[C1]] : index
// Handle scalar RHS case
// CHECK-NEXT: %[[VAL_14:.*]] = scf.if %[[RHS_IS_SCALAR]] -> (tensor<*xf32>) {
// CHECK-NEXT: %[[SCALAR_RHS:.*]] = "mhlo.reshape"(%[[RHS]]) : (tensor<*xf32>) -> tensor<f32>
// CHECK-NEXT: %[[COUNTER_PLUS_ONE2:.*]] = addi %[[COUNTER]], %[[C1]] : index
// CHECK-NEXT: %[[COUNTER2:.*]] = select %[[RHS_IS_SCALAR]], %[[COUNTER_PLUS_ONE2]], %[[COUNTER]] : index
// Handle scalar case
// CHECK-NEXT: %[[IS_SCALAR_CASE:.*]] = cmpi uge, %[[COUNTER2]], %[[C1]] : index
// CHECK-NEXT: %[[VAL_8:.*]] = scf.if %[[IS_SCALAR_CASE]] -> (tensor<*xf32>) {
// CHECK-NEXT: %[[NUM_TENS_LHS:.*]] = tensor.from_elements %[[NUM_LHS]] : tensor<1xindex>
// CHECK-NEXT: %[[RESHAPED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[NUM_TENS_LHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK-NEXT: %[[RHS_SCALAR_RESULT:.*]] = chlo.broadcast_add %[[RESHAPED_LHS]], %[[SCALAR_RHS]] : (tensor<?xf32>, tensor<f32>) -> tensor<?xf32>
// CHECK-NEXT: %[[RESHAPED_RHS_SCALAR_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[RHS_SCALAR_RESULT:.*]], %[[LHS_SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK-NEXT: %[[SHAPE_BROADCAST_RHS:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[RESHAPED_EXTENDED_RHS_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[RESHAPED_RHS_SCALAR_RESULT]], %[[SHAPE_BROADCAST_RHS]]) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK-NEXT: scf.yield %[[RESHAPED_EXTENDED_RHS_RESULT]] : tensor<*xf32>
// CHECK-NEXT: %[[NUM_TENS_RHS:.*]] = tensor.from_elements %[[NUM_RHS]] : tensor<1xindex>
// CHECK-NEXT: %[[RESHAPED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[NUM_TENS_RHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK-NEXT: %[[SCALAR_RESULT:.*]] = chlo.broadcast_add %[[RESHAPED_LHS]], %[[RESHAPED_RHS]] : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
// CHECK-NEXT: %[[SHAPE_BROADCAST:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[RESHAPED_EXTENDED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[SCALAR_RESULT]], %[[SHAPE_BROADCAST]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK-NEXT: scf.yield %[[RESHAPED_EXTENDED_RESULT]] : tensor<*xf32>
// CHECK-NEXT: } else {
// CHECK-NEXT: %[[SHAPES_EQ:.*]] = shape.shape_eq %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex>
// Handle equal shapes case
@ -290,8 +284,6 @@ func @addUnrankedUnranked(
// CHECK-NEXT: %[[RESHAPED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[RESULT_RANK_SPECIALIZATION]], %[[RESULT_SHAPE]]) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK-NEXT: scf.yield %[[RESHAPED_RESULT]] : tensor<*xf32>
// CHECK-NEXT: }
// CHECK-NEXT: scf.yield %[[VAL_70:.*]] : tensor<*xf32>
// CHECK-NEXT: }
// CHECK-NEXT: scf.yield %[[VAL_71:.*]] : tensor<*xf32>
// CHECK-NEXT: }
// CHECK-NEXT: return %[[VAL_72:.*]] : tensor<*xf32>
@ -315,6 +307,48 @@ func @selectUnrankedUnrankedUnranked(
// CHECK-NEXT: %[[PRED_SHAPE:.*]] = shape.shape_of %[[PRED]] : tensor<*xi1> -> tensor<?xindex>
// CHECK-NEXT: %[[LHS_SHAPE:.*]] = shape.shape_of %[[LHS]] : tensor<*xf32> -> tensor<?xindex>
// CHECK-NEXT: %[[RHS_SHAPE:.*]] = shape.shape_of %[[RHS]] : tensor<*xf32> -> tensor<?xindex>
// CHECK-NEXT: %c0 = constant 0 : index
// CHECK-NEXT: %c1 = constant 1 : index
// CHECK-NEXT: %[[NUM_ELEMENTS_PRED:.*]] = shape.num_elements %[[PRED_SHAPE]] : tensor<?xindex> -> index
// CHECK-NEXT: %[[PRED_IS_SCALAR:.*]] = cmpi eq, %[[NUM_ELEMENTS_PRED]], %c1 : index
// CHECK-NEXT: %[[COUNTER_PLUS_ONE:.*]] = addi %c0, %c1 : index
// CHECK-NEXT: %[[COUNTER:.*]] = select %[[PRED_IS_SCALAR]], %[[COUNTER_PLUS_ONE]], %c0 : index
// CHECK-NEXT: %[[NUM_ELEMENTS_LHS:.*]] = shape.num_elements %[[LHS_SHAPE]] : tensor<?xindex> -> index
// CHECK-NEXT: %[[LHS_IS_SCALAR:.*]] = cmpi eq, %[[NUM_ELEMENTS_LHS]], %c1 : index
// CHECK-NEXT: %[[COUNTER_PLUS_ONE:.*]] = addi %[[COUNTER]], %c1 : index
// CHECK-NEXT: %[[COUNTER2:.*]] = select %[[LHS_IS_SCALAR]], %[[COUNTER_PLUS_ONE]], %[[COUNTER]] : index
// CHECK-NEXT: %[[NUM_ELEMENTS_RHS:.*]] = shape.num_elements %[[RHS_SHAPE]] : tensor<?xindex> -> index
// CHECK-NEXT: %[[RHS_IS_SCALAR:.*]] = cmpi eq, %[[NUM_ELEMENTS_RHS]], %c1 : index
// CHECK-NEXT: %[[COUNTER_PLUS_ONE:.*]] = addi %[[COUNTER2]], %c1 : index
// CHECK-NEXT: %[[COUNTER3:.*]] = select %[[RHS_IS_SCALAR]], %[[COUNTER_PLUS_ONE]], %[[COUNTER2]] : index
// CHECK-NEXT: %c2 = constant 2 : index
// CHECK-NEXT: %[[IS_SCALAR_CASE:.*]] = cmpi uge, %[[COUNTER3]], %c2 : index
// CHECK-NEXT: %[[IF_IS_SCALAR_CASE:.*]] = scf.if %[[IS_SCALAR_CASE]] -> (tensor<*xf32>) {
// CHECK-NEXT: %[[NUM_TENS_PRED:.*]] = tensor.from_elements %[[NUM_ELEMENTS_PRED]] : tensor<1xindex>
// CHECK-NEXT: %[[RESHAPED_PRED:.*]] = "mhlo.dynamic_reshape"(%[[PRED]], %[[NUM_TENS_PRED]]) : (tensor<*xi1>, tensor<1xindex>) -> tensor<?xi1>
// CHECK-NEXT: %[[NUM_TENS_LHS:.*]] = tensor.from_elements %[[NUM_ELEMENTS_LHS]] : tensor<1xindex>
// CHECK-NEXT: %[[RESHAPED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[NUM_TENS_LHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK-NEXT: %[[NUM_TENS_RHS:.*]] = tensor.from_elements %[[NUM_ELEMENTS_RHS]] : tensor<1xindex>
// CHECK-NEXT: %[[RESHAPED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[NUM_TENS_RHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK-NEXT: %[[SCALAR_RESULT:.*]] = chlo.broadcast_select %[[RESHAPED_PRED]], %[[RESHAPED_LHS]], %[[RESHAPED_RHS]] : (tensor<?xi1>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
// CHECK-NEXT: %[[SHAPE_BROADCAST:.*]] = shape.broadcast %[[PRED_SHAPE]], %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[RESHAPED_EXTENDED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[SCALAR_RESULT]], %[[SHAPE_BROADCAST]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK-NEXT: scf.yield %[[RESHAPED_EXTENDED_RESULT]] : tensor<*xf32>
// CHECK-NEXT: } else {
// CHECK-NEXT: %[[FIRST_SHAPES_EQUAL:.*]] = shape.shape_eq %[[PRED_SHAPE]], %[[LHS_SHAPE]] : tensor<?xindex>, tensor<?xindex>
// CHECK-NEXT: %[[SECOND_SHAPES_EQUAL:.*]] = shape.shape_eq %[[PRED_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex>
// CHECK-NEXT: %[[ALL_SHAPES_EQUAL:.*]] = and %[[FIRST_SHAPES_EQUAL]], %[[SECOND_SHAPES_EQUAL]] : i1
// CHECK-NEXT: %[[IF_EQUAL_CASE:.*]] = scf.if %[[ALL_SHAPES_EQUAL]] -> (tensor<*xf32>) {
// CHECK-NEXT: %[[ANY_SHAPE:.*]] = shape.any %[[PRED_SHAPE]], %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[ANY_NUM:.*]] = shape.num_elements %[[ANY_SHAPE]] : tensor<?xindex> -> index
// CHECK-NEXT: %[[ANY_TENSOR:.*]] = tensor.from_elements %[[ANY_NUM]] : tensor<1xindex>
// CHECK-NEXT: %[[FLATTENED_PRED:.*]] = "mhlo.dynamic_reshape"(%[[PRED]], %[[ANY_TENSOR]]) : (tensor<*xi1>, tensor<1xindex>) -> tensor<?xi1>
// CHECK-NEXT: %[[FLATTENED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[ANY_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK-NEXT: %[[FLATTENED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[ANY_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK-NEXT: %[[FLATTENED_RESULT:.*]] = "mhlo.select"(%[[FLATTENED_PRED]], %[[FLATTENED_LHS]], %[[FLATTENED_RHS]]) : (tensor<?xi1>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
// CHECK-NEXT: %[[RESHAPED_SAME_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[FLATTENED_RESULT]], %[[ANY_SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK-NEXT: scf.yield %[[RESHAPED_SAME_RESULT]] : tensor<*xf32>
// CHECK-NEXT: } else {
// CHECK-NEXT: %[[RESULT_SHAPE:.*]] = shape.broadcast %[[PRED_SHAPE]], %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[MINIMUM_SHAPES:.*]]:3 = chlo.minimum_broadcast_shapes %[[PRED_SHAPE]], %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>, tensor<?xindex>, tensor<?xindex>
// CHECK-NEXT: %[[MINIMUM_RESHAPED_PRED:.*]] = "mhlo.dynamic_reshape"(%[[PRED]], %[[MINIMUM_SHAPES]]#0) : (tensor<*xi1>, tensor<?xindex>) -> tensor<*xi1>
@ -327,7 +361,6 @@ func @selectUnrankedUnrankedUnranked(
// CHECK-NEXT: %[[RHS_RANK:.*]] = shape.rank %[[MINIMUM_SHAPES]]#2 : tensor<?xindex> -> index
// CHECK-NEXT: %[[GREATEST_RANK_CMP:.*]] = cmpi sgt, %[[GREATER_RANK]], %[[RHS_RANK]] : index
// CHECK-NEXT: %[[GREATEST_RANK:.*]] = select %[[GREATEST_RANK_CMP]], %[[GREATER_RANK]], %[[RHS_RANK]] : index
// CHECK-NEXT: %c1 = constant 1 : index
// CHECK-NEXT: %[[GREATEST_RANK_IS_1:.*]] = cmpi eq, %[[GREATEST_RANK]], %c1 : index
// Handle rank 1 specialization
// CHECK-NEXT: scf.if %[[GREATEST_RANK_IS_1]] -> (tensor<*xf32>) {