[MLIR][HLO] Add equal shapes case to rank specialization

Also restructure lowering implementation to facilitate the addition or removal
of special cases.

PiperOrigin-RevId: 374626365
This commit is contained in:
A. Unique TensorFlower 2021-05-19 05:37:49 -07:00 committed by TensorFlow MLIR Team
parent 71394fb301
commit c62fd89663
2 changed files with 296 additions and 213 deletions

View File

@ -182,10 +182,10 @@ bool IsScalarTensorType(Type ty) {
} }
Type DeriveRankedTensorTypes(Type ty, int64_t rank) { Type DeriveRankedTensorTypes(Type ty, int64_t rank) {
auto unranked_ty = ty.dyn_cast<UnrankedTensorType>(); auto tensor_ty = ty.dyn_cast<TensorType>();
if (!unranked_ty) return ty; if (!tensor_ty) return ty;
SmallVector<int64_t, 8> shape(rank, ShapedType::kDynamicSize); SmallVector<int64_t, 8> shape(rank, ShapedType::kDynamicSize);
return RankedTensorType::get(shape, unranked_ty.getElementType()); return RankedTensorType::get(shape, tensor_ty.getElementType());
} }
Type DeriveUnrankedTensorTypes(Type ty) { Type DeriveUnrankedTensorTypes(Type ty) {
@ -208,7 +208,7 @@ Optional<Value> FindUniqueNonScalar(ValueRange values) {
SmallVector<Value, 8> MaterializeRankedOperations( SmallVector<Value, 8> MaterializeRankedOperations(
OpBuilder &b, Location loc, BlockAndValueMapping &bvm, OpBuilder &b, Location loc, BlockAndValueMapping &bvm,
chlo::RankSpecializationClusterOp &op, int64_t target_rank) { chlo::RankSpecializationClusterOp op, int64_t target_rank) {
// Create ranked operations. // Create ranked operations.
for (Operation &nested_op : op.getBody()->without_terminator()) { for (Operation &nested_op : op.getBody()->without_terminator()) {
auto mapped_operands = llvm::to_vector<4>(llvm::map_range( auto mapped_operands = llvm::to_vector<4>(llvm::map_range(
@ -256,59 +256,101 @@ SmallVector<Value, 8> MaterializeFinalReshape(
})); }));
} }
struct LowerSingleNonScalarOperandPattern SmallVector<Value, 8> MaterializeRankSpecializationForSingleNonScalarOperand(
: public OpRewritePattern<chlo::RankSpecializationClusterOp> { OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op,
using OpRewritePattern<chlo::RankSpecializationClusterOp>::OpRewritePattern; Value non_scalar_operand) {
// Flatten the non-scalar operand.
Value flat_shape = b.create<tensor::FromElementsOp>(
loc, b.create<shape::NumElementsOp>(
loc, b.getIndexType(),
b.create<shape::ShapeOfOp>(loc, non_scalar_operand))
.result());
Value flat_non_scalar_operand = b.create<mhlo::DynamicReshapeOp>(
loc, DeriveRankedTensorTypes(non_scalar_operand.getType(), /*rank=*/1),
non_scalar_operand, flat_shape);
LogicalResult matchAndRewrite(chlo::RankSpecializationClusterOp op, // Materialize ranked variants for the element-wise operations.
PatternRewriter &rewriter) const override { BlockAndValueMapping bvm;
// Only apply this pattern if we can statically know that all operands have for (auto it : llvm::zip(op.getBody()->getArguments(), op.operands())) {
// the same shape or are scalars, i.e. all but one operands are scalars. Value operand = std::get<1>(it);
Optional<Value> non_scalar_operand = FindUniqueNonScalar(op.operands()); bvm.map(std::get<0>(it),
if (!non_scalar_operand) return failure(); operand == non_scalar_operand ? flat_non_scalar_operand : operand);
// Flatten the non-scalar operand.
Location loc = op.getLoc();
Value flat_shape = rewriter.create<tensor::FromElementsOp>(
loc,
rewriter
.create<shape::NumElementsOp>(
loc, rewriter.getIndexType(),
rewriter.create<shape::ShapeOfOp>(loc, *non_scalar_operand))
.result());
Value flat_non_scalar_operand = rewriter.create<mhlo::DynamicReshapeOp>(
loc, DeriveRankedTensorTypes(non_scalar_operand->getType(), /*rank=*/1),
*non_scalar_operand, flat_shape);
// Materialize ranked variants for the element-wise operations.
BlockAndValueMapping bvm;
for (auto it : llvm::zip(op.getBody()->getArguments(), op.operands())) {
Value operand = std::get<1>(it);
bvm.map(std::get<0>(it), operand == *non_scalar_operand
? flat_non_scalar_operand
: operand);
}
SmallVector<Value, 8> unshaped_results =
MaterializeRankedOperations(rewriter, loc, bvm, op, /*target_rank=*/1);
// Restore the results' expected shape.
SmallVector<Value, 8> results =
MaterializeFinalReshape(rewriter, loc, op, unshaped_results);
rewriter.replaceOp(op, results);
return success();
} }
}; SmallVector<Value, 8> unshaped_results =
MaterializeRankedOperations(b, loc, bvm, op, /*target_rank=*/1);
Value MaterializeRankSpecialization(OpBuilder &b, Location loc, // Restore the results' expected shape.
chlo::RankSpecializationClusterOp op, return MaterializeFinalReshape(b, loc, op, unshaped_results);
const SmallVector<Value, 8> &shapes, }
int64_t target_rank) {
Value MaterializeEqualShapesRankSpecializationCase(
OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op,
const SmallVector<Value, 8> &shapes,
function_ref<void(OpBuilder &, Location)> else_builder_fn) {
assert(shapes.size() >= 2 &&
"This strategy should only be materialized if there are at least two "
"shapes involved.");
// Materialize all shapes equal predicate.
Value all_shapes_eq;
for (Value s : llvm::drop_begin(shapes)) {
auto literal = b.create<shape::ShapeEqOp>(loc, shapes.front(), s);
all_shapes_eq =
all_shapes_eq
? b.create<mlir::AndOp>(loc, all_shapes_eq, literal).result()
: literal;
}
auto if_op = b.create<scf::IfOp>(
loc, op->getResultTypes(), all_shapes_eq,
[&](OpBuilder &b, Location loc) {
// Flatten operands.
Value shape = shapes.front();
for (Value s : llvm::drop_begin(shapes)) {
shape = b.create<shape::AnyOp>(loc, shape.getType(),
ValueRange{shape, s});
}
Value flat_shape = b.create<tensor::FromElementsOp>(
loc, b.create<shape::NumElementsOp>(loc, b.getIndexType(), shape)
.result());
SmallVector<Value, 8> flat_operands =
llvm::to_vector<8>(llvm::map_range(op.operands(), [&](Value v) {
return b
.create<mhlo::DynamicReshapeOp>(
loc, DeriveRankedTensorTypes(v.getType(), /*rank=*/1), v,
flat_shape)
.result();
}));
// Materialize ranked variants for the element-wise operations.
// TODO(frgossen): Materializae non-broadcasting equivalents instead.
BlockAndValueMapping bvm;
for (auto it : llvm::zip(op.getBody()->getArguments(), flat_operands))
bvm.map(std::get<0>(it), std::get<1>(it));
Value unshaped_result =
MaterializeRankedOperations(b, loc, bvm, op, /*target_rank=*/1)
.front();
// Return as unranked tensor for compatibility with the other cases.
b.create<scf::YieldOp>(
loc, b.create<tensor::CastOp>(
loc, DeriveUnrankedTensorTypes(unshaped_result.getType()),
unshaped_result)
.dest());
},
else_builder_fn);
return if_op.results().front();
}
Value MaterializeTargetRankSpecializationCase(
OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op,
const SmallVector<Value, 8> &shapes, int64_t target_rank) {
// Reshape operands to match the target rank. // Reshape operands to match the target rank.
MLIRContext *ctx = op->getContext();
llvm::SmallVector<int64_t, 8> ranked_ty_dynamic_dims( llvm::SmallVector<int64_t, 8> ranked_ty_dynamic_dims(
target_rank, RankedTensorType::kDynamicSize); target_rank, RankedTensorType::kDynamicSize);
RankedTensorType extent_tensor_ty = RankedTensorType extent_tensor_ty =
shape::getExtentTensorType(ctx, target_rank); shape::getExtentTensorType(b.getContext(), target_rank);
Value all_ones_shape = b.create<shape::ConstShapeOp>( Value all_ones_shape = b.create<shape::ConstShapeOp>(
loc, extent_tensor_ty, loc, extent_tensor_ty,
mlir::DenseIntElementsAttr::get(extent_tensor_ty, mlir::DenseIntElementsAttr::get(extent_tensor_ty,
@ -319,7 +361,8 @@ Value MaterializeRankSpecialization(OpBuilder &b, Location loc,
std::tie(operand, shape) = it; std::tie(operand, shape) = it;
Value ranked_shape = b.create<tensor::CastOp>( Value ranked_shape = b.create<tensor::CastOp>(
loc, extent_tensor_ty, loc, extent_tensor_ty,
b.create<shape::BroadcastOp>(loc, shape::getExtentTensorType(ctx), b.create<shape::BroadcastOp>(loc,
shape::getExtentTensorType(b.getContext()),
shape, all_ones_shape, shape, all_ones_shape,
/*error=*/nullptr)); /*error=*/nullptr));
Type element_ty = operand.getType().dyn_cast<TensorType>().getElementType(); Type element_ty = operand.getType().dyn_cast<TensorType>().getElementType();
@ -341,11 +384,10 @@ Value MaterializeRankSpecialization(OpBuilder &b, Location loc,
unshaped_result); unshaped_result);
} }
Value MaterializeAllRankSpecializations(OpBuilder &b, Location loc, Value RecusivelyMaterializeTargetRankSpecializationCases(
chlo::RankSpecializationClusterOp op, OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op,
const SmallVector<Value, 8> &shapes, const SmallVector<Value, 8> &shapes, Value max_rank,
Value max_rank, int64_t min_target_rank, int64_t min_target_rank, int64_t max_target_rank) {
int64_t max_target_rank) {
Value min_target_rank_predicate = Value min_target_rank_predicate =
b.create<CmpIOp>(loc, CmpIPredicate::eq, max_rank, b.create<CmpIOp>(loc, CmpIPredicate::eq, max_rank,
b.create<ConstantIndexOp>(loc, min_target_rank)); b.create<ConstantIndexOp>(loc, min_target_rank));
@ -357,7 +399,8 @@ Value MaterializeAllRankSpecializations(OpBuilder &b, Location loc,
"Input for dynamic binary or n-ary op lowering was of " "Input for dynamic binary or n-ary op lowering was of "
"a rank greater than " + "a rank greater than " +
std::to_string(max_target_rank)); std::to_string(max_target_rank));
return MaterializeRankSpecialization(b, loc, op, shapes, min_target_rank); return MaterializeTargetRankSpecializationCase(b, loc, op, shapes,
min_target_rank);
} }
// Materialize IR for the smallest considered target rank. // Materialize IR for the smallest considered target rank.
@ -366,84 +409,107 @@ Value MaterializeAllRankSpecializations(OpBuilder &b, Location loc,
/*withElseRegion=*/true); /*withElseRegion=*/true);
auto then_builder = if_op.getThenBodyBuilder(); auto then_builder = if_op.getThenBodyBuilder();
then_builder.create<scf::YieldOp>( then_builder.create<scf::YieldOp>(
loc, MaterializeRankSpecialization(then_builder, loc, op, shapes, loc, MaterializeTargetRankSpecializationCase(then_builder, loc, op,
min_target_rank)); shapes, min_target_rank));
// Recur for all remaining target ranks. // Recurse for all remaining target ranks.
auto else_builder = if_op.getElseBodyBuilder(); auto else_builder = if_op.getElseBodyBuilder();
else_builder.create<scf::YieldOp>( else_builder.create<scf::YieldOp>(
loc, loc, RecusivelyMaterializeTargetRankSpecializationCases(
MaterializeAllRankSpecializations(else_builder, loc, op, shapes, max_rank, else_builder, loc, op, shapes, max_rank, min_target_rank + 1,
min_target_rank + 1, max_target_rank)); max_target_rank));
return if_op.results().front(); return if_op.results().front();
} }
struct LowerMultipleNonScalarOperandPattern Value MaterializeGenericRankSpecializationCases(
OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op,
const SmallVector<Value, 8> &shapes) {
// Get the minimum broadcast shapes of the operands.
ValueRange reduced_shapes =
b.create<chlo::MinimumBroadcastShapesOp>(
loc,
SmallVector<Type, 8>(shapes.size(),
shape::getExtentTensorType(b.getContext())),
shapes)
.results();
// TODO(frgossen): Avoid this reshape if it is redundant in all cases.
SmallVector<Value, 8> reshaped_args;
for (auto it : llvm::zip(op.operands(), reduced_shapes)) {
Value arg = std::get<0>(it);
Value reduced_shape = std::get<1>(it);
reshaped_args.push_back(b.create<mhlo::DynamicReshapeOp>(
loc, arg.getType(), arg, reduced_shape));
}
// Find the maximum rank among the reduced operand shapes.
Value max_rank;
for (Value shape : reduced_shapes) {
Value rank = b.create<shape::RankOp>(loc, b.getIndexType(), shape);
if (!max_rank) {
max_rank = rank;
} else {
max_rank = b.create<mlir::SelectOp>(
loc, b.create<CmpIOp>(loc, CmpIPredicate::sgt, max_rank, rank),
max_rank, rank);
}
}
// Materialize rank specialization for ranks 1, ..., 8.
// TODO(frgossen): For clusters w/o a select operation, consider only ranks
// 1, ..., 5.
const int64_t kMinTargetRank = 1;
const int64_t kMaxTargetRank = 8;
return RecusivelyMaterializeTargetRankSpecializationCases(
b, loc, op, reduced_shapes, max_rank, kMinTargetRank, kMaxTargetRank);
}
Value MaterializeDefaultRankSpecialization(
OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op) {
auto shapes = llvm::to_vector<8>(llvm::map_range(op.operands(), [&](Value v) {
return b.create<shape::ShapeOfOp>(loc, v).result();
}));
// Materialize all the different cases.
Value unshaped_result = MaterializeEqualShapesRankSpecializationCase(
b, loc, op, shapes, [&](OpBuilder &b, Location loc) {
b.create<scf::YieldOp>(
loc, MaterializeGenericRankSpecializationCases(b, loc, op, shapes));
});
// Materialize final reshape once and for all rank specialization cases.
return MaterializeFinalReshape(b, loc, op, unshaped_result).front();
}
struct LowerRankSpecializationClusterPattern
: public OpRewritePattern<chlo::RankSpecializationClusterOp> { : public OpRewritePattern<chlo::RankSpecializationClusterOp> {
using OpRewritePattern<chlo::RankSpecializationClusterOp>::OpRewritePattern; using OpRewritePattern<chlo::RankSpecializationClusterOp>::OpRewritePattern;
LogicalResult matchAndRewrite(chlo::RankSpecializationClusterOp op, LogicalResult matchAndRewrite(chlo::RankSpecializationClusterOp op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
// We have a specialized pattern for the case in which all but one operands // TODO(frgossen): If there is a single operand, we can flatten it
// are scalars. // completely and apply a non-broadcasting operation.
if (FindUniqueNonScalar(op.operands())) return failure();
// If there is only one unranked operand and all others are known scalars,
// we can flatten the operands to rank 1.
Location loc = op.getLoc();
if (Optional<Value> non_scalar_operand =
FindUniqueNonScalar(op.operands())) {
rewriter.replaceOp(op,
MaterializeRankSpecializationForSingleNonScalarOperand(
rewriter, loc, op, *non_scalar_operand));
return success();
}
// Restoring the result shape currently relies on all operands being used // Restoring the result shape currently relies on all operands being used
// for a single result. The result shape is then the broadcasted shape of // for a single result. The result shape is then the broadcasted shape of
// all operands. // all operands.
if (op.getNumResults() != 1) return failure(); if (op.getNumResults() != 1) return failure();
// Get the minimum broadcast shapes of the operands. // For all other cases, reshape the operands to match in rank, apply the
Location loc = op.getLoc(); // operation, and restore the expected shape.
SmallVector<Value, 8> shapes = rewriter.replaceOp(op,
llvm::to_vector<8>(llvm::map_range(op.operands(), [&](Value v) { MaterializeDefaultRankSpecialization(rewriter, loc, op));
return rewriter.create<shape::ShapeOfOp>(loc, v).result();
}));
ValueRange reduced_shapes =
rewriter
.create<chlo::MinimumBroadcastShapesOp>(
loc,
SmallVector<Type, 8>(shapes.size(),
shape::getExtentTensorType(getContext())),
shapes)
.results();
// TODO(frgossen): Avoid this reshape if it is redundant in all cases.
SmallVector<Value, 8> reshaped_args;
for (auto it : llvm::zip(op.operands(), reduced_shapes)) {
Value arg = std::get<0>(it);
Value reduced_shape = std::get<1>(it);
reshaped_args.push_back(rewriter.create<mhlo::DynamicReshapeOp>(
loc, arg.getType(), arg, reduced_shape));
}
// Find the maximum rank among the reduced operand shapes.
Value max_rank;
for (Value shape : reduced_shapes) {
Value rank =
rewriter.create<shape::RankOp>(loc, rewriter.getIndexType(), shape);
if (!max_rank) {
max_rank = rank;
} else {
max_rank = rewriter.create<mlir::SelectOp>(
loc,
rewriter.create<CmpIOp>(loc, CmpIPredicate::sgt, max_rank, rank),
max_rank, rank);
}
}
// Materialize rank specialization for ranks 1, ..., 8.
// TODO(frgossen): For clusters w/o a select operation, consider only ranks
// 1, ..., 5.
const int64_t kMinTargetRank = 1;
const int64_t kMaxTargetRank = 8;
Value unshaped_result = MaterializeAllRankSpecializations(
rewriter, loc, op, reduced_shapes, max_rank, kMinTargetRank,
kMaxTargetRank);
// Materialize final reshape once and for all rank specialization cases.
rewriter.replaceOp(
op, MaterializeFinalReshape(rewriter, loc, op, unshaped_result));
return success(); return success();
} }
}; };
@ -475,8 +541,7 @@ void PopulateRankSpecializationClusterPatterns(
void PopulateRankSpecializationToSCFPatterns( void PopulateRankSpecializationToSCFPatterns(
MLIRContext *context, OwningRewritePatternList *patterns) { MLIRContext *context, OwningRewritePatternList *patterns) {
patterns->insert<LowerSingleNonScalarOperandPattern, patterns->insert<LowerRankSpecializationClusterPattern>(context);
LowerMultipleNonScalarOperandPattern>(context);
} }
std::unique_ptr<FunctionPass> createRankSpecializationClusterPass() { std::unique_ptr<FunctionPass> createRankSpecializationClusterPass() {

View File

@ -40,162 +40,180 @@ func @add_mul(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>,
// CHECK-SCF-DAG: %[[SHAPE_ARG0:.*]] = shape.shape_of %[[ARG0]] // CHECK-SCF-DAG: %[[SHAPE_ARG0:.*]] = shape.shape_of %[[ARG0]]
// CHECK-SCF-DAG: %[[SHAPE_ARG1:.*]] = shape.shape_of %[[ARG1]] // CHECK-SCF-DAG: %[[SHAPE_ARG1:.*]] = shape.shape_of %[[ARG1]]
// CHECK-SCF-DAG: %[[SHAPE_ARG2:.*]] = shape.shape_of %[[ARG2]] // CHECK-SCF-DAG: %[[SHAPE_ARG2:.*]] = shape.shape_of %[[ARG2]]
// Find maximum reduced rank. // Equal shapes case:
// CHECK-SCF-DAG: %[[REDUCED_SHAPES:.*]]:3 = chlo.minimum_broadcast_shapes %[[SHAPE_ARG2]], %[[SHAPE_ARG0]], %[[SHAPE_ARG1]] // CHECK-SCF-DAG: %[[EQ20:.*]] = shape.shape_eq %[[SHAPE_ARG2]], %[[SHAPE_ARG0]]
// CHECK-SCF-DAG: %[[REDUCED_RANK0:.*]] = shape.rank %[[REDUCED_SHAPES]]#1 // CHECK-SCF-DAG: %[[EQ21:.*]] = shape.shape_eq %[[SHAPE_ARG2]], %[[SHAPE_ARG1]]
// CHECK-SCF-DAG: %[[REDUCED_RANK1:.*]] = shape.rank %[[REDUCED_SHAPES]]#2 // CHECK-SCF-DAG: %[[SHAPES_EQ:.*]] = and %[[EQ20]], %[[EQ21]]
// CHECK-SCF-DAG: %[[REDUCED_RANK2:.*]] = shape.rank %[[REDUCED_SHAPES]]#0 // CHECK-SCF: %[[UNSHAPED_RES_EQ_SHAPES:.*]] = scf.if %[[SHAPES_EQ]]
// CHECK-SCF-DAG: %[[R2_GT_R0:.*]] = cmpi sgt, %[[REDUCED_RANK2]], %[[REDUCED_RANK0]] // CHECK-SCF-DAG: %[[S20:.*]] = shape.any %[[SHAPE_ARG2]], %[[SHAPE_ARG0]]
// CHECK-SCF-DAG: %[[R20:.*]] = select %[[R2_GT_R0]], %[[REDUCED_RANK2]], %[[REDUCED_RANK0]] // CHECK-SCF-DAG: %[[S201:.*]] = shape.any %[[S20]], %[[SHAPE_ARG1]]
// CHECK-SCF-DAG: %[[R20_GT_R1:.*]] = cmpi sgt, %[[R20]], %[[REDUCED_RANK1]] // CHECK-SCF-DAG: %[[N:.*]] = shape.num_elements %[[S201]]
// CHECK-SCF-DAG: %[[MAX_RED_RANK:.*]] = select %[[R20_GT_R1]], %15, %[[REDUCED_RANK1]] // CHECK-SCF-DAG: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[N]]
// Case 1: // CHECK-SCF-DAG: %[[FLAT_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[FLAT_SHAPE]])
// CHECK-SCF: %[[MAX_RED_RANK_EQ_1:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C1]] // CHECK-SCF-DAG: %[[FLAT_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[FLAT_SHAPE]])
// CHECK-SCF: %[[UNSHAPED_RES_1:.*]] = scf.if %[[MAX_RED_RANK_EQ_1]] // CHECK-SCF-DAG: %[[FLAT_ARG2:.*]] = "mhlo.dynamic_reshape"(%[[ARG2]], %[[FLAT_SHAPE]])
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_1]] // CHECK-SCF-DAG: %[[TMP:.*]] = chlo.broadcast_multiply %[[FLAT_ARG0]], %[[FLAT_ARG1]] : (tensor<?xf32>, tensor<?xf32>)
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_1]] // CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_add %[[TMP]], %[[FLAT_ARG2]] : (tensor<?xf32>, tensor<?xf32>)
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_1]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2_:.*]] = tensor.cast %[[EXT_SHAPE_ARG2]]
// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]])
// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]])
// CHECK-SCF-DAG: %[[REDUCED_ARG2:.*]] = "mhlo.dynamic_reshape"(%[[ARG2]], %[[EXT_SHAPE_ARG2_]])
// CHECK-SCF-DAG: %[[TMP:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor<?xf32>, tensor<?xf32>)
// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_add %[[TMP]], %[[REDUCED_ARG2]] : (tensor<?xf32>, tensor<?xf32>)
// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]] // CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]]
// CHECK-SCF: scf.yield %[[INNER_RES_]] // CHECK-SCF: scf.yield %[[INNER_RES_]]
// CHECK-SCF: else // CHECK-SCF: else
// Case 2: // Find maximum reduced rank.
// CHECK-SCF: %[[MAX_RED_RANK_EQ_2:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C2]] // CHECK-SCF-DAG: %[[REDUCED_SHAPES:.*]]:3 = chlo.minimum_broadcast_shapes %[[SHAPE_ARG2]], %[[SHAPE_ARG0]], %[[SHAPE_ARG1]]
// CHECK-SCF: %[[UNSHAPED_RES_2:.*]] = scf.if %[[MAX_RED_RANK_EQ_2]] // CHECK-SCF-DAG: %[[REDUCED_RANK0:.*]] = shape.rank %[[REDUCED_SHAPES]]#1
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_2]] // CHECK-SCF-DAG: %[[REDUCED_RANK1:.*]] = shape.rank %[[REDUCED_SHAPES]]#2
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_2]] // CHECK-SCF-DAG: %[[REDUCED_RANK2:.*]] = shape.rank %[[REDUCED_SHAPES]]#0
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_2]] // CHECK-SCF-DAG: %[[R2_GT_R0:.*]] = cmpi sgt, %[[REDUCED_RANK2]], %[[REDUCED_RANK0]]
// CHECK-SCF-DAG: %[[R20:.*]] = select %[[R2_GT_R0]], %[[REDUCED_RANK2]], %[[REDUCED_RANK0]]
// CHECK-SCF-DAG: %[[R20_GT_R1:.*]] = cmpi sgt, %[[R20]], %[[REDUCED_RANK1]]
// CHECK-SCF-DAG: %[[MAX_RED_RANK:.*]] = select %[[R20_GT_R1]], %[[R20]], %[[REDUCED_RANK1]]
// Generic case 1:
// CHECK-SCF: %[[MAX_RED_RANK_EQ_1:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C1]]
// CHECK-SCF: %[[UNSHAPED_RES_1:.*]] = scf.if %[[MAX_RED_RANK_EQ_1]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_1]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_1]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_1]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]] // CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]] // CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2_:.*]] = tensor.cast %[[EXT_SHAPE_ARG2]] // CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2_:.*]] = tensor.cast %[[EXT_SHAPE_ARG2]]
// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]]) // CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]])
// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]]) // CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]])
// CHECK-SCF-DAG: %[[REDUCED_ARG2:.*]] = "mhlo.dynamic_reshape"(%[[ARG2]], %[[EXT_SHAPE_ARG2_]]) // CHECK-SCF-DAG: %[[REDUCED_ARG2:.*]] = "mhlo.dynamic_reshape"(%[[ARG2]], %[[EXT_SHAPE_ARG2_]])
// CHECK-SCF-DAG: %[[TMP:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor<?x?xf32>, tensor<?x?xf32>) // CHECK-SCF-DAG: %[[TMP:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor<?xf32>, tensor<?xf32>)
// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_add %[[TMP]], %[[REDUCED_ARG2]] : (tensor<?x?xf32>, tensor<?x?xf32>) // CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_add %[[TMP]], %[[REDUCED_ARG2]] : (tensor<?xf32>, tensor<?xf32>)
// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]] // CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]]
// CHECK-SCF: scf.yield %[[INNER_RES_]] // CHECK-SCF: scf.yield %[[INNER_RES_]]
// CHECK-SCF: else // CHECK-SCF: else
// Case 3: // Generic case 2:
// CHECK-SCF: %[[MAX_RED_RANK_EQ_3:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C3]] // CHECK-SCF: %[[MAX_RED_RANK_EQ_2:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C2]]
// CHECK-SCF: %[[UNSHAPED_RES_3:.*]] = scf.if %[[MAX_RED_RANK_EQ_3]] // CHECK-SCF: %[[UNSHAPED_RES_2:.*]] = scf.if %[[MAX_RED_RANK_EQ_2]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_3]] // CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_2]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_3]] // CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_2]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_3]] // CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_2]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]] // CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]] // CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2_:.*]] = tensor.cast %[[EXT_SHAPE_ARG2]] // CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2_:.*]] = tensor.cast %[[EXT_SHAPE_ARG2]]
// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]]) // CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]])
// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]]) // CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]])
// CHECK-SCF-DAG: %[[REDUCED_ARG2:.*]] = "mhlo.dynamic_reshape"(%[[ARG2]], %[[EXT_SHAPE_ARG2_]]) // CHECK-SCF-DAG: %[[REDUCED_ARG2:.*]] = "mhlo.dynamic_reshape"(%[[ARG2]], %[[EXT_SHAPE_ARG2_]])
// CHECK-SCF-DAG: %[[TMP:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) // CHECK-SCF-DAG: %[[TMP:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor<?x?xf32>, tensor<?x?xf32>)
// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_add %[[TMP]], %[[REDUCED_ARG2]] : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) // CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_add %[[TMP]], %[[REDUCED_ARG2]] : (tensor<?x?xf32>, tensor<?x?xf32>)
// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]] // CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]]
// CHECK-SCF: scf.yield %[[INNER_RES_]] // CHECK-SCF: scf.yield %[[INNER_RES_]]
// CHECK-SCF: else // CHECK-SCF: else
// Case 4: // Generic case 3:
// CHECK-SCF: %[[MAX_RED_RANK_EQ_4:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C4]] // CHECK-SCF: %[[MAX_RED_RANK_EQ_3:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C3]]
// CHECK-SCF: %[[UNSHAPED_RES_4:.*]] = scf.if %[[MAX_RED_RANK_EQ_4]] // CHECK-SCF: %[[UNSHAPED_RES_3:.*]] = scf.if %[[MAX_RED_RANK_EQ_3]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_4]] // CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_3]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_4]] // CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_3]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_4]] // CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_3]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]] // CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]] // CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2_:.*]] = tensor.cast %[[EXT_SHAPE_ARG2]] // CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2_:.*]] = tensor.cast %[[EXT_SHAPE_ARG2]]
// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]]) // CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]])
// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]]) // CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]])
// CHECK-SCF-DAG: %[[REDUCED_ARG2:.*]] = "mhlo.dynamic_reshape"(%[[ARG2]], %[[EXT_SHAPE_ARG2_]]) // CHECK-SCF-DAG: %[[REDUCED_ARG2:.*]] = "mhlo.dynamic_reshape"(%[[ARG2]], %[[EXT_SHAPE_ARG2_]])
// CHECK-SCF-DAG: %[[TMP:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) // CHECK-SCF-DAG: %[[TMP:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor<?x?x?xf32>, tensor<?x?x?xf32>)
// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_add %[[TMP]], %[[REDUCED_ARG2]] : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) // CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_add %[[TMP]], %[[REDUCED_ARG2]] : (tensor<?x?x?xf32>, tensor<?x?x?xf32>)
// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]] // CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]]
// CHECK-SCF: scf.yield %[[INNER_RES_]] // CHECK-SCF: scf.yield %[[INNER_RES_]]
// CHECK-SCF: else // CHECK-SCF: else
// Case 5: // Generic case 4:
// CHECK-SCF: %[[MAX_RED_RANK_EQ_5:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C5]] // CHECK-SCF: %[[MAX_RED_RANK_EQ_4:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C4]]
// CHECK-SCF: %[[UNSHAPED_RES_5:.*]] = scf.if %[[MAX_RED_RANK_EQ_5]] // CHECK-SCF: %[[UNSHAPED_RES_4:.*]] = scf.if %[[MAX_RED_RANK_EQ_4]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_5]] // CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_4]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_5]] // CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_4]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_5]] // CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_4]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]] // CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]] // CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2_:.*]] = tensor.cast %[[EXT_SHAPE_ARG2]] // CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2_:.*]] = tensor.cast %[[EXT_SHAPE_ARG2]]
// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]]) // CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]])
// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]]) // CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]])
// CHECK-SCF-DAG: %[[REDUCED_ARG2:.*]] = "mhlo.dynamic_reshape"(%[[ARG2]], %[[EXT_SHAPE_ARG2_]]) // CHECK-SCF-DAG: %[[REDUCED_ARG2:.*]] = "mhlo.dynamic_reshape"(%[[ARG2]], %[[EXT_SHAPE_ARG2_]])
// CHECK-SCF-DAG: %[[TMP:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) // CHECK-SCF-DAG: %[[TMP:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_add %[[TMP]], %[[REDUCED_ARG2]] : (tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) // CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_add %[[TMP]], %[[REDUCED_ARG2]] : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]] // CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]]
// CHECK-SCF: scf.yield %[[INNER_RES_]] // CHECK-SCF: scf.yield %[[INNER_RES_]]
// CHECK-SCF: else // CHECK-SCF: else
// Case 6: // Generic case 5:
// CHECK-SCF: %[[MAX_RED_RANK_EQ_6:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C6]] // CHECK-SCF: %[[MAX_RED_RANK_EQ_5:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C5]]
// CHECK-SCF: %[[UNSHAPED_RES_6:.*]] = scf.if %[[MAX_RED_RANK_EQ_6]] // CHECK-SCF: %[[UNSHAPED_RES_5:.*]] = scf.if %[[MAX_RED_RANK_EQ_5]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_6]] // CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_5]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_6]] // CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_5]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_6]] // CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_5]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]] // CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]] // CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2_:.*]] = tensor.cast %[[EXT_SHAPE_ARG2]] // CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2_:.*]] = tensor.cast %[[EXT_SHAPE_ARG2]]
// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]]) // CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]])
// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]]) // CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]])
// CHECK-SCF-DAG: %[[REDUCED_ARG2:.*]] = "mhlo.dynamic_reshape"(%[[ARG2]], %[[EXT_SHAPE_ARG2_]]) // CHECK-SCF-DAG: %[[REDUCED_ARG2:.*]] = "mhlo.dynamic_reshape"(%[[ARG2]], %[[EXT_SHAPE_ARG2_]])
// CHECK-SCF-DAG: %[[TMP:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor<?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?xf32>) // CHECK-SCF-DAG: %[[TMP:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>)
// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_add %[[TMP]], %[[REDUCED_ARG2]] : (tensor<?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?xf32>) // CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_add %[[TMP]], %[[REDUCED_ARG2]] : (tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>)
// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]] // CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]]
// CHECK-SCF: scf.yield %[[INNER_RES_]] // CHECK-SCF: scf.yield %[[INNER_RES_]]
// CHECK-SCF: else // CHECK-SCF: else
// Case 7: // Generic case 6:
// CHECK-SCF: %[[MAX_RED_RANK_EQ_7:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C7]] // CHECK-SCF: %[[MAX_RED_RANK_EQ_6:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C6]]
// CHECK-SCF: %[[UNSHAPED_RES_7:.*]] = scf.if %[[MAX_RED_RANK_EQ_7]] // CHECK-SCF: %[[UNSHAPED_RES_6:.*]] = scf.if %[[MAX_RED_RANK_EQ_6]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_7]] // CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_6]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_7]] // CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_6]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_7]] // CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_6]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]] // CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]] // CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2_:.*]] = tensor.cast %[[EXT_SHAPE_ARG2]] // CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2_:.*]] = tensor.cast %[[EXT_SHAPE_ARG2]]
// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]]) // CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]])
// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]]) // CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]])
// CHECK-SCF-DAG: %[[REDUCED_ARG2:.*]] = "mhlo.dynamic_reshape"(%[[ARG2]], %[[EXT_SHAPE_ARG2_]]) // CHECK-SCF-DAG: %[[REDUCED_ARG2:.*]] = "mhlo.dynamic_reshape"(%[[ARG2]], %[[EXT_SHAPE_ARG2_]])
// CHECK-SCF-DAG: %[[TMP:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor<?x?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?x?xf32>) // CHECK-SCF-DAG: %[[TMP:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor<?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?xf32>)
// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_add %[[TMP]], %[[REDUCED_ARG2]] : (tensor<?x?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?x?xf32>) // CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_add %[[TMP]], %[[REDUCED_ARG2]] : (tensor<?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?xf32>)
// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]] // CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]]
// CHECK-SCF: scf.yield %[[INNER_RES_]] // CHECK-SCF: scf.yield %[[INNER_RES_]]
// CHECK-SCF: else // CHECK-SCF: else
// Case 8: // Generic case 7:
// CHECK-SCF: %[[MAX_RED_RANK_EQ_8:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C8]] // CHECK-SCF: %[[MAX_RED_RANK_EQ_7:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C7]]
// CHECK-SCF: assert %[[MAX_RED_RANK_EQ_8]], "Input for dynamic binary or n-ary op lowering was of a rank greater than 8" // CHECK-SCF: %[[UNSHAPED_RES_7:.*]] = scf.if %[[MAX_RED_RANK_EQ_7]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_8]] // CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_7]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_8]] // CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_7]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_8]] // CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_7]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]] // CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]] // CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2_:.*]] = tensor.cast %[[EXT_SHAPE_ARG2]] // CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2_:.*]] = tensor.cast %[[EXT_SHAPE_ARG2]]
// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]]) // CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]])
// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]]) // CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]])
// CHECK-SCF-DAG: %[[REDUCED_ARG2:.*]] = "mhlo.dynamic_reshape"(%[[ARG2]], %[[EXT_SHAPE_ARG2_]]) // CHECK-SCF-DAG: %[[REDUCED_ARG2:.*]] = "mhlo.dynamic_reshape"(%[[ARG2]], %[[EXT_SHAPE_ARG2_]])
// CHECK-SCF-DAG: %[[TMP:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor<?x?x?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?x?x?xf32>) // CHECK-SCF-DAG: %[[TMP:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor<?x?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?x?xf32>)
// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_add %[[TMP]], %[[REDUCED_ARG2]] : (tensor<?x?x?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?x?x?xf32>) // CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_add %[[TMP]], %[[REDUCED_ARG2]] : (tensor<?x?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?x?xf32>)
// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]] // CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]]
// CHECK-SCF: scf.yield %[[INNER_RES_]] // CHECK-SCF: scf.yield %[[INNER_RES_]]
// CHECK-SCF: scf.yield %[[UNSHAPED_RES_7]] // CHECK-SCF: else
// CHECK-SCF: scf.yield %[[UNSHAPED_RES_6]] // Generic case 8:
// CHECK-SCF: scf.yield %[[UNSHAPED_RES_5]] // CHECK-SCF: %[[MAX_RED_RANK_EQ_8:.*]] = cmpi eq, %[[MAX_RED_RANK]], %[[C8]]
// CHECK-SCF: scf.yield %[[UNSHAPED_RES_4]] // CHECK-SCF: assert %[[MAX_RED_RANK_EQ_8]], "Input for dynamic binary or n-ary op lowering was of a rank greater than 8"
// CHECK-SCF: scf.yield %[[UNSHAPED_RES_3]] // CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_8]]
// CHECK-SCF: scf.yield %[[UNSHAPED_RES_2]] // CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#2, %[[ONE_SHAPE_8]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_8]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG2_:.*]] = tensor.cast %[[EXT_SHAPE_ARG2]]
// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]])
// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]])
// CHECK-SCF-DAG: %[[REDUCED_ARG2:.*]] = "mhlo.dynamic_reshape"(%[[ARG2]], %[[EXT_SHAPE_ARG2_]])
// CHECK-SCF-DAG: %[[TMP:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[REDUCED_ARG1]] : (tensor<?x?x?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?x?x?xf32>)
// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_add %[[TMP]], %[[REDUCED_ARG2]] : (tensor<?x?x?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?x?x?xf32>)
// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]]
// CHECK-SCF: scf.yield %[[INNER_RES_]]
// CHECK-SCF: scf.yield %[[UNSHAPED_RES_7]]
// CHECK-SCF: scf.yield %[[UNSHAPED_RES_6]]
// CHECK-SCF: scf.yield %[[UNSHAPED_RES_5]]
// CHECK-SCF: scf.yield %[[UNSHAPED_RES_4]]
// CHECK-SCF: scf.yield %[[UNSHAPED_RES_3]]
// CHECK-SCF: scf.yield %[[UNSHAPED_RES_2]]
// CHECK-SCF: scf.yield %[[UNSHAPED_RES_1]]
// Reshape the result. // Reshape the result.
// CHECK-SCF-DAG: %[[SHAPE_ARG0:.*]] = shape.shape_of %[[ARG0]] // CHECK-SCF-DAG: %[[SHAPE_ARG0:.*]] = shape.shape_of %[[ARG0]]
// CHECK-SCF-DAG: %[[SHAPE_ARG1:.*]] = shape.shape_of %[[ARG1]] // CHECK-SCF-DAG: %[[SHAPE_ARG1:.*]] = shape.shape_of %[[ARG1]]
// CHECK-SCF-DAG: %[[SHAPE_ARG2:.*]] = shape.shape_of %[[ARG2]] // CHECK-SCF-DAG: %[[SHAPE_ARG2:.*]] = shape.shape_of %[[ARG2]]
// CHECK-SCF-DAG: %[[RES_SHAPE:.*]] = shape.broadcast %[[SHAPE_ARG2]], %[[SHAPE_ARG0]], %[[SHAPE_ARG1]] // CHECK-SCF-DAG: %[[RES_SHAPE:.*]] = shape.broadcast %[[SHAPE_ARG2]], %[[SHAPE_ARG0]], %[[SHAPE_ARG1]]
// CHECK-SCF-DAG: %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[UNSHAPED_RES_1]], %[[RES_SHAPE]]) // CHECK-SCF-DAG: %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[UNSHAPED_RES_EQ_SHAPES]], %[[RES_SHAPE]])
// CHECK-SCF: return %[[RES]] // CHECK-SCF: return %[[RES]]
// ----- // -----