[MLIR][HLO] Exploit scalar properties in rank specialization lowering

Take advantage of the fact that scalars are already ranked and that they are
neutral elements to broadcasting. Do not reshape scalars, do not consider them
for broadcasting, and materialize ranked operations on scalars accordingly.

PiperOrigin-RevId: 375968371
This commit is contained in:
A. Unique TensorFlower 2021-05-26 09:58:25 -07:00 committed by TensorFlow MLIR Team
parent 2f8f3d692c
commit 4ebcebf31c
2 changed files with 199 additions and 97 deletions

View File

@ -181,6 +181,10 @@ bool IsScalarTensorType(Type ty) {
return ranked_ty && ranked_ty.getRank() == 0;
}
bool IsScalarShapeType(Type ty) {
return ty.cast<RankedTensorType>().getDimSize(0) == 0;
}
Type DeriveRankedTensorTypes(Type ty, int64_t rank) {
auto tensor_ty = ty.dyn_cast<TensorType>();
if (!tensor_ty) return ty;
@ -208,11 +212,16 @@ Optional<Value> FindUniqueNonScalar(ValueRange values) {
SmallVector<Value, 8> MaterializeRankedOperations(
OpBuilder &b, Location loc, BlockAndValueMapping &bvm,
chlo::RankSpecializationClusterOp op, int64_t target_rank) {
chlo::RankSpecializationClusterOp op) {
// Create ranked operations.
for (Operation &nested_op : op.getBody()->without_terminator()) {
auto mapped_operands = llvm::to_vector<4>(llvm::map_range(
nested_op.getOperands(), [&](Value v) { return bvm.lookup(v); }));
int64_t target_rank = 0;
for (Value v : mapped_operands) {
target_rank =
std::max(target_rank, v.getType().cast<RankedTensorType>().getRank());
}
auto ranked_result_types = llvm::to_vector<2>(llvm::map_range(
nested_op.getResultTypes(),
[&](Type ty) { return DeriveRankedTensorTypes(ty, target_rank); }));
@ -265,12 +274,15 @@ Value MaterializeScalarRankSpecializationCase(
Value all_others_are_scalar;
for (auto it : llvm::enumerate(shapes)) {
if (it.index() == non_scalar_idx) continue;
// For statically known scalars, there is no need to test.
if (IsScalarTensorType(op.getOperand(it.index()).getType())) continue;
auto literal =
b.create<CmpIOp>(loc, CmpIPredicate::eq,
b.create<shape::NumElementsOp>(loc, it.value()), one);
all_others_are_scalar =
all_others_are_scalar
? b.create<AndOp>(loc, all_others_are_scalar, literal).getResult()
? b.create<mlir::AndOp>(loc, all_others_are_scalar, literal)
.getResult()
: literal.result();
}
@ -303,8 +315,7 @@ Value MaterializeScalarRankSpecializationCase(
for (auto it : llvm::zip(op.getBody()->getArguments(), ranked_operands))
bvm.map(std::get<0>(it), std::get<1>(it));
Value unshaped_result =
MaterializeRankedOperations(b, loc, bvm, op, /*target_rank=*/1)
.front();
MaterializeRankedOperations(b, loc, bvm, op).front();
// Return as unranked tensor for compatibility with the other cases.
b.create<scf::YieldOp>(
@ -322,26 +333,29 @@ Value MaterializeEqualShapesRankSpecializationCase(
OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op,
const SmallVector<Value, 8> &shapes,
function_ref<void(OpBuilder &, Location)> else_builder_fn) {
assert(shapes.size() >= 2 &&
"This strategy should only be materialized if there are at least two "
"shapes involved.");
// Materialize all shapes equal predicate.
Value all_shapes_eq;
for (Value s : llvm::drop_begin(shapes)) {
auto literal = b.create<shape::ShapeEqOp>(loc, shapes.front(), s);
all_shapes_eq =
all_shapes_eq
? b.create<mlir::AndOp>(loc, all_shapes_eq, literal).result()
Value all_shapes_eq_or_scalar;
auto non_scalar_shapes = llvm::to_vector<8>(llvm::make_filter_range(
shapes, [](Value v) { return !IsScalarShapeType(v.getType()); }));
assert(
non_scalar_shapes.size() >= 2 &&
"Equal shapes strategy requires at least two non-scalar operand shapes.");
for (Value s : llvm::drop_begin(non_scalar_shapes)) {
auto literal =
b.create<shape::ShapeEqOp>(loc, non_scalar_shapes.front(), s);
all_shapes_eq_or_scalar =
all_shapes_eq_or_scalar
? b.create<mlir::AndOp>(loc, all_shapes_eq_or_scalar, literal)
.result()
: literal;
}
auto if_op = b.create<scf::IfOp>(
loc, op->getResultTypes(), all_shapes_eq,
loc, op->getResultTypes(), all_shapes_eq_or_scalar,
[&](OpBuilder &b, Location loc) {
// Flatten operands.
Value shape = shapes.front();
for (Value s : llvm::drop_begin(shapes)) {
// Flatten non-scalar operands.
Value shape = non_scalar_shapes.front();
for (Value s : llvm::drop_begin(non_scalar_shapes)) {
shape = b.create<shape::AnyOp>(loc, shape.getType(),
ValueRange{shape, s});
}
@ -350,6 +364,7 @@ Value MaterializeEqualShapesRankSpecializationCase(
.result());
auto flat_operands =
llvm::to_vector<8>(llvm::map_range(op.operands(), [&](Value v) {
if (IsScalarTensorType(v.getType())) return v;
return b
.create<mhlo::DynamicReshapeOp>(
loc, DeriveRankedTensorTypes(v.getType(), /*rank=*/1), v,
@ -358,13 +373,11 @@ Value MaterializeEqualShapesRankSpecializationCase(
}));
// Materialize ranked variants for the element-wise operations.
// TODO(frgossen): Materializae non-broadcasting equivalents instead.
BlockAndValueMapping bvm;
for (auto it : llvm::zip(op.getBody()->getArguments(), flat_operands))
bvm.map(std::get<0>(it), std::get<1>(it));
Value unshaped_result =
MaterializeRankedOperations(b, loc, bvm, op, /*target_rank=*/1)
.front();
MaterializeRankedOperations(b, loc, bvm, op).front();
// Return as unranked tensor for compatibility with the other cases.
b.create<scf::YieldOp>(
@ -382,8 +395,6 @@ Value MaterializeTargetRankSpecializationCase(
OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op,
const SmallVector<Value, 8> &shapes, int64_t target_rank) {
// Reshape operands to match the target rank.
llvm::SmallVector<int64_t, 8> ranked_ty_dynamic_dims(
target_rank, RankedTensorType::kDynamicSize);
RankedTensorType extent_tensor_ty =
shape::getExtentTensorType(b.getContext(), target_rank);
Value all_ones_shape = b.create<shape::ConstShapeOp>(
@ -394,16 +405,19 @@ Value MaterializeTargetRankSpecializationCase(
for (auto it : llvm::zip(op.operands(), shapes)) {
Value operand, shape;
std::tie(operand, shape) = it;
if (operand.getType().isa<RankedTensorType>()) {
ranked_operands.push_back(operand);
continue;
}
Value ranked_shape = b.create<tensor::CastOp>(
loc, extent_tensor_ty,
b.create<shape::BroadcastOp>(loc,
shape::getExtentTensorType(b.getContext()),
shape, all_ones_shape,
/*error=*/nullptr));
Type element_ty = operand.getType().dyn_cast<TensorType>().getElementType();
auto ranked_ty = RankedTensorType::get(ranked_ty_dynamic_dims, element_ty);
ranked_operands.push_back(b.create<mhlo::DynamicReshapeOp>(
loc, ranked_ty, operand, ranked_shape));
loc, DeriveRankedTensorTypes(operand.getType(), target_rank), operand,
ranked_shape));
}
// Materialize ranked versions of the element-wise operations.
@ -412,8 +426,7 @@ Value MaterializeTargetRankSpecializationCase(
bvm.map(std::get<0>(it), std::get<1>(it));
// Return as unranked for compatibility with other target ranks.
auto unshaped_result =
MaterializeRankedOperations(b, loc, bvm, op, target_rank).front();
auto unshaped_result = MaterializeRankedOperations(b, loc, bvm, op).front();
return b.create<tensor::CastOp>(
loc, DeriveUnrankedTensorTypes(unshaped_result.getType()),
unshaped_result);
@ -460,25 +473,17 @@ Value MaterializeGenericRankSpecializationCases(
OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op,
const SmallVector<Value, 8> &shapes) {
// Get the minimum broadcast shapes of the operands.
ValueRange reduced_shapes =
b.create<chlo::MinimumBroadcastShapesOp>(
auto non_scalar_shapes = llvm::to_vector<8>(llvm::make_filter_range(
shapes, [](Value v) { return !IsScalarShapeType(v.getType()); }));
auto min_bcast_shapes_op = b.create<chlo::MinimumBroadcastShapesOp>(
loc,
SmallVector<Type, 8>(shapes.size(),
SmallVector<Type, 8>(non_scalar_shapes.size(),
shape::getExtentTensorType(b.getContext())),
shapes)
.results();
// TODO(frgossen): Avoid this reshape if it is redundant in all cases.
SmallVector<Value, 8> reshaped_args;
for (auto it : llvm::zip(op.operands(), reduced_shapes)) {
Value arg = std::get<0>(it);
Value reduced_shape = std::get<1>(it);
reshaped_args.push_back(b.create<mhlo::DynamicReshapeOp>(
loc, arg.getType(), arg, reduced_shape));
}
non_scalar_shapes);
// Find the maximum rank among the reduced operand shapes.
Value max_rank;
for (Value shape : reduced_shapes) {
for (Value shape : min_bcast_shapes_op.results()) {
Value rank = b.create<shape::RankOp>(loc, b.getIndexType(), shape);
if (!max_rank) {
max_rank = rank;
@ -489,6 +494,17 @@ Value MaterializeGenericRankSpecializationCases(
}
}
// Collect reduced shapes.
SmallVector<Value, 8> reduced_shapes;
auto it = min_bcast_shapes_op.result_begin();
for (Value s : shapes) {
if (IsScalarShapeType(s.getType())) {
reduced_shapes.push_back(s);
} else {
reduced_shapes.push_back(*it++);
}
}
// Materialize rank specialization for ranks 1, ..., 8.
// TODO(frgossen): For clusters w/o a select operation, consider only ranks
// 1, ..., 5.
@ -508,34 +524,6 @@ Value MaterializeDefaultRankSpecializationCases(
});
}
Value MaterializeRankSpecializationForExactlyTwoOperands(
OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op) {
assert(op->getNumOperands() == 2 && op.getNumResults() == 1 &&
"The rank specialization strategy for clusters with exactly two "
"operands supports only one result.");
auto shapes = llvm::to_vector<8>(llvm::map_range(op.operands(), [&](Value v) {
return b.create<shape::ShapeOfOp>(loc, v).result();
}));
// Materialize all the different cases.
Value unshaped_result = MaterializeScalarRankSpecializationCase(
b, loc, op, shapes, /*non_scalar_idx=*/1,
[&](OpBuilder &b, Location loc) {
b.create<scf::YieldOp>(
loc, MaterializeScalarRankSpecializationCase(
b, loc, op, shapes, /*non_scalar_idx=*/0,
[&](OpBuilder &b, Location loc) {
b.create<scf::YieldOp>(
loc, MaterializeDefaultRankSpecializationCases(
b, loc, op, shapes));
}));
});
// Materialize final reshape once and for all rank specialization cases.
return MaterializeFinalReshape(b, loc, op, unshaped_result).front();
}
SmallVector<Value, 8> MaterializeRankSpecializationForSingleNonScalarOperand(
OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op,
Value non_scalar_operand) {
@ -552,17 +540,48 @@ SmallVector<Value, 8> MaterializeRankSpecializationForSingleNonScalarOperand(
// Materialize ranked variants for the element-wise operations.
BlockAndValueMapping bvm;
for (auto it : llvm::zip(op.getBody()->getArguments(), op.operands())) {
Value operand = std::get<1>(it);
bvm.map(std::get<0>(it),
Value operand;
Value bb_arg;
std::tie(bb_arg, operand) = it;
bvm.map(bb_arg,
operand == non_scalar_operand ? flat_non_scalar_operand : operand);
}
SmallVector<Value, 8> unshaped_results =
MaterializeRankedOperations(b, loc, bvm, op, /*target_rank=*/1);
MaterializeRankedOperations(b, loc, bvm, op);
// Restore the results' expected shape.
return MaterializeFinalReshape(b, loc, op, unshaped_results);
}
Value MaterializeRankSpecializationForTwoNonScalarOperands(
OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op,
ValueRange non_scalar_operands) {
assert(non_scalar_operands.size() == 2);
auto shapes = llvm::to_vector<8>(llvm::map_range(op.operands(), [&](Value v) {
return b.create<shape::ShapeOfOp>(loc, v).result();
}));
auto non_scalar_lhs = llvm::find(op.operands(), non_scalar_operands[0]);
auto non_scalar_rhs = llvm::find(op.operands(), non_scalar_operands[1]);
// Materialize all the different cases.
Value unshaped_result = MaterializeScalarRankSpecializationCase(
b, loc, op, shapes, non_scalar_rhs.getIndex(),
[&](OpBuilder &b, Location loc) {
b.create<scf::YieldOp>(
loc, MaterializeScalarRankSpecializationCase(
b, loc, op, shapes, non_scalar_lhs.getIndex(),
[&](OpBuilder &b, Location loc) {
b.create<scf::YieldOp>(
loc, MaterializeDefaultRankSpecializationCases(
b, loc, op, shapes));
}));
});
// Materialize final reshape once and for all rank specialization cases.
return MaterializeFinalReshape(b, loc, op, unshaped_result).front();
}
// Materialize rank generic rank specialization.
Value MaterializeDefaultRankSpecialization(
OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op) {
@ -584,31 +603,36 @@ struct LowerRankSpecializationClusterPattern
LogicalResult matchAndRewrite(chlo::RankSpecializationClusterOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
// Restoring the result shape currently relies on all operands being used
// for a single result. The result shape is then the broadcasted shape of
// all operands.
if (op.getNumResults() != 1) return failure();
// TODO(frgossen): If there is a single operand, we can flatten it
// completely and apply a non-broadcasting operation.
// If there is only one unranked operand and all others are known scalars,
// we can flatten the operands to rank 1.
if (Optional<Value> non_scalar_operand =
FindUniqueNonScalar(op.operands())) {
// If there is only a single non-scalar operand, we can flatten that operand
// completely.
Location loc = op.getLoc();
auto non_scalar_operands =
llvm::to_vector<2>(llvm::make_filter_range(op.operands(), [](Value v) {
return !IsScalarTensorType(v.getType());
}));
if (non_scalar_operands.size() == 1) {
rewriter.replaceOp(op,
MaterializeRankSpecializationForSingleNonScalarOperand(
rewriter, loc, op, *non_scalar_operand));
rewriter, loc, op, non_scalar_operands.front()));
return success();
}
// If there are only two operands, we can consider extra cases in which
// either operand is scalar.
if (op->getNumOperands() == 2) {
rewriter.replaceOp(op, MaterializeRankSpecializationForExactlyTwoOperands(
rewriter, loc, op));
// If there are exactly two unranked operands and all others are known to be
// scalars, we can consider two extra cases: If either of the unranked
// operands turns out to be a scalar at runtime, we can, again, apply the
// trick for a single non-scalar operand.
if (non_scalar_operands.size() == 2 &&
llvm::all_of(non_scalar_operands, [](Value v) {
return v.getType().isa<UnrankedTensorType>();
})) {
rewriter.replaceOp(op,
MaterializeRankSpecializationForTwoNonScalarOperands(
rewriter, loc, op, non_scalar_operands));
return success();
}

View File

@ -267,7 +267,7 @@ func @sqrt_ranked(%arg: tensor<3x?xf32>) -> tensor<3x?xf32> {
// -----
// Ternary operation.
// Operation with mixed ranked and unranked operands.
// CHECK-LABEL: @select_mixed
// CHECK-SAME: (%[[PRED:.*]]: tensor<*xi1>, %[[ARG1:.*]]: tensor<*xf32>, %[[ARG2:.*]]: tensor<2xf32>)
func @select_mixed(%pred: tensor<*xi1>, %arg1: tensor<*xf32>,
@ -284,6 +284,7 @@ func @select_mixed(%pred: tensor<*xi1>, %arg1: tensor<*xf32>,
// CHECK-SCF-LABEL: @select_mixed
// CHECK-SCF: chlo.broadcast_select %{{.*}}, %{{.*}}, %{{.*}} : (tensor<?xi1>, tensor<?xf32>, tensor<?xf32>)
// CHECK-SCF: return
// -----
@ -417,6 +418,7 @@ func @angle(%arg : tensor<*xcomplex<f32>>) -> tensor<*xf32> {
// -----
// Scalar cluster operand.
// CHECK-LABEL: @xlogy
// CHECK-SAME: (%[[ARG0:.*]]: tensor<*xf32>, %[[ARG1:.*]]: tensor<*xf32>)
func @xlogy(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>) -> tensor<*xf32> {
@ -441,11 +443,87 @@ func @xlogy(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>) -> tensor<*xf32> {
return %5 : tensor<*xf32>
}
// CHECK-SCF-LABEL: @xlogy
// CHECK-SCF-DAG: %[[PRED:.*]] = chlo.broadcast_compare %{{.*}}, %{{.*}} {comparison_direction = "EQ"} : (tensor<?x?x?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?x?x?xf32>)
// CHECK-SCF-DAG: %[[TMP0:.*]] = "mhlo.log"(%{{.*}}) : (tensor<?x?x?x?x?x?x?x?xf32>)
// CHECK-SCF-DAG: %[[TMP1:.*]] = chlo.broadcast_multiply %{{.*}}, %[[TMP0]] : (tensor<?x?x?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?x?x?xf32>)
// CHECK-SCF: chlo.broadcast_select %[[PRED]], %{{.*}}, %[[TMP1]] : (tensor<?x?x?x?x?x?x?x?xi1>, tensor<?x?x?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?x?x?xf32>)
// CHECK-SCF: @xlogy
// CHECK-SCF-SAME: (%[[ARG0:.*]]: tensor<*xf32>, %[[ARG1:.*]]: tensor<*xf32>)
// CHECK-SCF-DAG: %[[C1:.*]] = constant 1
// CHECK-SCF-DAG: %[[ONE_SHAPE_1:.*]] = shape.const_shape [1]
// CHECK-SCF-DAG: %[[SHAPE_ARG0:.*]] = shape.shape_of %[[ARG0]]
// CHECK-SCF-DAG: %[[SHAPE_ARG1:.*]] = shape.shape_of %[[ARG1]]
// CHECK-SCF-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.00{{.*}}>
// Lhs scalar case:
// CHECK-SCF-DAG: %[[LHS_N:.*]] = shape.num_elements %[[SHAPE_ARG0]]
// CHECK-SCF-DAG: %[[LHS_SCALAR:.*]] = cmpi eq, %[[LHS_N]], %[[C1]]
// CHECK-SCF: %[[UNSHAPED_RES:.*]] = scf.if %[[LHS_SCALAR]]
// CHECK-SCF-DAG: %[[N:.*]] = shape.num_elements %[[SHAPE_ARG1]]
// CHECK-SCF-DAG: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[N]]
// CHECK-SCF-DAG: %[[FLAT_NON_SCALAR:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[FLAT_SHAPE]])
// CHECK-SCF-DAG: %[[SCALAR:.*]] = "mhlo.reshape"(%[[ARG0]])
// CHECK-SCF-DAG: %[[PRED:.*]] = chlo.broadcast_compare %[[SCALAR]], %[[ZERO]] {comparison_direction = "EQ"} : (tensor<f32>, tensor<f32>)
// CHECK-SCF-DAG: %[[TMP0:.*]] = "mhlo.log"(%[[FLAT_NON_SCALAR]]) : (tensor<?xf32>)
// CHECK-SCF-DAG: %[[TMP1:.*]] = chlo.broadcast_multiply %[[SCALAR]], %[[TMP0]] : (tensor<f32>, tensor<?xf32>)
// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_select %[[PRED]], %[[ZERO]], %[[TMP1]] : (tensor<i1>, tensor<f32>, tensor<?xf32>)
// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]]
// CHECK-SCF: scf.yield %[[INNER_RES_]]
// CHECK-SCF: else
// Rhs scalar case:
// CHECK-SCF-DAG: %[[RHS_N:.*]] = shape.num_elements %[[SHAPE_ARG1]]
// CHECK-SCF-DAG: %[[RHS_SCALAR:.*]] = cmpi eq, %[[RHS_N]], %[[C1]]
// CHECK-SCF: %{{.*}} = scf.if %[[RHS_SCALAR]]
// CHECK-SCF-DAG: %[[N:.*]] = shape.num_elements %[[SHAPE_ARG0]]
// CHECK-SCF-DAG: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[N]]
// CHECK-SCF-DAG: %[[FLAT_NON_SCALAR:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[FLAT_SHAPE]])
// CHECK-SCF-DAG: %[[SCALAR:.*]] = "mhlo.reshape"(%[[ARG1]])
// CHECK-SCF-DAG: %[[PRED:.*]] = chlo.broadcast_compare %[[FLAT_NON_SCALAR]], %[[ZERO]] {comparison_direction = "EQ"} : (tensor<?xf32>, tensor<f32>)
// CHECK-SCF-DAG: %[[TMP0:.*]] = "mhlo.log"(%[[SCALAR]]) : (tensor<f32>)
// CHECK-SCF-DAG: %[[TMP1:.*]] = chlo.broadcast_multiply %[[FLAT_NON_SCALAR]], %[[TMP0]] : (tensor<?xf32>, tensor<f32>)
// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_select %[[PRED]], %[[ZERO]], %[[TMP1]] : (tensor<?xi1>, tensor<f32>, tensor<?xf32>)
// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]]
// CHECK-SCF: scf.yield %[[INNER_RES_]]
// CHECK-SCF: else
// Equal shapes case:
// CHECK-SCF-DAG: %[[SHAPES_EQ:.*]] = shape.shape_eq %[[SHAPE_ARG0]], %[[SHAPE_ARG1]]
// CHECK-SCF: %{{.*}} = scf.if %[[SHAPES_EQ]]
// CHECK-SCF-DAG: %[[SHAPE:.*]] = shape.any %[[SHAPE_ARG0]], %[[SHAPE_ARG1]]
// CHECK-SCF-DAG: %[[N:.*]] = shape.num_elements %[[SHAPE]]
// CHECK-SCF-DAG: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[N]]
// CHECK-SCF-DAG: %[[FLAT_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[FLAT_SHAPE]])
// CHECK-SCF-DAG: %[[FLAT_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[FLAT_SHAPE]])
// CHECK-SCF-DAG: %[[PRED:.*]] = chlo.broadcast_compare %[[FLAT_ARG0]], %[[ZERO]] {comparison_direction = "EQ"} : (tensor<?xf32>, tensor<f32>)
// CHECK-SCF-DAG: %[[TMP0:.*]] = "mhlo.log"(%[[FLAT_ARG1]]) : (tensor<?xf32>)
// CHECK-SCF-DAG: %[[TMP1:.*]] = chlo.broadcast_multiply %[[FLAT_ARG0]], %[[TMP0]] : (tensor<?xf32>, tensor<?xf32>)
// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_select %[[PRED]], %[[ZERO]], %[[TMP1]] : (tensor<?xi1>, tensor<f32>, tensor<?xf32>)
// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]]
// CHECK-SCF: scf.yield %[[INNER_RES_]]
// CHECK-SCF: else
// Find maximum reduced rank.
// CHECK-SCF-DAG: %[[REDUCED_SHAPES:.*]]:2 = chlo.minimum_broadcast_shapes %[[SHAPE_ARG0]], %[[SHAPE_ARG1]]
// CHECK-SCF-DAG: %[[REDUCED_RANK0:.*]] = shape.rank %[[REDUCED_SHAPES]]#0
// CHECK-SCF-DAG: %[[REDUCED_RANK1:.*]] = shape.rank %[[REDUCED_SHAPES]]#1
// CHECK-SCF-DAG: %[[R0_GT_R1:.*]] = cmpi sgt, %[[REDUCED_RANK0]], %[[REDUCED_RANK1]]
// CHECK-SCF-DAG: %[[MAX_RED_RANK:.*]] = select %[[R0_GT_R1]], %[[REDUCED_RANK0]], %[[REDUCED_RANK1]]
// Generic case 1:
// CHECK-SCF: %[[MAX_RED_RANK_LE_1:.*]] = cmpi ule, %[[MAX_RED_RANK]], %[[C1]]
// CHECK-SCF: %{{.*}} = scf.if %[[MAX_RED_RANK_LE_1]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#0, %[[ONE_SHAPE_1]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1:.*]] = shape.broadcast %[[REDUCED_SHAPES]]#1, %[[ONE_SHAPE_1]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG0_:.*]] = tensor.cast %[[EXT_SHAPE_ARG0]]
// CHECK-SCF-DAG: %[[EXT_SHAPE_ARG1_:.*]] = tensor.cast %[[EXT_SHAPE_ARG1]]
// CHECK-SCF-DAG: %[[REDUCED_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[EXT_SHAPE_ARG0_]])
// CHECK-SCF-DAG: %[[REDUCED_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[EXT_SHAPE_ARG1_]])
// CHECK-SCF-DAG: %[[PRED:.*]] = chlo.broadcast_compare %[[REDUCED_ARG0]], %[[ZERO]] {comparison_direction = "EQ"} : (tensor<?xf32>, tensor<f32>)
// CHECK-SCF-DAG: %[[TMP0:.*]] = "mhlo.log"(%[[REDUCED_ARG1]]) : (tensor<?xf32>)
// CHECK-SCF-DAG: %[[TMP1:.*]] = chlo.broadcast_multiply %[[REDUCED_ARG0]], %[[TMP0]] : (tensor<?xf32>, tensor<?xf32>)
// CHECK-SCF-DAG: %[[INNER_RES:.*]] = chlo.broadcast_select %[[PRED]], %[[ZERO]], %[[TMP1]] : (tensor<?xi1>, tensor<f32>, tensor<?xf32>)
// CHECK-SCF-DAG: %[[INNER_RES_:.*]] = tensor.cast %[[INNER_RES]]
// CHECK-SCF: scf.yield %[[INNER_RES_]]
// CHECK-SCF: else
// ...
// Reshape the result.
// CHECK-SCF-DAG: %[[SHAPE_ARG0:.*]] = shape.shape_of %[[ARG0]]
// CHECK-SCF-DAG: %[[SHAPE_ARG1:.*]] = shape.shape_of %[[ARG1]]
// CHECK-SCF-DAG: %[[RES_SHAPE:.*]] = shape.broadcast %[[SHAPE_ARG0]], %[[SHAPE_ARG1]]
// CHECK-SCF-DAG: %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[UNSHAPED_RES]], %[[RES_SHAPE]])
// CHECK-SCF: return %[[RES]]
// -----