[MLIR][HLO] Add scalar cases for binary rank specialization

For rank specialization clusters that have only two operands, we can materialize
two extra cases in which either of them is a scalar. This avoids redundant index
computations in these cases.

PiperOrigin-RevId: 375037390
This commit is contained in:
A. Unique TensorFlower 2021-05-21 01:34:45 -07:00 committed by TensorFlow MLIR Team
parent a7884196f5
commit 3daf65578a
2 changed files with 340 additions and 39 deletions

View File

@ -256,31 +256,66 @@ SmallVector<Value, 8> MaterializeFinalReshape(
}));
}
SmallVector<Value, 8> MaterializeRankSpecializationForSingleNonScalarOperand(
Value MaterializeScalarRankSpecializationCase(
OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op,
Value non_scalar_operand) {
// Flatten the non-scalar operand.
Value flat_shape = b.create<tensor::FromElementsOp>(
loc, b.create<shape::NumElementsOp>(
loc, b.getIndexType(),
b.create<shape::ShapeOfOp>(loc, non_scalar_operand))
.result());
Value flat_non_scalar_operand = b.create<mhlo::DynamicReshapeOp>(
loc, DeriveRankedTensorTypes(non_scalar_operand.getType(), /*rank=*/1),
non_scalar_operand, flat_shape);
// Materialize ranked variants for the element-wise operations.
BlockAndValueMapping bvm;
for (auto it : llvm::zip(op.getBody()->getArguments(), op.operands())) {
Value operand = std::get<1>(it);
bvm.map(std::get<0>(it),
operand == non_scalar_operand ? flat_non_scalar_operand : operand);
const SmallVector<Value, 8> &shapes, int64_t non_scalar_idx,
function_ref<void(OpBuilder &, Location)> else_builder_fn) {
// Materialize predicate: All operands except one are scalars.
Value one = b.create<ConstantIndexOp>(loc, 1);
Value all_others_are_scalar;
for (auto it : llvm::enumerate(shapes)) {
if (it.index() == non_scalar_idx) continue;
auto literal =
b.create<CmpIOp>(loc, CmpIPredicate::eq,
b.create<shape::NumElementsOp>(loc, it.value()), one);
all_others_are_scalar =
all_others_are_scalar
? b.create<AndOp>(loc, all_others_are_scalar, literal).getResult()
: literal.result();
}
SmallVector<Value, 8> unshaped_results =
MaterializeRankedOperations(b, loc, bvm, op, /*target_rank=*/1);
// Restore the results' expected shape.
return MaterializeFinalReshape(b, loc, op, unshaped_results);
auto if_op = b.create<scf::IfOp>(
loc, op->getResultTypes(), all_others_are_scalar,
[&](OpBuilder &b, Location loc) {
// Flatten the non-scalar operand.
Value flat_shape = b.create<tensor::FromElementsOp>(
loc, b.create<shape::NumElementsOp>(loc, b.getIndexType(),
shapes[non_scalar_idx])
.result());
Value non_scalar_operand = op.operands()[non_scalar_idx];
Value flat_non_scalar_operand = b.create<mhlo::DynamicReshapeOp>(
loc,
DeriveRankedTensorTypes(non_scalar_operand.getType(), /*rank=*/1),
non_scalar_operand, flat_shape);
// Derive ranked operands.
auto ranked_operands =
llvm::to_vector<8>(llvm::map_range(op.operands(), [&](Value v) {
if (v == non_scalar_operand) return flat_non_scalar_operand;
return b
.create<tensor::CastOp>(
loc, DeriveRankedTensorTypes(v.getType(), /*rank=*/0), v)
.dest();
}));
// Materialize ranked variants for the element-wise operations.
BlockAndValueMapping bvm;
for (auto it : llvm::zip(op.getBody()->getArguments(), ranked_operands))
bvm.map(std::get<0>(it), std::get<1>(it));
Value unshaped_result =
MaterializeRankedOperations(b, loc, bvm, op, /*target_rank=*/1)
.front();
// Return as unranked tensor for compatibility with the other cases.
b.create<scf::YieldOp>(
loc, b.create<tensor::CastOp>(
loc, DeriveUnrankedTensorTypes(unshaped_result.getType()),
unshaped_result)
.dest());
},
else_builder_fn);
return if_op.results().front();
}
Value MaterializeEqualShapesRankSpecializationCase(
@ -313,7 +348,7 @@ Value MaterializeEqualShapesRankSpecializationCase(
Value flat_shape = b.create<tensor::FromElementsOp>(
loc, b.create<shape::NumElementsOp>(loc, b.getIndexType(), shape)
.result());
SmallVector<Value, 8> flat_operands =
auto flat_operands =
llvm::to_vector<8>(llvm::map_range(op.operands(), [&](Value v) {
return b
.create<mhlo::DynamicReshapeOp>(
@ -355,7 +390,7 @@ Value MaterializeTargetRankSpecializationCase(
loc, extent_tensor_ty,
mlir::DenseIntElementsAttr::get(extent_tensor_ty,
SmallVector<int64_t, 6>(target_rank, 1)));
SmallVector<Value, 2> ranked_operands;
SmallVector<Value, 8> ranked_operands;
for (auto it : llvm::zip(op.operands(), shapes)) {
Value operand, shape;
std::tie(operand, shape) = it;
@ -464,6 +499,72 @@ Value MaterializeGenericRankSpecializationCases(
b, loc, op, reduced_shapes, max_rank, kMinTargetRank, kMaxTargetRank);
}
Value MaterializeDefaultRankSpecializationCases(
OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op,
const SmallVector<Value, 8> &shapes) {
return MaterializeEqualShapesRankSpecializationCase(
b, loc, op, shapes, [&](OpBuilder &b, Location loc) {
b.create<scf::YieldOp>(
loc, MaterializeGenericRankSpecializationCases(b, loc, op, shapes));
});
}
Value MaterializeRankSpecializationForExactlyTwoOperands(
OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op) {
assert(op->getNumOperands() == 2 && op.getNumResults() == 1 &&
"The rank specialization strategy for clusters with exactly two "
"operands supports only one result.");
auto shapes = llvm::to_vector<8>(llvm::map_range(op.operands(), [&](Value v) {
return b.create<shape::ShapeOfOp>(loc, v).result();
}));
// Materialize all the different cases.
Value unshaped_result = MaterializeScalarRankSpecializationCase(
b, loc, op, shapes, /*non_scalar_idx=*/1,
[&](OpBuilder &b, Location loc) {
b.create<scf::YieldOp>(
loc, MaterializeScalarRankSpecializationCase(
b, loc, op, shapes, /*non_scalar_idx=*/0,
[&](OpBuilder &b, Location loc) {
b.create<scf::YieldOp>(
loc, MaterializeDefaultRankSpecializationCases(
b, loc, op, shapes));
}));
});
// Materialize final reshape once and for all rank specialization cases.
return MaterializeFinalReshape(b, loc, op, unshaped_result).front();
}
SmallVector<Value, 8> MaterializeRankSpecializationForSingleNonScalarOperand(
OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op,
Value non_scalar_operand) {
// Flatten the non-scalar operand.
Value flat_shape = b.create<tensor::FromElementsOp>(
loc, b.create<shape::NumElementsOp>(
loc, b.getIndexType(),
b.create<shape::ShapeOfOp>(loc, non_scalar_operand))
.result());
Value flat_non_scalar_operand = b.create<mhlo::DynamicReshapeOp>(
loc, DeriveRankedTensorTypes(non_scalar_operand.getType(), /*rank=*/1),
non_scalar_operand, flat_shape);
// Materialize ranked variants for the element-wise operations.
BlockAndValueMapping bvm;
for (auto it : llvm::zip(op.getBody()->getArguments(), op.operands())) {
Value operand = std::get<1>(it);
bvm.map(std::get<0>(it),
operand == non_scalar_operand ? flat_non_scalar_operand : operand);
}
SmallVector<Value, 8> unshaped_results =
MaterializeRankedOperations(b, loc, bvm, op, /*target_rank=*/1);
// Restore the results' expected shape.
return MaterializeFinalReshape(b, loc, op, unshaped_results);
}
// Materialize rank generic rank specialization.
Value MaterializeDefaultRankSpecialization(
OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op) {
auto shapes = llvm::to_vector<8>(llvm::map_range(op.operands(), [&](Value v) {
@ -471,11 +572,8 @@ Value MaterializeDefaultRankSpecialization(
}));
// Materialize all the different cases.
Value unshaped_result = MaterializeEqualShapesRankSpecializationCase(
b, loc, op, shapes, [&](OpBuilder &b, Location loc) {
b.create<scf::YieldOp>(
loc, MaterializeGenericRankSpecializationCases(b, loc, op, shapes));
});
Value unshaped_result =
MaterializeDefaultRankSpecializationCases(b, loc, op, shapes);
// Materialize final reshape once and for all rank specialization cases.
return MaterializeFinalReshape(b, loc, op, unshaped_result).front();
@ -487,6 +585,11 @@ struct LowerRankSpecializationClusterPattern
LogicalResult matchAndRewrite(chlo::RankSpecializationClusterOp op,
PatternRewriter &rewriter) const override {
// Restoring the result shape currently relies on all operands being used
// for a single result. The result shape is then the broadcasted shape of
// all operands.
if (op.getNumResults() != 1) return failure();
// TODO(frgossen): If there is a single operand, we can flatten it
// completely and apply a non-broadcasting operation.
@ -501,10 +604,13 @@ struct LowerRankSpecializationClusterPattern
return success();
}
// Restoring the result shape currently relies on all operands being used
// for a single result. The result shape is then the broadcasted shape of
// all operands.
if (op.getNumResults() != 1) return failure();
// If there are only two operands, we can consider extra cases in which
// either operand is scalar.
if (op->getNumOperands() == 2) {
rewriter.replaceOp(op, MaterializeRankSpecializationForExactlyTwoOperands(
rewriter, loc, op));
return success();
}
// For all other cases, reshape the operands to match in rank, apply the
// operation, and restore the expected shape.

View File

@ -346,12 +346,12 @@ func @mixed(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>)
}
// CHECK-SCF-LABEL: @mixed
// CHECK-SCF-DAG: %[[TMP0:.*]] = chlo.tan %{{.*}} : tensor<?xf32>
// CHECK-SCF-DAG: %[[TMP1:.*]] = "mhlo.sqrt"(%{{.*}}) : (tensor<?xf32>)
// CHECK-SCF-DAG: %[[TMP2:.*]] = chlo.broadcast_multiply %[[TMP0]], %[[TMP1]] : (tensor<?xf32>, tensor<?xf32>)
// CHECK-SCF-DAG: %[[TMP3:.*]] = chlo.broadcast_add %[[TMP2]], %{{.*}} : (tensor<?xf32>, tensor<?xf32>)
// CHECK-SCF-DAG: %[[TMP4:.*]] = "mhlo.sqrt"(%[[TMP3]]) : (tensor<?xf32>)
// CHECK-SCF: chlo.tan %[[TMP4]] : tensor<?xf32>
// CHECK-SCF-DAG: %[[TMP0:.*]] = chlo.tan %{{.*}} : tensor<?xf32>
// CHECK-SCF-DAG: %[[TMP1:.*]] = "mhlo.sqrt"(%{{.*}}) : (tensor<?xf32>)
// CHECK-SCF-DAG: %[[TMP2:.*]] = chlo.broadcast_multiply %[[TMP0]], %[[TMP1]] : (tensor<?xf32>, tensor<?xf32>)
// CHECK-SCF-DAG: %[[TMP3:.*]] = chlo.broadcast_add %[[TMP2]], %{{.*}} : (tensor<?xf32>, tensor<?xf32>)
// CHECK-SCF-DAG: %[[TMP4:.*]] = "mhlo.sqrt"(%[[TMP3]]) : (tensor<?xf32>)
// CHECK-SCF: chlo.tan %[[TMP4]] : tensor<?xf32>
// -----
@ -446,3 +446,198 @@ func @xlogy(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>) -> tensor<*xf32> {
// CHECK-SCF-DAG: %[[TMP0:.*]] = "mhlo.log"(%{{.*}}) : (tensor<?x?x?x?x?x?x?x?xf32>)
// CHECK-SCF-DAG: %[[TMP1:.*]] = chlo.broadcast_multiply %{{.*}}, %[[TMP0]] : (tensor<?x?x?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?x?x?xf32>)
// CHECK-SCF: chlo.broadcast_select %[[PRED]], %{{.*}}, %[[TMP1]] : (tensor<?x?x?x?x?x?x?x?xi1>, tensor<?x?x?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?x?x?xf32>)
// -----
// CHECK-LABEL: @mul
// CHECK-SAME: (%[[ARG0:.*]]: tensor<*xf32>, %[[ARG1:.*]]: tensor<*xf32>)
func @mul(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>) -> tensor<*xf32> {
// CHECK: %[[RES:.*]] = "chlo.rank_specialization_cluster"(%[[ARG0]], %[[ARG1]])
// CHECK: ^bb0(%[[ARG0_:.*]]: tensor<*xf32>, %[[ARG1_:.*]]: tensor<*xf32>):
// CHECK: %[[TMP:.*]] = chlo.broadcast_multiply %[[ARG0_]], %[[ARG1_]]
// CHECK: "chlo.rank_specialization_cluster_yield"(%[[TMP]])
// CHECK: return %[[RES]]
%0 = chlo.broadcast_multiply %arg0, %arg1 : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
// CHECK-SCF-LABEL: @mul
// CHECK-SCF-SAME: (%[[ARG0:.*]]: tensor<*xf32>, %[[ARG1:.*]]: tensor<*xf32>)
// CHECK-SCF-DAG: %[[C1:.*]] = constant 1
// CHECK-SCF-DAG: %[[C2:.*]] = constant 2
// CHECK-SCF-DAG: %[[C3:.*]] = constant 3
// CHECK-SCF-DAG: %[[C4:.*]] = constant 4
// CHECK-SCF-DAG: %[[C5:.*]] = constant 5
// CHECK-SCF-DAG: %[[C6:.*]] = constant 6
// CHECK-SCF-DAG: %[[C7:.*]] = constant 7
// CHECK-SCF-DAG: %[[C8:.*]] = constant 8
// CHECK-SCF-DAG: %[[ONE_SHAPE_1:.*]] = shape.const_shape [1]
// CHECK-SCF-DAG: %[[ONE_SHAPE_2:.*]] = shape.const_shape [1, 1]
// CHECK-SCF-DAG: %[[ONE_SHAPE_3:.*]] = shape.const_shape [1, 1, 1]
// CHECK-SCF-DAG: %[[ONE_SHAPE_4:.*]] = shape.const_shape [1, 1, 1, 1]
// CHECK-SCF-DAG: %[[ONE_SHAPE_5:.*]] = shape.const_shape [1, 1, 1, 1, 1]
// CHECK-SCF-DAG: %[[ONE_SHAPE_6:.*]] = shape.const_shape [1, 1, 1, 1, 1, 1]
// CHECK-SCF-DAG: %[[ONE_SHAPE_7:.*]] = shape.const_shape [1, 1, 1, 1, 1, 1, 1]
// CHECK-SCF-DAG: %[[ONE_SHAPE_8:.*]] = shape.const_shape [1, 1, 1, 1, 1, 1, 1, 1]
// CHECK-SCF-DAG: %[[SHAPE_ARG0:.*]] = shape.shape_of %[[ARG0]]
// CHECK-SCF-DAG: %[[SHAPE_ARG1:.*]] = shape.shape_of %[[ARG1]]
// Lhs scalar case:
// CHECK-SCF-DAG: %[[LHS_N:.*]] = shape.num_elements %[[SHAPE_ARG0]]
// CHECK-SCF-DAG: %[[LHS_SCALAR:.*]] = cmpi eq, %[[LHS_N]], %[[C1]]
// CHECK-SCF: %[[UNSHAPED_RES_LHS_SCALAR:.*]] = scf.if %[[LHS_SCALAR]]
// CHECK-SCF-DAG: %[[N:.*]] = shape.num_elements %[[SHAPE_ARG1]]
// CHECK-SCF-DAG: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[N]]
// CHECK-SCF-DAG: %[[FLAT_NON_SCALAR:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[FLAT_SHAPE]])
// CHECK-SCF-DAG: %[[SCALAR:.*]] = tensor.cast %[[ARG0]]
// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_multiply %[[SCALAR]], %[[FLAT_NON_SCALAR]] : (tensor<f32>, tensor<?xf32>)
// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]]
// CHECK-SCF: scf.yield %[[INNER_RES_]]
// CHECK-SCF: else
// Rhs scalar case:
// CHECK-SCF-DAG: %[[RHS_N:.*]] = shape.num_elements %[[SHAPE_ARG1]]
// CHECK-SCF-DAG: %[[RHS_SCALAR:.*]] = cmpi eq, %[[RHS_N]], %[[C1]]
// CHECK-SCF: %[[UNSHAPED_RES_RHS_SCALAR:.*]] = scf.if %[[RHS_SCALAR]]
// CHECK-SCF-DAG: %[[N:.*]] = shape.num_elements %[[SHAPE_ARG0]]
// CHECK-SCF-DAG: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[N]]
// CHECK-SCF-DAG: %[[FLAT_NON_SCALAR:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[FLAT_SHAPE]])
// CHECK-SCF-DAG: %[[SCALAR:.*]] = tensor.cast %[[ARG1]]
// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_multiply %[[FLAT_NON_SCALAR]], %[[SCALAR]] : (tensor<?xf32>, tensor<f32>)
// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]]
// CHECK-SCF: scf.yield %[[INNER_RES_]]
// CHECK-SCF: else
// Equal shapes case:
// CHECK-SCF-DAG: %[[SHAPES_EQ:.*]] = shape.shape_eq %[[SHAPE_ARG0]], %[[SHAPE_ARG1]]
// CHECK-SCF: %[[UNSHAPED_RES_EQ_SHAPES:.*]] = scf.if %[[SHAPES_EQ]]
// CHECK-SCF-DAG: %[[SHAPE:.*]] = shape.any %[[SHAPE_ARG0]], %[[SHAPE_ARG1]]
// CHECK-SCF-DAG: %[[N:.*]] = shape.num_elements %[[SHAPE]]
// CHECK-SCF-DAG: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[N]]
// CHECK-SCF-DAG: %[[FLAT_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[FLAT_SHAPE]])
// CHECK-SCF-DAG: %[[FLAT_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[FLAT_SHAPE]])
// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_multiply %[[FLAT_ARG0]], %[[FLAT_ARG1]] : (tensor<?xf32>, tensor<?xf32>)
// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]]
// CHECK-SCF: scf.yield %[[INNER_RES_]]
// CHECK-SCF: else
// Find maximum reduced rank.
// CHECK-SCF-DAG: %[[REDUCED_SHAPES:.*]]:2 = chlo.minimum_broadcast_shapes %[[SHAPE_ARG0]], %[[SHAPE_ARG1]]
// CHECK-SCF-DAG: %[[REDUCED_RANK0:.*]] = shape.rank %[[REDUCED_SHAPES]]#0
// CHECK-SCF-DAG: %[[REDUCED_RANK1:.*]] = shape.rank %[[REDUCED_SHAPES]]#1
// CHECK-SCF-DAG: %[[R0_GT_R1:.*]] = cmpi sgt, %[[REDUCED_RANK0]], %[[REDUCED_RANK1]]
// CHECK-SCF-DAG: %[[MAX_RED_RANK:.*]] = select %[[R0_GT_R1]], %[[REDUCED_RANK0]], %[[REDUCED_RANK1]]
// Generic case 1:
// CHECK-SCF: %[[MAX_RED_RANK_EQ_1:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C1]]
// CHECK-SCF: %[[UNSHAPED_RES_1:.*]] = scf.if %[[MAX_RED_RANK_EQ_1]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_1]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_1]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]]
// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]])
// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]])
// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor<?xf32>, tensor<?xf32>)
// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]]
// CHECK-SCF: scf.yield %[[INNER_RES_]]
// CHECK-SCF: else
// Generic case 2:
// CHECK-SCF: %[[MAX_RED_RANK_EQ_2:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C2]]
// CHECK-SCF: %[[UNSHAPED_RES_2:.*]] = scf.if %[[MAX_RED_RANK_EQ_2]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_2]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_2]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]]
// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]])
// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]])
// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor<?x?xf32>, tensor<?x?xf32>)
// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]]
// CHECK-SCF: scf.yield %[[INNER_RES_]]
// CHECK-SCF: else
// Generic case 3:
// CHECK-SCF: %[[MAX_RED_RANK_EQ_3:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C3]]
// CHECK-SCF: %[[UNSHAPED_RES_3:.*]] = scf.if %[[MAX_RED_RANK_EQ_3]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_3]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_3]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]]
// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]])
// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]])
// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor<?x?x?xf32>, tensor<?x?x?xf32>)
// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]]
// CHECK-SCF: scf.yield %[[INNER_RES_]]
// CHECK-SCF: else
// Generic case 4:
// CHECK-SCF: %[[MAX_RED_RANK_EQ_4:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C4]]
// CHECK-SCF: %[[UNSHAPED_RES_4:.*]] = scf.if %[[MAX_RED_RANK_EQ_4]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_4]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_4]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]]
// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]])
// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]])
// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]]
// CHECK-SCF: scf.yield %[[INNER_RES_]]
// CHECK-SCF: else
// Generic case 5:
// CHECK-SCF: %[[MAX_RED_RANK_EQ_5:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C5]]
// CHECK-SCF: %[[UNSHAPED_RES_5:.*]] = scf.if %[[MAX_RED_RANK_EQ_5]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_5]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_5]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]]
// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]])
// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]])
// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>)
// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]]
// CHECK-SCF: scf.yield %[[INNER_RES_]]
// CHECK-SCF: else
// Generic case 6:
// CHECK-SCF: %[[MAX_RED_RANK_EQ_6:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C6]]
// CHECK-SCF: %[[UNSHAPED_RES_6:.*]] = scf.if %[[MAX_RED_RANK_EQ_6]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_6]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_6]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]]
// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]])
// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]])
// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor<?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?xf32>)
// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]]
// CHECK-SCF: scf.yield %[[INNER_RES_]]
// CHECK-SCF: else
// Generic case 7:
// CHECK-SCF: %[[MAX_RED_RANK_EQ_7:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C7]]
// CHECK-SCF: %[[UNSHAPED_RES_7:.*]] = scf.if %[[MAX_RED_RANK_EQ_7]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_7]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_7]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]]
// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]])
// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]])
// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor<?x?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?x?xf32>)
// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]]
// CHECK-SCF: scf.yield %[[INNER_RES_]]
// CHECK-SCF: else
// Generic case 8:
// CHECK-SCF: %[[MAX_RED_RANK_EQ_8:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C8]]
// CHECK-SCF: assert %[[MAX_RED_RANK_EQ_8]], "Input for dynamic binary or n-ary op lowering was of a rank greater than 8"
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_8]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_8]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]]
// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]])
// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]])
// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor<?x?x?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?x?x?xf32>)
// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]]
// CHECK-SCF: scf.yield %[[INNER_RES_]]
// CHECK-SCF: scf.yield %[[UNSHAPED_RES_7]]
// CHECK-SCF: scf.yield %[[UNSHAPED_RES_6]]
// CHECK-SCF: scf.yield %[[UNSHAPED_RES_5]]
// CHECK-SCF: scf.yield %[[UNSHAPED_RES_4]]
// CHECK-SCF: scf.yield %[[UNSHAPED_RES_3]]
// CHECK-SCF: scf.yield %[[UNSHAPED_RES_2]]
// CHECK-SCF: scf.yield %[[UNSHAPED_RES_1]]
// CHECK-SCF: scf.yield %[[UNSHAPED_RES_EQ_SHAPES]]
// CHECK-SCF: scf.yield %[[UNSHAPED_RES_RHS_SCALAR]]
// Reshape the result.
// CHECK-SCF-DAG: %[[SHAPE_ARG0:.*]] = shape.shape_of %[[ARG0]]
// CHECK-SCF-DAG: %[[SHAPE_ARG1:.*]] = shape.shape_of %[[ARG1]]
// CHECK-SCF-DAG: %[[RES_SHAPE:.*]] = shape.broadcast %[[SHAPE_ARG0]], %[[SHAPE_ARG1]]
// CHECK-SCF-DAG: %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[UNSHAPED_RES_LHS_SCALAR]], %[[RES_SHAPE]])
// CHECK-SCF: return %[[RES]]