Support CHLO->LHLO lowering for broadcasting operations with both inputs unranked.

PiperOrigin-RevId: 323960733
This commit is contained in:
Tres Popp 2020-07-30 01:56:40 -07:00 committed by TensorFlow MLIR Team
parent b09bf2a4dc
commit ffef8d6593
2 changed files with 363 additions and 1 deletions

View File

@ -220,6 +220,229 @@ struct ConvertUnrankedScalarDynamicBroadcastBinaryOp
}
};
// Handles lowering of the following pattern to patterns that will be further
// matched by other patterns until they result in LHLO:
// %result = "chlo.op"(%lhs, %rhs) : (<*xTy>, <*xTy>) -> <*xTy>
//
// The sequence of specializations this handles is:
// - Either operand being scalar
// - Operands having equal shapes
// - The resulting value being any of ranks [2,6]
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
struct ConvertUnrankedDynamicBroadcastBinaryOp
: public OpRewritePattern<ChloOpTy> {
using OpRewritePattern<ChloOpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(ChloOpTy op,
PatternRewriter &rewriter) const override {
auto loc = op.getLoc();
Value lhs = op.lhs();
Value rhs = op.rhs();
auto lhs_type = lhs.getType().dyn_cast<UnrankedTensorType>();
auto rhs_type = rhs.getType().dyn_cast<UnrankedTensorType>();
auto result_type = op.getResult().getType().template dyn_cast<TensorType>();
// Only support unranked operands. If either operand is ranked, another
// pattern will handle the lowering.
if (!lhs_type || !rhs_type) return failure();
// If lhs is scalar
auto if_op = rewriter.create<scf::IfOp>(
loc, result_type, IsScalarTensor(rewriter, op, lhs), true);
OpBuilder if_lhs_scalar_builder = if_op.getThenBodyBuilder();
Value reshaped_lhs = if_lhs_scalar_builder.create<mhlo::ReshapeOp>(
loc, RankedTensorType::get({}, lhs_type.getElementType()), lhs);
Value if_lhs_scalar_result = if_lhs_scalar_builder.create<ChloOpTy>(
loc, ArrayRef<Type>{result_type}, ArrayRef<Value>{reshaped_lhs, rhs},
op.getAttrs());
if_lhs_scalar_builder.create<scf::YieldOp>(loc, if_lhs_scalar_result);
// If lhs is NOT scalar
//
// See if rhs is scalar
OpBuilder else_lhs_scalar_builder = if_op.getElseBodyBuilder();
auto if_rhs_scalar_op = else_lhs_scalar_builder.create<scf::IfOp>(
loc, result_type, IsScalarTensor(else_lhs_scalar_builder, op, rhs),
true);
else_lhs_scalar_builder.create<scf::YieldOp>(loc,
if_rhs_scalar_op.getResult(0));
OpBuilder if_rhs_scalar_builder = if_rhs_scalar_op.getThenBodyBuilder();
Value reshaped_rhs = if_rhs_scalar_builder.create<mhlo::ReshapeOp>(
loc, RankedTensorType::get({}, lhs_type.getElementType()), rhs);
Value if_rhs_scalar_result = if_rhs_scalar_builder.create<ChloOpTy>(
loc, ArrayRef<Type>{result_type}, ArrayRef<Value>{lhs, reshaped_rhs},
op.getAttrs());
if_rhs_scalar_builder.create<scf::YieldOp>(loc, if_rhs_scalar_result);
// If NEITHER shape is scalar
//
// See if shapes are equal.
OpBuilder else_no_scalars_builder = if_rhs_scalar_op.getElseBodyBuilder();
auto extent_tensor_type = RankedTensorType::get({ShapedType::kDynamicSize},
rewriter.getIndexType());
Value shape_of_lhs =
else_no_scalars_builder.create<shape::ToExtentTensorOp>(
loc, extent_tensor_type,
else_no_scalars_builder.create<shape::ShapeOfOp>(loc, lhs)
.getResult());
Value shape_of_rhs =
else_no_scalars_builder.create<shape::ToExtentTensorOp>(
loc, extent_tensor_type,
else_no_scalars_builder.create<shape::ShapeOfOp>(loc, rhs)
.getResult());
Value equal_shapes = else_no_scalars_builder.create<shape::ShapeEqOp>(
loc, shape_of_lhs, shape_of_rhs);
auto if_eq_shapes_op = else_no_scalars_builder.create<scf::IfOp>(
loc, result_type, equal_shapes, true);
else_no_scalars_builder.create<scf::YieldOp>(loc,
if_eq_shapes_op.getResult(0));
OpBuilder if_eq_shapes_builder = if_eq_shapes_op.getThenBodyBuilder();
Value non_broadcast_op =
Adaptor::CreateOp(op, result_type, lhs, rhs, if_eq_shapes_builder);
if_eq_shapes_builder.create<scf::YieldOp>(loc, non_broadcast_op);
// If shapes are not scalar, nor equal
//
// See if values are of a rank that we support.
OpBuilder if_neq_shapes_builder = if_eq_shapes_op.getElseBodyBuilder();
if_neq_shapes_builder.create<scf::YieldOp>(
loc, HandleBroadcastAndOp(if_neq_shapes_builder, op, lhs, rhs));
rewriter.replaceOp(op, {if_op.getResult(0)});
return success();
}
private:
// Returns the dyanamic result of checking the given value is a scalar
// tensor.
Value IsScalarTensor(OpBuilder &rewriter, ChloOpTy op, Value tensor) const {
auto loc = op.getLoc();
auto extent_tensor_type = RankedTensorType::get({ShapedType::kDynamicSize},
rewriter.getIndexType());
Value shape_of_tensor = rewriter.create<shape::ToExtentTensorOp>(
loc, extent_tensor_type,
rewriter.create<shape::ShapeOfOp>(loc, tensor).getResult());
Value rank_tensor = rewriter.create<shape::RankOp>(
loc, rewriter.getIndexType(), shape_of_tensor);
return rewriter.create<CmpIOp>(loc, rewriter.getI1Type(), CmpIPredicate::eq,
rank_tensor,
rewriter.create<ConstantIndexOp>(loc, 0));
}
// Create the if statement and code for a broadcasting op with a result of a
// given rank.
scf::IfOp createRankSpecializedBroadcastAndOp(OpBuilder &builder, ChloOpTy op,
Value lhs, Value rhs,
Value actual_rank,
int targeted_rank) const {
auto loc = op.getLoc();
// Create the if block to place the current specialized logic in.
Value greater_rank_is_n = builder.create<CmpIOp>(
loc, CmpIPredicate::eq, actual_rank,
builder.create<ConstantIndexOp>(loc, targeted_rank));
auto if_op =
builder.create<scf::IfOp>(loc, lhs.getType(), greater_rank_is_n, true);
OpBuilder if_builder = if_op.getThenBodyBuilder();
// Handle shape broadcasting and inferrence.
Value lhs_shape = if_builder.create<shape::ShapeOfOp>(loc, lhs);
Value rhs_shape = if_builder.create<shape::ShapeOfOp>(loc, rhs);
SmallVector<int64_t, 6> ranked_shape(targeted_rank, 1);
auto extent_tensor_type =
RankedTensorType::get({targeted_rank}, builder.getIndexType());
auto reshaped_type = RankedTensorType::get(
llvm::SmallVector<int64_t, 6>(targeted_rank,
RankedTensorType::kDynamicSize),
lhs.getType().template dyn_cast<TensorType>().getElementType());
Value ranked_shape_val = if_builder.create<shape::ConstShapeOp>(
loc, extent_tensor_type,
mlir::DenseIntElementsAttr::get(extent_tensor_type, ranked_shape));
// TODO(tpopp): Return extent tensors when possible to signal that this is a
// guaranteed safe broadcast by construction.
Value extended_lhs = if_builder.create<shape::BroadcastOp>(
loc, lhs_shape, ranked_shape_val, nullptr);
Value extended_rhs = if_builder.create<shape::BroadcastOp>(
loc, rhs_shape, ranked_shape_val, nullptr);
Value lhs_extent_tensor = if_builder.create<shape::ToExtentTensorOp>(
loc, extent_tensor_type, extended_lhs);
Value rhs_extent_tensor = if_builder.create<shape::ToExtentTensorOp>(
loc, extent_tensor_type, extended_rhs);
// 1. Reshape operands to the given rank (with the same number of elements)
// 2. Compute the ranked-broadcasted ChloOp (which will assert that the ops
// can be broadcasted and do the actual broadcasting)
// 3. Type erase the output back to unranked
Value reshaped_lhs = if_builder.create<mhlo::DynamicReshapeOp>(
loc, reshaped_type, lhs, lhs_extent_tensor);
Value reshaped_rhs = if_builder.create<mhlo::DynamicReshapeOp>(
loc, reshaped_type, rhs, rhs_extent_tensor);
Value result = if_builder.create<ChloOpTy>(
loc, ArrayRef<Type>{reshaped_type},
ArrayRef<Value>{reshaped_lhs, reshaped_rhs}, op.getAttrs());
Value reshaped_result = if_builder.create<TensorCastOp>(
loc, UnrankedTensorType::get(reshaped_type.getElementType()), result);
if_builder.create<scf::YieldOp>(loc, reshaped_result);
// Return the if_op, so the result can be used and the else block can be
// used for the next rank specialized step.
return if_op;
}
// Iterates over the desired ranks to be specialized and generates the code
// snippet for each case.
Value HandleBroadcastAndOp(OpBuilder &rewriter, ChloOpTy op, Value lhs,
Value rhs) const {
constexpr int max_rank_specialization = 7;
auto loc = op.getLoc();
// Find the larger rank of the 2 operands.
auto extent_tensor_type = RankedTensorType::get({ShapedType::kDynamicSize},
rewriter.getIndexType());
Value lhs_shape =
rewriter.create<shape::ShapeOfOp>(loc, extent_tensor_type, lhs);
Value rhs_shape =
rewriter.create<shape::ShapeOfOp>(loc, extent_tensor_type, rhs);
Value lhs_rank =
rewriter.create<RankOp>(loc, rewriter.getIndexType(), lhs_shape);
Value rhs_rank =
rewriter.create<RankOp>(loc, rewriter.getIndexType(), rhs_shape);
Value greater_rank_lhs =
rewriter.create<CmpIOp>(loc, CmpIPredicate::sgt, lhs_rank, rhs_rank);
Value greater_rank =
rewriter.create<SelectOp>(loc, greater_rank_lhs, lhs_rank, rhs_rank);
// Generate a list of nested if/else statements to handle rank
// specializations from 2-6.
scf::IfOp if_op = createRankSpecializedBroadcastAndOp(rewriter, op, lhs,
rhs, greater_rank, 2);
// Put each subsequent rank specialization inside the else statement of the
// previous one.
OpBuilder else_builder = if_op.getElseBodyBuilder();
for (int i = 3; i < max_rank_specialization; i++) {
auto inner_if = createRankSpecializedBroadcastAndOp(else_builder, op, lhs,
rhs, greater_rank, i);
else_builder.create<scf::YieldOp>(loc, inner_if.getResult(0));
else_builder = inner_if.getElseBodyBuilder();
}
// Fire an assertion if none of the rank specializations applied (one of the
// ranks was greater than 6).
else_builder.create<AssertOp>(
loc, else_builder.create<ConstantIntOp>(loc, 0, 1),
"Input for dynamic binary op lowering was of a rank greater than 6");
else_builder.create<scf::YieldOp>(loc, lhs);
// Return the result of the outermost if statement.
return if_op.getResult(0);
}
};
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
void PopulateForBinaryOp(MLIRContext *context,
OwningRewritePatternList *patterns) {
@ -230,7 +453,8 @@ void PopulateForBinaryOp(MLIRContext *context,
ConvertRankedDynamicBroadcastBinaryOp<ChloOpTy, HloOpTy, Adaptor>>(
context, 5);
patterns->insert<
ConvertUnrankedScalarDynamicBroadcastBinaryOp<ChloOpTy, HloOpTy>>(
ConvertUnrankedScalarDynamicBroadcastBinaryOp<ChloOpTy, HloOpTy>,
ConvertUnrankedDynamicBroadcastBinaryOp<ChloOpTy, HloOpTy, Adaptor>>(
context);
}

View File

@ -311,3 +311,141 @@ func @addUnrankedScalar(%arg0: tensor<*xf32>, %arg1: tensor<f32>) -> tensor<*xf3
// CHECK: %[[RESHAPED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[ASSUMING_RESULT:.*]], %[[SHAPE_2]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK: return %[[RESHAPED_RESULT]] : tensor<*xf32>
// CHECK: }
// -----
func @addUnrankedUnranked(
%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
%0 = chlo.broadcast_add %arg0, %arg1 : (tensor<*xf32>, tensor<*xf32>)
-> tensor<*xf32>
return %0 : tensor<*xf32>
}
// CHECK-LABEL: func @addUnrankedUnranked(
// CHECK-SAME: %[[LHS:.*]]: tensor<*xf32>,
// CHECK-SAME: %[[RHS:.*]]: tensor<*xf32>) -> tensor<*xf32> {
// CHECK: %[[LHS_SHAPE:.*]] = shape.shape_of %[[LHS]] : tensor<*xf32>
// CHECK: %[[LHS_EXTENT_TENSOR:.*]] = shape.to_extent_tensor %[[LHS_SHAPE]] : tensor<?xindex>
// CHECK: %[[RANK_LHS:.*]] = shape.rank %[[LHS_EXTENT_TENSOR]]
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[LHS_IS_SCALAR:.*]] = cmpi "eq", %[[RANK_LHS]], %[[C0]] : index
// Handle scalar LHS case
// CHECK: %[[VAL_8:.*]] = scf.if %[[LHS_IS_SCALAR]] -> (tensor<*xf32>) {
// CHECK: %[[SCALAR_LHS:.*]] = "mhlo.reshape"(%[[LHS]]) : (tensor<*xf32>) -> tensor<f32>
// CHECK: %[[VAL_10:.*]] = chlo.broadcast_add %[[SCALAR_LHS]], %[[RHS]] : (tensor<f32>, tensor<*xf32>) -> tensor<*xf32>
// CHECK: scf.yield %[[VAL_10]] : tensor<*xf32>
// CHECK: } else {
// CHECK: %[[RHS_SHAPE:.*]] = shape.shape_of %[[RHS]] : tensor<*xf32>
// CHECK: %[[RHS_EXTENT_TENSOR:.*]] = shape.to_extent_tensor %[[RHS_SHAPE]] : tensor<?xindex>
// CHECK: %[[RANK_RHS:.*]] = shape.rank %[[RHS_EXTENT_TENSOR]]
// CHECK: %[[RHS_IS_SCALAR:.*]] = cmpi "eq", %[[RANK_RHS]], %[[C0]] : index
// Handle scalar RHS case
// CHECK: %[[VAL_14:.*]] = scf.if %[[RHS_IS_SCALAR]] -> (tensor<*xf32>) {
// CHECK: %[[SCALAR_RHS:.*]] = "mhlo.reshape"(%[[RHS]]) : (tensor<*xf32>) -> tensor<f32>
// CHECK: %[[VAL_16:.*]] = chlo.broadcast_add %[[LHS]], %[[SCALAR_RHS]] : (tensor<*xf32>, tensor<f32>) -> tensor<*xf32>
// CHECK: scf.yield %[[VAL_16]] : tensor<*xf32>
// CHECK: } else {
// CHECK: %[[SHAPES_EQ:.*]] = shape.shape_eq %[[LHS_EXTENT_TENSOR]], %[[RHS_EXTENT_TENSOR]] : tensor<?xindex>, tensor<?xindex>
// Handle scalar RHS case
// CHECK: %[[VAL_18:.*]] = scf.if %[[SHAPES_EQ]] -> (tensor<*xf32>) {
// CHECK: %[[VAL_19:.*]] = mhlo.add %[[LHS]], %[[RHS]] : tensor<*xf32>
// CHECK: scf.yield %[[VAL_19]] : tensor<*xf32>
// CHECK: } else {
// CHECK: %[[LHS_RANK:.*]] = rank %[[LHS_SHAPE]] : tensor<?xindex>
// CHECK: %[[RHS_RANK:.*]] = rank %[[RHS_SHAPE]] : tensor<?xindex>
// CHECK: %[[LHS_RANK_GREATER:.*]] = cmpi "sgt", %[[LHS_RANK]], %[[RHS_RANK]] : index
// CHECK: %[[GREATEST_RANK:.*]] = select %[[LHS_RANK_GREATER]], %[[LHS_RANK]], %[[RHS_RANK]] : index
// CHECK: %[[C2:.*]] = constant 2 : index
// CHECK: %[[GREATEST_RANK_IS_2:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C2]] : index
// Handle rank 2 specialization
// CHECK: %[[VAL_26:.*]] = scf.if %[[GREATEST_RANK_IS_2]] -> (tensor<*xf32>) {
// CHECK: %[[CONST_SHAPE_2:.*]] = shape.const_shape [1, 1]
// CHECK: %[[BROADCASTED_LHS_2:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_2]]
// CHECK: %[[BROADCASTED_RHS_2:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_2]]
// CHECK: %[[EXTENT_LHS_2:.*]] = shape.to_extent_tensor %[[BROADCASTED_LHS_2]] : !shape.shape -> tensor<2xindex>
// CHECK: %[[EXTENT_RHS_2:.*]] = shape.to_extent_tensor %[[BROADCASTED_RHS_2]] : !shape.shape -> tensor<2xindex>
// CHECK: %[[RESHAPED_LHS_2:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[EXTENT_LHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor<?x?xf32>
// CHECK: %[[RESHAPED_RHS_2:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[EXTENT_RHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor<?x?xf32>
// CHECK: %[[RESULT_RANK_2:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_2]], %[[RESHAPED_RHS_2]] : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[RESULT_2:.*]] = tensor_cast %[[RESULT_RANK_2]] : tensor<?x?xf32> to tensor<*xf32>
// CHECK: scf.yield %[[RESULT_2]] : tensor<*xf32>
// CHECK: } else {
// CHECK: %[[C3:.*]] = constant 3 : index
// CHECK: %[[GREATEST_RANK_IS_3:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C3]] : index
// Handle rank 3 specialization
// CHECK: %[[VAL_34:.*]] = scf.if %[[GREATEST_RANK_IS_3]] -> (tensor<*xf32>) {
// CHECK: %[[CONST_SHAPE_3:.*]] = shape.const_shape [1, 1, 1]
// CHECK: %[[BROADCASTED_LHS_3:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_3]]
// CHECK: %[[BROADCASTED_RHS_3:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_3]]
// CHECK: %[[EXTENT_LHS_3:.*]] = shape.to_extent_tensor %[[BROADCASTED_LHS_3]] : !shape.shape -> tensor<3xindex>
// CHECK: %[[EXTENT_RHS_3:.*]] = shape.to_extent_tensor %[[BROADCASTED_RHS_3]] : !shape.shape -> tensor<3xindex>
// CHECK: %[[RESHAPED_LHS_3:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[EXTENT_LHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
// CHECK: %[[RESHAPED_RHS_3:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[EXTENT_RHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
// CHECK: %[[RESULT_RANK_3:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_3]], %[[RESHAPED_RHS_3]] : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
// CHECK: %[[RESULT_3:.*]] = tensor_cast %[[RESULT_RANK_3]] : tensor<?x?x?xf32> to tensor<*xf32>
// CHECK: scf.yield %[[RESULT_3]] : tensor<*xf32>
// CHECK: } else {
// CHECK: %[[C4:.*]] = constant 4 : index
// CHECK: %[[GREATEST_RANK_IS_4:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C4]] : index
// Handle rank 4 specialization
// CHECK: %[[VAL_42:.*]] = scf.if %[[GREATEST_RANK_IS_4]] -> (tensor<*xf32>) {
// CHECK: %[[CONST_SHAPE_4:.*]] = shape.const_shape [1, 1, 1, 1]
// CHECK: %[[BROADCASTED_LHS_4:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_4]]
// CHECK: %[[BROADCASTED_RHS_4:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_4]]
// CHECK: %[[EXTENT_LHS_4:.*]] = shape.to_extent_tensor %[[BROADCASTED_LHS_4]] : !shape.shape -> tensor<4xindex>
// CHECK: %[[EXTENT_RHS_4:.*]] = shape.to_extent_tensor %[[BROADCASTED_RHS_4]] : !shape.shape -> tensor<4xindex>
// CHECK: %[[RESHAPED_LHS_4:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[EXTENT_LHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
// CHECK: %[[RESHAPED_RHS_4:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[EXTENT_RHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
// CHECK: %[[RESULT_RANK_4:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_4]], %[[RESHAPED_RHS_4]] : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
// CHECK: %[[RESULT_4:.*]] = tensor_cast %[[RESULT_RANK_4]] : tensor<?x?x?x?xf32> to tensor<*xf32>
// CHECK: scf.yield %[[RESULT_4]] : tensor<*xf32>
// CHECK: } else {
// CHECK: %[[C5:.*]] = constant 5 : index
// CHECK: %[[GREATEST_RANK_IS_5:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C5]] : index
// Handle rank 5 specialization
// CHECK: %[[VAL_50:.*]] = scf.if %[[GREATEST_RANK_IS_5]] -> (tensor<*xf32>) {
// CHECK: %[[CONST_SHAPE_5:.*]] = shape.const_shape [1, 1, 1, 1, 1]
// CHECK: %[[BROADCASTED_LHS_5:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_5]]
// CHECK: %[[BROADCASTED_RHS_5:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_5]]
// CHECK: %[[EXTENT_LHS_5:.*]] = shape.to_extent_tensor %[[BROADCASTED_LHS_5]] : !shape.shape -> tensor<5xindex>
// CHECK: %[[EXTENT_RHS_5:.*]] = shape.to_extent_tensor %[[BROADCASTED_RHS_5]] : !shape.shape -> tensor<5xindex>
// CHECK: %[[RESHAPED_LHS_5:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[EXTENT_LHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor<?x?x?x?x?xf32>
// CHECK: %[[RESHAPED_RHS_5:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[EXTENT_RHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor<?x?x?x?x?xf32>
// CHECK: %[[RESULT_RANK_5:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_5]], %[[RESHAPED_RHS_5]] : (tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
// CHECK: %[[RESULT_5:.*]] = tensor_cast %[[RESULT_RANK_5]] : tensor<?x?x?x?x?xf32> to tensor<*xf32>
// CHECK: scf.yield %[[RESULT_5]] : tensor<*xf32>
// CHECK: } else {
// CHECK: %[[C6:.*]] = constant 6 : index
// CHECK: %[[GREATEST_RANK_IS_6:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C6]] : index
// Handle rank 6 specialization
// CHECK: %[[VAL_58:.*]] = scf.if %[[GREATEST_RANK_IS_6]] -> (tensor<*xf32>) {
// CHECK: %[[CONST_SHAPE_6:.*]] = shape.const_shape [1, 1, 1, 1, 1, 1]
// CHECK: %[[BROADCASTED_LHS_6:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_6]]
// CHECK: %[[BROADCASTED_RHS_6:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_6]]
// CHECK: %[[EXTENT_LHS_6:.*]] = shape.to_extent_tensor %[[BROADCASTED_LHS_6]] : !shape.shape -> tensor<6xindex>
// CHECK: %[[EXTENT_RHS_6:.*]] = shape.to_extent_tensor %[[BROADCASTED_RHS_6]] : !shape.shape -> tensor<6xindex>
// CHECK: %[[RESHAPED_LHS_6:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[EXTENT_LHS_6]]) : (tensor<*xf32>, tensor<6xindex>) -> tensor<?x?x?x?x?x?xf32>
// CHECK: %[[RESHAPED_RHS_6:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[EXTENT_RHS_6]]) : (tensor<*xf32>, tensor<6xindex>) -> tensor<?x?x?x?x?x?xf32>
// CHECK: %[[RESULT_RANK_6:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_6]], %[[RESHAPED_RHS_6]] : (tensor<?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?xf32>) -> tensor<?x?x?x?x?x?xf32>
// CHECK: %[[RESULT_6:.*]] = tensor_cast %[[RESULT_RANK_6]] : tensor<?x?x?x?x?x?xf32> to tensor<*xf32>
// CHECK: scf.yield %[[RESULT_6]] : tensor<*xf32>
// CHECK: } else {
// CHECK: %false = constant false
// CHECK: assert %false
// CHECK: scf.yield %[[LHS]] : tensor<*xf32>
// CHECK: }
// CHECK: scf.yield %[[VAL_64:.*]] : tensor<*xf32>
// CHECK: }
// CHECK: scf.yield %[[VAL_65:.*]] : tensor<*xf32>
// CHECK: }
// CHECK: scf.yield %[[VAL_66:.*]] : tensor<*xf32>
// CHECK: }
// CHECK: scf.yield %[[VAL_67:.*]] : tensor<*xf32>
// CHECK: }
// CHECK: scf.yield %[[VAL_68:.*]] : tensor<*xf32>
// CHECK: }
// CHECK: scf.yield %[[VAL_69:.*]] : tensor<*xf32>
// CHECK: }
// CHECK: scf.yield %[[VAL_70:.*]] : tensor<*xf32>
// CHECK: }
// CHECK: return %[[VAL_71:.*]] : tensor<*xf32>
// CHECK: }