[MLIR][HLO] Exploit scalar properties in rank specialization lowering

Take advantage of the fact that scalars are already ranked and that they are
neutral elements to broadcasting. Do not reshape scalars, do not consider them
for broadcasting, and materialize ranked operations on scalars accordingly.

PiperOrigin-RevId: 375968371
This commit is contained in:
A. Unique TensorFlower 2021-05-26 09:58:25 -07:00 committed by TensorFlow MLIR Team
parent 2f8f3d692c
commit 4ebcebf31c
2 changed files with 199 additions and 97 deletions

View File

@ -181,6 +181,10 @@ bool IsScalarTensorType(Type ty) {
return ranked_ty && ranked_ty.getRank() == 0; return ranked_ty && ranked_ty.getRank() == 0;
} }
bool IsScalarShapeType(Type ty) {
return ty.cast<RankedTensorType>().getDimSize(0) == 0;
}
Type DeriveRankedTensorTypes(Type ty, int64_t rank) { Type DeriveRankedTensorTypes(Type ty, int64_t rank) {
auto tensor_ty = ty.dyn_cast<TensorType>(); auto tensor_ty = ty.dyn_cast<TensorType>();
if (!tensor_ty) return ty; if (!tensor_ty) return ty;
@ -208,11 +212,16 @@ Optional<Value> FindUniqueNonScalar(ValueRange values) {
SmallVector<Value, 8> MaterializeRankedOperations( SmallVector<Value, 8> MaterializeRankedOperations(
OpBuilder &b, Location loc, BlockAndValueMapping &bvm, OpBuilder &b, Location loc, BlockAndValueMapping &bvm,
chlo::RankSpecializationClusterOp op, int64_t target_rank) { chlo::RankSpecializationClusterOp op) {
// Create ranked operations. // Create ranked operations.
for (Operation &nested_op : op.getBody()->without_terminator()) { for (Operation &nested_op : op.getBody()->without_terminator()) {
auto mapped_operands = llvm::to_vector<4>(llvm::map_range( auto mapped_operands = llvm::to_vector<4>(llvm::map_range(
nested_op.getOperands(), [&](Value v) { return bvm.lookup(v); })); nested_op.getOperands(), [&](Value v) { return bvm.lookup(v); }));
int64_t target_rank = 0;
for (Value v : mapped_operands) {
target_rank =
std::max(target_rank, v.getType().cast<RankedTensorType>().getRank());
}
auto ranked_result_types = llvm::to_vector<2>(llvm::map_range( auto ranked_result_types = llvm::to_vector<2>(llvm::map_range(
nested_op.getResultTypes(), nested_op.getResultTypes(),
[&](Type ty) { return DeriveRankedTensorTypes(ty, target_rank); })); [&](Type ty) { return DeriveRankedTensorTypes(ty, target_rank); }));
@ -265,12 +274,15 @@ Value MaterializeScalarRankSpecializationCase(
Value all_others_are_scalar; Value all_others_are_scalar;
for (auto it : llvm::enumerate(shapes)) { for (auto it : llvm::enumerate(shapes)) {
if (it.index() == non_scalar_idx) continue; if (it.index() == non_scalar_idx) continue;
// For statically known scalars, there is no need to test.
if (IsScalarTensorType(op.getOperand(it.index()).getType())) continue;
auto literal = auto literal =
b.create<CmpIOp>(loc, CmpIPredicate::eq, b.create<CmpIOp>(loc, CmpIPredicate::eq,
b.create<shape::NumElementsOp>(loc, it.value()), one); b.create<shape::NumElementsOp>(loc, it.value()), one);
all_others_are_scalar = all_others_are_scalar =
all_others_are_scalar all_others_are_scalar
? b.create<AndOp>(loc, all_others_are_scalar, literal).getResult() ? b.create<mlir::AndOp>(loc, all_others_are_scalar, literal)
.getResult()
: literal.result(); : literal.result();
} }
@ -303,8 +315,7 @@ Value MaterializeScalarRankSpecializationCase(
for (auto it : llvm::zip(op.getBody()->getArguments(), ranked_operands)) for (auto it : llvm::zip(op.getBody()->getArguments(), ranked_operands))
bvm.map(std::get<0>(it), std::get<1>(it)); bvm.map(std::get<0>(it), std::get<1>(it));
Value unshaped_result = Value unshaped_result =
MaterializeRankedOperations(b, loc, bvm, op, /*target_rank=*/1) MaterializeRankedOperations(b, loc, bvm, op).front();
.front();
// Return as unranked tensor for compatibility with the other cases. // Return as unranked tensor for compatibility with the other cases.
b.create<scf::YieldOp>( b.create<scf::YieldOp>(
@ -322,26 +333,29 @@ Value MaterializeEqualShapesRankSpecializationCase(
OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op, OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op,
const SmallVector<Value, 8> &shapes, const SmallVector<Value, 8> &shapes,
function_ref<void(OpBuilder &, Location)> else_builder_fn) { function_ref<void(OpBuilder &, Location)> else_builder_fn) {
assert(shapes.size() >= 2 &&
"This strategy should only be materialized if there are at least two "
"shapes involved.");
// Materialize all shapes equal predicate. // Materialize all shapes equal predicate.
Value all_shapes_eq; Value all_shapes_eq_or_scalar;
for (Value s : llvm::drop_begin(shapes)) { auto non_scalar_shapes = llvm::to_vector<8>(llvm::make_filter_range(
auto literal = b.create<shape::ShapeEqOp>(loc, shapes.front(), s); shapes, [](Value v) { return !IsScalarShapeType(v.getType()); }));
all_shapes_eq = assert(
all_shapes_eq non_scalar_shapes.size() >= 2 &&
? b.create<mlir::AndOp>(loc, all_shapes_eq, literal).result() "Equal shapes strategy requires at least two non-scalar operand shapes.");
for (Value s : llvm::drop_begin(non_scalar_shapes)) {
auto literal =
b.create<shape::ShapeEqOp>(loc, non_scalar_shapes.front(), s);
all_shapes_eq_or_scalar =
all_shapes_eq_or_scalar
? b.create<mlir::AndOp>(loc, all_shapes_eq_or_scalar, literal)
.result()
: literal; : literal;
} }
auto if_op = b.create<scf::IfOp>( auto if_op = b.create<scf::IfOp>(
loc, op->getResultTypes(), all_shapes_eq, loc, op->getResultTypes(), all_shapes_eq_or_scalar,
[&](OpBuilder &b, Location loc) { [&](OpBuilder &b, Location loc) {
// Flatten operands. // Flatten non-scalar operands.
Value shape = shapes.front(); Value shape = non_scalar_shapes.front();
for (Value s : llvm::drop_begin(shapes)) { for (Value s : llvm::drop_begin(non_scalar_shapes)) {
shape = b.create<shape::AnyOp>(loc, shape.getType(), shape = b.create<shape::AnyOp>(loc, shape.getType(),
ValueRange{shape, s}); ValueRange{shape, s});
} }
@ -350,6 +364,7 @@ Value MaterializeEqualShapesRankSpecializationCase(
.result()); .result());
auto flat_operands = auto flat_operands =
llvm::to_vector<8>(llvm::map_range(op.operands(), [&](Value v) { llvm::to_vector<8>(llvm::map_range(op.operands(), [&](Value v) {
if (IsScalarTensorType(v.getType())) return v;
return b return b
.create<mhlo::DynamicReshapeOp>( .create<mhlo::DynamicReshapeOp>(
loc, DeriveRankedTensorTypes(v.getType(), /*rank=*/1), v, loc, DeriveRankedTensorTypes(v.getType(), /*rank=*/1), v,
@ -358,13 +373,11 @@ Value MaterializeEqualShapesRankSpecializationCase(
})); }));
// Materialize ranked variants for the element-wise operations. // Materialize ranked variants for the element-wise operations.
// TODO(frgossen): Materializae non-broadcasting equivalents instead.
BlockAndValueMapping bvm; BlockAndValueMapping bvm;
for (auto it : llvm::zip(op.getBody()->getArguments(), flat_operands)) for (auto it : llvm::zip(op.getBody()->getArguments(), flat_operands))
bvm.map(std::get<0>(it), std::get<1>(it)); bvm.map(std::get<0>(it), std::get<1>(it));
Value unshaped_result = Value unshaped_result =
MaterializeRankedOperations(b, loc, bvm, op, /*target_rank=*/1) MaterializeRankedOperations(b, loc, bvm, op).front();
.front();
// Return as unranked tensor for compatibility with the other cases. // Return as unranked tensor for compatibility with the other cases.
b.create<scf::YieldOp>( b.create<scf::YieldOp>(
@ -382,8 +395,6 @@ Value MaterializeTargetRankSpecializationCase(
OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op, OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op,
const SmallVector<Value, 8> &shapes, int64_t target_rank) { const SmallVector<Value, 8> &shapes, int64_t target_rank) {
// Reshape operands to match the target rank. // Reshape operands to match the target rank.
llvm::SmallVector<int64_t, 8> ranked_ty_dynamic_dims(
target_rank, RankedTensorType::kDynamicSize);
RankedTensorType extent_tensor_ty = RankedTensorType extent_tensor_ty =
shape::getExtentTensorType(b.getContext(), target_rank); shape::getExtentTensorType(b.getContext(), target_rank);
Value all_ones_shape = b.create<shape::ConstShapeOp>( Value all_ones_shape = b.create<shape::ConstShapeOp>(
@ -394,16 +405,19 @@ Value MaterializeTargetRankSpecializationCase(
for (auto it : llvm::zip(op.operands(), shapes)) { for (auto it : llvm::zip(op.operands(), shapes)) {
Value operand, shape; Value operand, shape;
std::tie(operand, shape) = it; std::tie(operand, shape) = it;
if (operand.getType().isa<RankedTensorType>()) {
ranked_operands.push_back(operand);
continue;
}
Value ranked_shape = b.create<tensor::CastOp>( Value ranked_shape = b.create<tensor::CastOp>(
loc, extent_tensor_ty, loc, extent_tensor_ty,
b.create<shape::BroadcastOp>(loc, b.create<shape::BroadcastOp>(loc,
shape::getExtentTensorType(b.getContext()), shape::getExtentTensorType(b.getContext()),
shape, all_ones_shape, shape, all_ones_shape,
/*error=*/nullptr)); /*error=*/nullptr));
Type element_ty = operand.getType().dyn_cast<TensorType>().getElementType();
auto ranked_ty = RankedTensorType::get(ranked_ty_dynamic_dims, element_ty);
ranked_operands.push_back(b.create<mhlo::DynamicReshapeOp>( ranked_operands.push_back(b.create<mhlo::DynamicReshapeOp>(
loc, ranked_ty, operand, ranked_shape)); loc, DeriveRankedTensorTypes(operand.getType(), target_rank), operand,
ranked_shape));
} }
// Materialize ranked versions of the element-wise operations. // Materialize ranked versions of the element-wise operations.
@ -412,8 +426,7 @@ Value MaterializeTargetRankSpecializationCase(
bvm.map(std::get<0>(it), std::get<1>(it)); bvm.map(std::get<0>(it), std::get<1>(it));
// Return as unranked for compatibility with other target ranks. // Return as unranked for compatibility with other target ranks.
auto unshaped_result = auto unshaped_result = MaterializeRankedOperations(b, loc, bvm, op).front();
MaterializeRankedOperations(b, loc, bvm, op, target_rank).front();
return b.create<tensor::CastOp>( return b.create<tensor::CastOp>(
loc, DeriveUnrankedTensorTypes(unshaped_result.getType()), loc, DeriveUnrankedTensorTypes(unshaped_result.getType()),
unshaped_result); unshaped_result);
@ -460,25 +473,17 @@ Value MaterializeGenericRankSpecializationCases(
OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op, OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op,
const SmallVector<Value, 8> &shapes) { const SmallVector<Value, 8> &shapes) {
// Get the minimum broadcast shapes of the operands. // Get the minimum broadcast shapes of the operands.
ValueRange reduced_shapes = auto non_scalar_shapes = llvm::to_vector<8>(llvm::make_filter_range(
b.create<chlo::MinimumBroadcastShapesOp>( shapes, [](Value v) { return !IsScalarShapeType(v.getType()); }));
auto min_bcast_shapes_op = b.create<chlo::MinimumBroadcastShapesOp>(
loc, loc,
SmallVector<Type, 8>(shapes.size(), SmallVector<Type, 8>(non_scalar_shapes.size(),
shape::getExtentTensorType(b.getContext())), shape::getExtentTensorType(b.getContext())),
shapes) non_scalar_shapes);
.results();
// TODO(frgossen): Avoid this reshape if it is redundant in all cases.
SmallVector<Value, 8> reshaped_args;
for (auto it : llvm::zip(op.operands(), reduced_shapes)) {
Value arg = std::get<0>(it);
Value reduced_shape = std::get<1>(it);
reshaped_args.push_back(b.create<mhlo::DynamicReshapeOp>(
loc, arg.getType(), arg, reduced_shape));
}
// Find the maximum rank among the reduced operand shapes. // Find the maximum rank among the reduced operand shapes.
Value max_rank; Value max_rank;
for (Value shape : reduced_shapes) { for (Value shape : min_bcast_shapes_op.results()) {
Value rank = b.create<shape::RankOp>(loc, b.getIndexType(), shape); Value rank = b.create<shape::RankOp>(loc, b.getIndexType(), shape);
if (!max_rank) { if (!max_rank) {
max_rank = rank; max_rank = rank;
@ -489,6 +494,17 @@ Value MaterializeGenericRankSpecializationCases(
} }
} }
// Collect reduced shapes.
SmallVector<Value, 8> reduced_shapes;
auto it = min_bcast_shapes_op.result_begin();
for (Value s : shapes) {
if (IsScalarShapeType(s.getType())) {
reduced_shapes.push_back(s);
} else {
reduced_shapes.push_back(*it++);
}
}
// Materialize rank specialization for ranks 1, ..., 8. // Materialize rank specialization for ranks 1, ..., 8.
// TODO(frgossen): For clusters w/o a select operation, consider only ranks // TODO(frgossen): For clusters w/o a select operation, consider only ranks
// 1, ..., 5. // 1, ..., 5.
@ -508,34 +524,6 @@ Value MaterializeDefaultRankSpecializationCases(
}); });
} }
Value MaterializeRankSpecializationForExactlyTwoOperands(
OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op) {
assert(op->getNumOperands() == 2 && op.getNumResults() == 1 &&
"The rank specialization strategy for clusters with exactly two "
"operands supports only one result.");
auto shapes = llvm::to_vector<8>(llvm::map_range(op.operands(), [&](Value v) {
return b.create<shape::ShapeOfOp>(loc, v).result();
}));
// Materialize all the different cases.
Value unshaped_result = MaterializeScalarRankSpecializationCase(
b, loc, op, shapes, /*non_scalar_idx=*/1,
[&](OpBuilder &b, Location loc) {
b.create<scf::YieldOp>(
loc, MaterializeScalarRankSpecializationCase(
b, loc, op, shapes, /*non_scalar_idx=*/0,
[&](OpBuilder &b, Location loc) {
b.create<scf::YieldOp>(
loc, MaterializeDefaultRankSpecializationCases(
b, loc, op, shapes));
}));
});
// Materialize final reshape once and for all rank specialization cases.
return MaterializeFinalReshape(b, loc, op, unshaped_result).front();
}
SmallVector<Value, 8> MaterializeRankSpecializationForSingleNonScalarOperand( SmallVector<Value, 8> MaterializeRankSpecializationForSingleNonScalarOperand(
OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op, OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op,
Value non_scalar_operand) { Value non_scalar_operand) {
@ -552,17 +540,48 @@ SmallVector<Value, 8> MaterializeRankSpecializationForSingleNonScalarOperand(
// Materialize ranked variants for the element-wise operations. // Materialize ranked variants for the element-wise operations.
BlockAndValueMapping bvm; BlockAndValueMapping bvm;
for (auto it : llvm::zip(op.getBody()->getArguments(), op.operands())) { for (auto it : llvm::zip(op.getBody()->getArguments(), op.operands())) {
Value operand = std::get<1>(it); Value operand;
bvm.map(std::get<0>(it), Value bb_arg;
std::tie(bb_arg, operand) = it;
bvm.map(bb_arg,
operand == non_scalar_operand ? flat_non_scalar_operand : operand); operand == non_scalar_operand ? flat_non_scalar_operand : operand);
} }
SmallVector<Value, 8> unshaped_results = SmallVector<Value, 8> unshaped_results =
MaterializeRankedOperations(b, loc, bvm, op, /*target_rank=*/1); MaterializeRankedOperations(b, loc, bvm, op);
// Restore the results' expected shape. // Restore the results' expected shape.
return MaterializeFinalReshape(b, loc, op, unshaped_results); return MaterializeFinalReshape(b, loc, op, unshaped_results);
} }
Value MaterializeRankSpecializationForTwoNonScalarOperands(
OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op,
ValueRange non_scalar_operands) {
assert(non_scalar_operands.size() == 2);
auto shapes = llvm::to_vector<8>(llvm::map_range(op.operands(), [&](Value v) {
return b.create<shape::ShapeOfOp>(loc, v).result();
}));
auto non_scalar_lhs = llvm::find(op.operands(), non_scalar_operands[0]);
auto non_scalar_rhs = llvm::find(op.operands(), non_scalar_operands[1]);
// Materialize all the different cases.
Value unshaped_result = MaterializeScalarRankSpecializationCase(
b, loc, op, shapes, non_scalar_rhs.getIndex(),
[&](OpBuilder &b, Location loc) {
b.create<scf::YieldOp>(
loc, MaterializeScalarRankSpecializationCase(
b, loc, op, shapes, non_scalar_lhs.getIndex(),
[&](OpBuilder &b, Location loc) {
b.create<scf::YieldOp>(
loc, MaterializeDefaultRankSpecializationCases(
b, loc, op, shapes));
}));
});
// Materialize final reshape once and for all rank specialization cases.
return MaterializeFinalReshape(b, loc, op, unshaped_result).front();
}
// Materialize rank generic rank specialization. // Materialize rank generic rank specialization.
Value MaterializeDefaultRankSpecialization( Value MaterializeDefaultRankSpecialization(
OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op) { OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op) {
@ -584,31 +603,36 @@ struct LowerRankSpecializationClusterPattern
LogicalResult matchAndRewrite(chlo::RankSpecializationClusterOp op, LogicalResult matchAndRewrite(chlo::RankSpecializationClusterOp op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
// Restoring the result shape currently relies on all operands being used // Restoring the result shape currently relies on all operands being used
// for a single result. The result shape is then the broadcasted shape of // for a single result. The result shape is then the broadcasted shape of
// all operands. // all operands.
if (op.getNumResults() != 1) return failure(); if (op.getNumResults() != 1) return failure();
// TODO(frgossen): If there is a single operand, we can flatten it // If there is only a single non-scalar operand, we can flatten that operand
// completely and apply a non-broadcasting operation. // completely.
Location loc = op.getLoc();
// If there is only one unranked operand and all others are known scalars, auto non_scalar_operands =
// we can flatten the operands to rank 1. llvm::to_vector<2>(llvm::make_filter_range(op.operands(), [](Value v) {
if (Optional<Value> non_scalar_operand = return !IsScalarTensorType(v.getType());
FindUniqueNonScalar(op.operands())) { }));
if (non_scalar_operands.size() == 1) {
rewriter.replaceOp(op, rewriter.replaceOp(op,
MaterializeRankSpecializationForSingleNonScalarOperand( MaterializeRankSpecializationForSingleNonScalarOperand(
rewriter, loc, op, *non_scalar_operand)); rewriter, loc, op, non_scalar_operands.front()));
return success(); return success();
} }
// If there are only two operands, we can consider extra cases in which // If there are exactly two unranked operands and all others are known to be
// either operand is scalar. // scalars, we can consider two extra cases: If either of the unranked
if (op->getNumOperands() == 2) { // operands turns out to be a scalar at runtime, we can, again, apply the
rewriter.replaceOp(op, MaterializeRankSpecializationForExactlyTwoOperands( // trick for a single non-scalar operand.
rewriter, loc, op)); if (non_scalar_operands.size() == 2 &&
llvm::all_of(non_scalar_operands, [](Value v) {
return v.getType().isa<UnrankedTensorType>();
})) {
rewriter.replaceOp(op,
MaterializeRankSpecializationForTwoNonScalarOperands(
rewriter, loc, op, non_scalar_operands));
return success(); return success();
} }

View File

@ -267,7 +267,7 @@ func @sqrt_ranked(%arg: tensor<3x?xf32>) -> tensor<3x?xf32> {
// ----- // -----
// Ternary operation. // Operation with mixed ranked and unranked operands.
// CHECK-LABEL: @select_mixed // CHECK-LABEL: @select_mixed
// CHECK-SAME: (%[[PRED:.*]]: tensor<*xi1>, %[[ARG1:.*]]: tensor<*xf32>, %[[ARG2:.*]]: tensor<2xf32>) // CHECK-SAME: (%[[PRED:.*]]: tensor<*xi1>, %[[ARG1:.*]]: tensor<*xf32>, %[[ARG2:.*]]: tensor<2xf32>)
func @select_mixed(%pred: tensor<*xi1>, %arg1: tensor<*xf32>, func @select_mixed(%pred: tensor<*xi1>, %arg1: tensor<*xf32>,
@ -284,6 +284,7 @@ func @select_mixed(%pred: tensor<*xi1>, %arg1: tensor<*xf32>,
// CHECK-SCF-LABEL: @select_mixed // CHECK-SCF-LABEL: @select_mixed
// CHECK-SCF: chlo.broadcast_select %{{.*}}, %{{.*}}, %{{.*}} : (tensor<?xi1>, tensor<?xf32>, tensor<?xf32>) // CHECK-SCF: chlo.broadcast_select %{{.*}}, %{{.*}}, %{{.*}} : (tensor<?xi1>, tensor<?xf32>, tensor<?xf32>)
// CHECK-SCF: return
// ----- // -----
@ -417,6 +418,7 @@ func @angle(%arg : tensor<*xcomplex<f32>>) -> tensor<*xf32> {
// ----- // -----
// Scalar cluster operand.
// CHECK-LABEL: @xlogy // CHECK-LABEL: @xlogy
// CHECK-SAME: (%[[ARG0:.*]]: tensor<*xf32>, %[[ARG1:.*]]: tensor<*xf32>) // CHECK-SAME: (%[[ARG0:.*]]: tensor<*xf32>, %[[ARG1:.*]]: tensor<*xf32>)
func @xlogy(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>) -> tensor<*xf32> { func @xlogy(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>) -> tensor<*xf32> {
@ -441,11 +443,87 @@ func @xlogy(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>) -> tensor<*xf32> {
return %5 : tensor<*xf32> return %5 : tensor<*xf32>
} }
// CHECK-SCF-LABEL: @xlogy // CHECK-SCF: @xlogy
// CHECK-SCF-DAG: %[[PRED:.*]] = chlo.broadcast_compare %{{.*}}, %{{.*}} {comparison_direction = "EQ"} : (tensor<?x?x?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?x?x?xf32>) // CHECK-SCF-SAME: (%[[ARG0:.*]]: tensor<*xf32>, %[[ARG1:.*]]: tensor<*xf32>)
// CHECK-SCF-DAG: %[[TMP0:.*]] = "mhlo.log"(%{{.*}}) : (tensor<?x?x?x?x?x?x?x?xf32>) // CHECK-SCF-DAG: %[[C1:.*]] = constant 1
// CHECK-SCF-DAG: %[[TMP1:.*]] = chlo.broadcast_multiply %{{.*}}, %[[TMP0]] : (tensor<?x?x?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?x?x?xf32>) // CHECK-SCF-DAG: %[[ONE_SHAPE_1:.*]] = shape.const_shape [1]
// CHECK-SCF: chlo.broadcast_select %[[PRED]], %{{.*}}, %[[TMP1]] : (tensor<?x?x?x?x?x?x?x?xi1>, tensor<?x?x?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?x?x?xf32>) // CHECK-SCF-DAG: %[[SHAPE_ARG0:.*]] = shape.shape_of %[[ARG0]]
// CHECK-SCF-DAG: %[[SHAPE_ARG1:.*]] = shape.shape_of %[[ARG1]]
// CHECK-SCF-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.00{{.*}}>
// Lhs scalar case:
// CHECK-SCF-DAG: %[[LHS_N:.*]] = shape.num_elements %[[SHAPE_ARG0]]
// CHECK-SCF-DAG: %[[LHS_SCALAR:.*]] = cmpi eq, %[[LHS_N]], %[[C1]]
// CHECK-SCF: %[[UNSHAPED_RES:.*]] = scf.if %[[LHS_SCALAR]]
// CHECK-SCF-DAG: %[[N:.*]] = shape.num_elements %[[SHAPE_ARG1]]
// CHECK-SCF-DAG: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[N]]
// CHECK-SCF-DAG: %[[FLAT_NON_SCALAR:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[FLAT_SHAPE]])
// CHECK-SCF-DAG: %[[SCALAR:.*]] = "mhlo.reshape"(%[[ARG0]])
// CHECK-SCF-DAG: %[[PRED:.*]] = chlo.broadcast_compare %[[SCALAR]], %[[ZERO]] {comparison_direction = "EQ"} : (tensor<f32>, tensor<f32>)
// CHECK-SCF-DAG: %[[TMP0:.*]] = "mhlo.log"(%[[FLAT_NON_SCALAR]]) : (tensor<?xf32>)
// CHECK-SCF-DAG: %[[TMP1:.*]] = chlo.broadcast_multiply %[[SCALAR]], %[[TMP0]] : (tensor<f32>, tensor<?xf32>)
// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_select %[[PRED]], %[[ZERO]], %[[TMP1]] : (tensor<i1>, tensor<f32>, tensor<?xf32>)
// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]]
// CHECK-SCF: scf.yield %[[INNER_RES_]]
// CHECK-SCF: else
// Rhs scalar case:
// CHECK-SCF-DAG: %[[RHS_N:.*]] = shape.num_elements %[[SHAPE_ARG1]]
// CHECK-SCF-DAG: %[[RHS_SCALAR:.*]] = cmpi eq, %[[RHS_N]], %[[C1]]
// CHECK-SCF: %{{.*}} = scf.if %[[RHS_SCALAR]]
// CHECK-SCF-DAG: %[[N:.*]] = shape.num_elements %[[SHAPE_ARG0]]
// CHECK-SCF-DAG: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[N]]
// CHECK-SCF-DAG: %[[FLAT_NON_SCALAR:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[FLAT_SHAPE]])
// CHECK-SCF-DAG: %[[SCALAR:.*]] = "mhlo.reshape"(%[[ARG1]])
// CHECK-SCF-DAG: %[[PRED:.*]] = chlo.broadcast_compare %[[FLAT_NON_SCALAR]], %[[ZERO]] {comparison_direction = "EQ"} : (tensor<?xf32>, tensor<f32>)
// CHECK-SCF-DAG: %[[TMP0:.*]] = "mhlo.log"(%[[SCALAR]]) : (tensor<f32>)
// CHECK-SCF-DAG: %[[TMP1:.*]] = chlo.broadcast_multiply %[[FLAT_NON_SCALAR]], %[[TMP0]] : (tensor<?xf32>, tensor<f32>)
// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_select %[[PRED]], %[[ZERO]], %[[TMP1]] : (tensor<?xi1>, tensor<f32>, tensor<?xf32>)
// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]]
// CHECK-SCF: scf.yield %[[INNER_RES_]]
// CHECK-SCF: else
// Equal shapes case:
// CHECK-SCF-DAG: %[[SHAPES_EQ:.*]] = shape.shape_eq %[[SHAPE_ARG0]], %[[SHAPE_ARG1]]
// CHECK-SCF: %{{.*}} = scf.if %[[SHAPES_EQ]]
// CHECK-SCF-DAG: %[[SHAPE:.*]] = shape.any %[[SHAPE_ARG0]], %[[SHAPE_ARG1]]
// CHECK-SCF-DAG: %[[N:.*]] = shape.num_elements %[[SHAPE]]
// CHECK-SCF-DAG: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[N]]
// CHECK-SCF-DAG: %[[FLAT_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[FLAT_SHAPE]])
// CHECK-SCF-DAG: %[[FLAT_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[FLAT_SHAPE]])
// CHECK-SCF-DAG: %[[PRED:.*]] = chlo.broadcast_compare %[[FLAT_ARG0]], %[[ZERO]] {comparison_direction = "EQ"} : (tensor<?xf32>, tensor<f32>)
// CHECK-SCF-DAG: %[[TMP0:.*]] = "mhlo.log"(%[[FLAT_ARG1]]) : (tensor<?xf32>)
// CHECK-SCF-DAG: %[[TMP1:.*]] = chlo.broadcast_multiply %[[FLAT_ARG0]], %[[TMP0]] : (tensor<?xf32>, tensor<?xf32>)
// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_select %[[PRED]], %[[ZERO]], %[[TMP1]] : (tensor<?xi1>, tensor<f32>, tensor<?xf32>)
// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]]
// CHECK-SCF: scf.yield %[[INNER_RES_]]
// CHECK-SCF: else
// Find maximum reduced rank.
// CHECK-SCF-DAG: %[[REDUCED_SHAPES:.*]]:2 = chlo.minimum_broadcast_shapes %[[SHAPE_ARG0]], %[[SHAPE_ARG1]]
// CHECK-SCF-DAG: %[[REDUCED_RANK0:.*]] = shape.rank %[[REDUCED_SHAPES]]#0
// CHECK-SCF-DAG: %[[REDUCED_RANK1:.*]] = shape.rank %[[REDUCED_SHAPES]]#1
// CHECK-SCF-DAG: %[[R0_GT_R1:.*]] = cmpi sgt, %[[REDUCED_RANK0]], %[[REDUCED_RANK1]]
// CHECK-SCF-DAG: %[[MAX_RED_RANK:.*]] = select %[[R0_GT_R1]], %[[REDUCED_RANK0]], %[[REDUCED_RANK1]]
// Generic case 1:
// CHECK-SCF: %[[MAX_RED_RANK_LE_1:.*]] = cmpi ule, %[[MAX_RED_RANK]], %[[C1]]
// CHECK-SCF: %{{.*}} = scf.if %[[MAX_RED_RANK_LE_1]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_1]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_1]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]]
// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]])
// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]])
// CHECK-SCF-DAG: %[[PRED:.*]] = chlo.broadcast_compare %[[REDUCED_ARG0]], %[[ZERO]] {comparison_direction = "EQ"} : (tensor<?xf32>, tensor<f32>)
// CHECK-SCF-DAG: %[[TMP0:.*]] = "mhlo.log"(%[[REDUCED_ARG1]]) : (tensor<?xf32>)
// CHECK-SCF-DAG: %[[TMP1:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[TMP0]] : (tensor<?xf32>, tensor<?xf32>)
// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_select %[[PRED]], %[[ZERO]], %[[TMP1]] : (tensor<?xi1>, tensor<f32>, tensor<?xf32>)
// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]]
// CHECK-SCF: scf.yield %[[INNER_RES_]]
// CHECK-SCF: else
// ...
// Reshape the result.
// CHECK-SCF-DAG: %[[SHAPE_ARG0:.*]] = shape.shape_of %[[ARG0]]
// CHECK-SCF-DAG: %[[SHAPE_ARG1:.*]] = shape.shape_of %[[ARG1]]
// CHECK-SCF-DAG: %[[RES_SHAPE:.*]] = shape.broadcast %[[SHAPE_ARG0]], %[[SHAPE_ARG1]]
// CHECK-SCF-DAG: %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[UNSHAPED_RES]], %[[RES_SHAPE]])
// CHECK-SCF: return %[[RES]]
// ----- // -----