[MLIR][HLO] Remove unused pass `TransformUnrankedHloPass`

The pass was replaced by the new generalized rank specialization and the two
passes `mhlo-rank-specialization-cluster` and `mhlo-rank-specialization-to-scf`.

PiperOrigin-RevId: 379935562
This commit is contained in:
A. Unique TensorFlower 2021-06-17 05:19:54 -07:00 committed by TensorFlow MLIR Team
parent 10634ca3a6
commit 470ac45f45
8 changed files with 1 additions and 1068 deletions

24
BUILD
View File

@ -857,29 +857,6 @@ cc_library(
alwayslink = 1, alwayslink = 1,
) )
cc_library(
name = "transform_unranked_hlo",
srcs = ["lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc"],
hdrs = [
"include/mlir-hlo/Dialect/mhlo/transforms/passes.h",
"include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h",
],
deps = [
":hlo",
":map_chlo_to_hlo_op",
":pass_details",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:Shape",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:Transforms",
],
alwayslink = 1,
)
cc_library( cc_library(
name = "broadcast_propagation", name = "broadcast_propagation",
srcs = ["lib/Dialect/mhlo/transforms/broadcast_propagation.cc"], srcs = ["lib/Dialect/mhlo/transforms/broadcast_propagation.cc"],
@ -1379,7 +1356,6 @@ cc_library(
":rank_specialization", ":rank_specialization",
":sink_constants_to_control_flow", ":sink_constants_to_control_flow",
":test_passes", ":test_passes",
":transform_unranked_hlo",
"@llvm-project//mlir:Pass", "@llvm-project//mlir:Pass",
], ],
) )

View File

@ -112,12 +112,6 @@ def TestInferShapedTypeMethodsPass : FunctionPass<"mhlo-test-infer-shaped-type-m
let constructor = "createTestInferShapedTypeMethodsPass()"; let constructor = "createTestInferShapedTypeMethodsPass()";
} }
def TransformUnrankedHloPass : FunctionPass<"mhlo-transform-unranked-hlo"> {
let summary = "Realize element-wise operations on ranked tensors where possible.";
let constructor = "createTransformUnrankedHloPass()";
}
def BroadcastPropagationPass : FunctionPass<"mhlo-broadcast-propagation"> { def BroadcastPropagationPass : FunctionPass<"mhlo-broadcast-propagation"> {
let summary = "Move dynamic broadcasts up over element-wise operations and " let summary = "Move dynamic broadcasts up over element-wise operations and "
"broadcast the operands rather than the result. This will eventually allow " "broadcast the operands rather than the result. This will eventually allow "

View File

@ -32,9 +32,6 @@ class Pass;
namespace mhlo { namespace mhlo {
// Transforms unranked HLO operations to ranked ones where possible.
std::unique_ptr<FunctionPass> createTransformUnrankedHloPass();
/// Lowers HLO control flow ops to the Standard dialect. /// Lowers HLO control flow ops to the Standard dialect.
std::unique_ptr<OperationPass<FuncOp>> createLegalizeControlFlowPass(); std::unique_ptr<OperationPass<FuncOp>> createLegalizeControlFlowPass();

View File

@ -81,10 +81,6 @@ void SetupMaterializeBroadcastsLegality(MLIRContext *context,
void PopulateMaterializeBroadcastsPatterns(MLIRContext *context, void PopulateMaterializeBroadcastsPatterns(MLIRContext *context,
OwningRewritePatternList *patterns); OwningRewritePatternList *patterns);
// Sets up legality definitions for element-wise operations on ranked tensors.
void SetupTransformUnrankedHloLegality(MLIRContext *context,
ConversionTarget *conversionTarget);
// Populates a collection of rewrite patterns to realize element-wise operations // Populates a collection of rewrite patterns to realize element-wise operations
// on ranked tensors where possible. // on ranked tensors where possible.
void PopulateTransformUnrankedHloPatterns(MLIRContext *context, void PopulateTransformUnrankedHloPatterns(MLIRContext *context,

View File

@ -62,7 +62,6 @@ add_mlir_library(MhloPasses
rank_specialization.cc rank_specialization.cc
sink_constants_to_control_flow.cc sink_constants_to_control_flow.cc
test_infer_shaped_type_pass.cc test_infer_shaped_type_pass.cc
transform_unranked_hlo.cc
unfuse_batch_norm.cc unfuse_batch_norm.cc
unfuse_batch_norm_pass.cc unfuse_batch_norm_pass.cc

View File

@ -52,7 +52,7 @@ struct ConvertConstantLikeOp : public OpConversionPattern<ConstantLikeOp> {
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
auto result_ty = op.getType().cast<ShapedType>(); auto result_ty = op.getType().cast<ShapedType>();
// Unranked uses are not supported. Consider `mhlo-transform-unranked-hlo`. // Unranked uses are not supported.
if (!result_ty.hasRank()) return failure(); if (!result_ty.hasRank()) return failure();
// Lower to MHLO constant if statically shaped. // Lower to MHLO constant if statically shaped.

View File

@ -1,619 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
#include "mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h"
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
namespace {
// TODO(herhut): Generate these out of op definitions.
#define MAP_XLA_OPERATION_CWISE_UNARY(fn, sep) \
fn(AbsOp) sep fn(CeilOp) sep fn(ClzOp) sep fn(ConvertOp) sep fn(CosOp) \
sep fn(ExpOp) sep fn(Expm1Op) sep fn(FloorOp) sep fn(ImagOp) \
sep fn(IsFiniteOp) sep fn(LogOp) sep fn(Log1pOp) sep fn(LogisticOp) \
sep fn(NotOp) sep fn(NegOp) sep fn(PopulationCountOp) \
sep fn(RealOp) sep fn(RoundOp) sep fn(RsqrtOp) \
sep fn(SignOp) sep fn(SinOp) sep fn(SqrtOp) \
sep fn(TanhOp)
// TODO(herhut): Generate these out of op definitions.
#define MAP_XLA_OPERATION_CWISE_BINARY(fn, sep) \
fn(AddOp) sep fn(AndOp) sep fn(Atan2Op) sep fn(ComplexOp) sep fn(DivOp) \
sep fn(MaxOp) sep fn(MinOp) sep fn(MulOp) sep fn(OrOp) sep fn(PowOp) \
sep fn(RemOp) sep fn(ShiftLeftOp) sep fn(ShiftRightArithmeticOp) \
sep fn(ShiftRightLogicalOp) sep fn(SubOp) sep fn(XorOp)
// TODO(herhut): Generate these out of op definitions.
#define MAP_CHLO_OPERATION_CWISE_UNARY(fn, sep) \
fn(AcosOp) sep fn(AcoshOp) sep fn(AsinOp) sep fn(AsinhOp) sep fn(AtanOp) \
sep fn(AtanhOp) sep fn(ConjOp) sep fn(CoshOp) sep fn(DigammaOp) \
sep fn(ErfOp) sep fn(ErfcOp) sep fn(IsInfOp) sep fn(LgammaOp) \
sep fn(SinhOp) sep fn(TanOp)
// TODO(herhut): Generate these out of op definitions.
#define MAP_CHLO_OPERATION_CWISE_BINARY(fn, sep) fn(PolygammaOp) sep fn(ZetaOp)
template <typename OpTy>
inline void AddLegalOpOnRankedTensor(ConversionTarget *target) {
target->addDynamicallyLegalOp<OpTy>([](OpTy op) {
return llvm::all_of(op.getOperation()->getOperandTypes(),
[&](Type t) { return t.isa<RankedTensorType>(); });
});
}
/// Element-wise operations on unranked tensors can be applied to the flattened
/// tensor operands with the same effect. This pattern rewrites every such
/// operation to
/// (i) flatten the input tensor,
/// (ii) apply the operation, and
/// (iii) restore the original shape.
template <typename OpTy>
struct ElementwiseOpConversion : public OpRewritePattern<OpTy> {
explicit ElementwiseOpConversion(MLIRContext *context)
: OpRewritePattern<OpTy>(context) {}
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
// Only apply conversion if at least one operand is unranked.
if (llvm::none_of(op.getOperation()->getOperands(), [&](Value operand) {
return operand.getType().isa<UnrankedTensorType>();
})) {
return failure();
}
// Get operands' shape.
auto loc = op.getLoc();
Type extentTensorTy = shape::getExtentTensorType(rewriter.getContext());
SmallVector<Value, 3> operandShapes;
for (Value operand : op.getOperation()->getOperands()) {
Value shape =
rewriter.create<shape::ShapeOfOp>(loc, extentTensorTy, operand);
operandShapes.push_back(shape);
}
Value shape =
operandShapes.size() == 1
? operandShapes.front()
: rewriter.create<shape::AnyOp>(loc, extentTensorTy, operandShapes);
// Derive flat shape.
Type indexTy = rewriter.getIndexType();
Value numElements =
rewriter.create<shape::NumElementsOp>(loc, indexTy, shape);
Value flatShape = rewriter.create<tensor::FromElementsOp>(loc, numElements);
// Flatten operands.
SmallVector<Value, 3> flatOperands;
for (Value operand : op.getOperation()->getOperands()) {
Type operandElementTy =
operand.getType().template cast<ShapedType>().getElementType();
Type flatTy =
RankedTensorType::get({ShapedType::kDynamicSize}, operandElementTy);
Value flat = rewriter.create<mhlo::DynamicReshapeOp>(loc, flatTy, operand,
flatShape);
flatOperands.push_back(flat);
}
// Apply operation to flattened operands.
Type resultElementTy =
op.getType().template cast<ShapedType>().getElementType();
Type flatResultTy =
RankedTensorType::get({ShapedType::kDynamicSize}, resultElementTy);
Value flatResult =
rewriter.create<OpTy>(loc, flatResultTy, flatOperands, op->getAttrs());
// Restore original shape.
rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(op, op.getType(),
flatResult, shape);
return success();
}
};
// Converts a broadcasting binary operation with a scalar operand and an
// unranked operand to a ranked broadcasting operation by dynamically reshaping
// the unranked operand to a 1D tensor. This will always be safe because
// broadcasting from a scalar to another shape always works.
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
struct ConvertUnrankedScalarDynamicBroadcastBinaryOp
: public OpConversionPattern<ChloOpTy> {
using OpConversionPattern<ChloOpTy>::OpConversionPattern;
LogicalResult matchAndRewrite(
ChloOpTy op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
typename ChloOpTy::Adaptor transformed(operands);
Value lhs = transformed.lhs();
Value rhs = transformed.rhs();
auto lhs_ranked_type = lhs.getType().dyn_cast<RankedTensorType>();
auto lhs_unranked_type = lhs.getType().dyn_cast<UnrankedTensorType>();
auto rhs_ranked_type = rhs.getType().dyn_cast<RankedTensorType>();
auto rhs_unranked_type = rhs.getType().dyn_cast<UnrankedTensorType>();
bool lhs_is_scalar = lhs_ranked_type &&
lhs_ranked_type.getShape().empty() &&
rhs_unranked_type;
bool rhs_is_scalar = rhs_ranked_type &&
rhs_ranked_type.getShape().empty() &&
lhs_unranked_type;
// Only support the case where exactly one operand is scalar and the other
// is unranked. Other patterns in chlo-to-hlo legalization will create more
// efficient lowerings for cases where both ranks are known or will handle
// the more generic case of both inputs being unranked.
if (!(lhs_is_scalar ^ rhs_is_scalar)) return failure();
auto scalar_element_type = lhs_is_scalar ? lhs_ranked_type.getElementType()
: rhs_ranked_type.getElementType();
auto result_type = op.getResult().getType().template dyn_cast<TensorType>();
auto result_element_type = result_type.getElementType();
// Reshape the non-scalar value into a dynamically sized, rank-1 tensor
Value shape =
rewriter.create<shape::ShapeOfOp>(loc, lhs_is_scalar ? rhs : lhs);
Value num_elements = rewriter.create<shape::NumElementsOp>(loc, shape);
Value size_tensor =
rewriter.create<tensor::FromElementsOp>(loc, num_elements);
Value reshaped = rewriter.create<mhlo::DynamicReshapeOp>(
loc,
RankedTensorType::get({RankedTensorType::kDynamicSize},
scalar_element_type),
lhs_is_scalar ? rhs : lhs, size_tensor);
// Create a new ranked Chlo op that will be further lowered by other
// patterns into Mhlo.
SmallVector<Value, 2> new_operands{lhs_is_scalar ? lhs : reshaped,
rhs_is_scalar ? rhs : reshaped};
Value computed = rewriter.create<ChloOpTy>(
loc,
TypeRange{RankedTensorType::get({RankedTensorType::kDynamicSize},
result_element_type)},
new_operands, op->getAttrs());
// Reshape the result back into an unranked tensor.
rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(op, result_type,
computed, shape);
return success();
}
};
// Handles lowering of the following pattern to patterns that will be further
// matched by other patterns until they result in LHLO:
// %result = "chlo.op"(%op0, %op1, ...) : (<*xTy>, <*xTy>, ...) -> <*xTy>
//
// The sequence of specializations this handles is:
// - At most one operand has a shape that does not consist of exactly one
// element.
// - All operands having equal shapes
// - The resulting minimized shapes being any of ranks [1,5]
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
struct ConvertUnrankedDynamicBroadcastNaryOp
: public OpConversionPattern<ChloOpTy> {
using OpConversionPattern<ChloOpTy>::OpConversionPattern;
LogicalResult matchAndRewrite(
ChloOpTy op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
typename ChloOpTy::Adaptor transformed(operands);
ValueRange transformed_operands = transformed.getOperands();
auto num_operands = transformed_operands.size();
llvm::SmallVector<Type, 3> operand_element_types;
operand_element_types.reserve(num_operands);
bool has_unranked_tensor_type = false;
for (int i = 0; i < num_operands; ++i) {
if (auto type =
transformed_operands[i].getType().dyn_cast<TensorType>()) {
if (type.isa<UnrankedTensorType>()) {
has_unranked_tensor_type = true;
}
operand_element_types.push_back(type.getElementType());
} else {
return failure();
}
}
if (!has_unranked_tensor_type) return failure();
auto result_type = op.getResult().getType().template dyn_cast<TensorType>();
llvm::SmallVector<Value> shapes;
shapes.reserve(num_operands);
for (int i = 0; i < num_operands; ++i) {
shapes.push_back(
rewriter.create<shape::ShapeOfOp>(loc, transformed_operands[i]));
}
// If at most one shape does not have exactly one element
Value counter = rewriter.create<ConstantIndexOp>(loc, 0);
Value one = rewriter.create<ConstantIndexOp>(loc, 1);
for (int i = 0; i < num_operands; ++i) {
Value is_scalar_like = IsSingleElementShape(rewriter, op, shapes[i]);
Value counter_plus_one = rewriter.create<AddIOp>(loc, counter, one);
counter = rewriter.create<SelectOp>(loc, is_scalar_like, counter_plus_one,
counter);
}
Value num_operands_minus_one =
rewriter.create<ConstantIndexOp>(loc, num_operands - 1);
Value at_most_one_non_scalar =
rewriter.create<CmpIOp>(loc, rewriter.getI1Type(), CmpIPredicate::uge,
counter, num_operands_minus_one);
auto if_op = rewriter.create<scf::IfOp>(loc, result_type,
at_most_one_non_scalar, true);
OpBuilder if_at_most_one_non_scalar_builder =
if_op.getThenBodyBuilder(rewriter.getListener());
llvm::SmallVector<Value, 3> reshaped_operands;
reshaped_operands.reserve(num_operands);
for (int i = 0; i < num_operands; ++i) {
Value num_elements =
if_at_most_one_non_scalar_builder.create<shape::NumElementsOp>(
loc, shapes[i]);
Value size_tensor =
if_at_most_one_non_scalar_builder.create<tensor::FromElementsOp>(
loc, num_elements);
Value reshaped =
if_at_most_one_non_scalar_builder.create<mhlo::DynamicReshapeOp>(
loc,
RankedTensorType::get({RankedTensorType::kDynamicSize},
operand_element_types[i]),
transformed_operands[i], size_tensor);
reshaped_operands.push_back(reshaped);
}
auto rank_one_result_type = RankedTensorType::get(
{RankedTensorType::kDynamicSize}, result_type.getElementType());
Value if_at_most_one_non_scalar_result =
if_at_most_one_non_scalar_builder.create<ChloOpTy>(
loc, ArrayRef<Type>{rank_one_result_type}, reshaped_operands,
op->getAttrs());
Value extended_result = extendToBroadcastShape(
if_at_most_one_non_scalar_builder, loc, result_type,
if_at_most_one_non_scalar_result, shapes);
if_at_most_one_non_scalar_builder.create<scf::YieldOp>(loc,
extended_result);
// If there is more than one shape which does not have exactly one element
//
// See if all shapes are equal.
OpBuilder else_builder = if_op.getElseBodyBuilder(rewriter.getListener());
Value equal_shapes =
else_builder.create<shape::ShapeEqOp>(loc, shapes[0], shapes[1]);
for (int i = 2; i < num_operands; ++i) {
Value are_equal =
else_builder.create<shape::ShapeEqOp>(loc, shapes[0], shapes[i]);
equal_shapes = else_builder.create<AndOp>(loc, equal_shapes, are_equal);
}
auto if_eq_shapes_op =
else_builder.create<scf::IfOp>(loc, result_type, equal_shapes, true);
else_builder.create<scf::YieldOp>(loc, if_eq_shapes_op.getResult(0));
OpBuilder if_eq_shapes_builder =
if_eq_shapes_op.getThenBodyBuilder(rewriter.getListener());
Value non_broadcast_op = Adaptor::CreateOp(
op, result_type, transformed_operands, if_eq_shapes_builder);
if_eq_shapes_builder.create<scf::YieldOp>(loc, non_broadcast_op);
// If shapes do not have exactly one element, nor are equal
//
// See if values are of a rank that we support.
OpBuilder if_neq_shapes_builder =
if_eq_shapes_op.getElseBodyBuilder(rewriter.getListener());
if_neq_shapes_builder.create<scf::YieldOp>(
loc,
HandleBroadcastAndOp(if_neq_shapes_builder, op, transformed_operands));
rewriter.replaceOp(op, {if_op.getResult(0)});
return success();
}
private:
// Returns the dynamic result of checking the given value is effectively a
// scalar shape (i.e. the number of elements is 1).
Value IsSingleElementShape(OpBuilder &rewriter, ChloOpTy op,
Value shape_of_tensor) const {
auto loc = op.getLoc();
Value num_elements =
rewriter.create<shape::NumElementsOp>(loc, shape_of_tensor);
return rewriter.create<CmpIOp>(loc, rewriter.getI1Type(), CmpIPredicate::eq,
num_elements,
rewriter.create<ConstantIndexOp>(loc, 1));
}
Value extendToBroadcastShape(OpBuilder &builder, Location loc,
Type result_type, Value value,
ValueRange shapes) const {
auto unknown_rank_extent_tensor_type = RankedTensorType::get(
{RankedTensorType::kDynamicSize}, builder.getIndexType());
Value broadcast_shape = builder.create<shape::BroadcastOp>(
loc, unknown_rank_extent_tensor_type, shapes, nullptr);
return builder.create<mhlo::DynamicReshapeOp>(loc, result_type, value,
broadcast_shape);
}
// Returns the dynamic result of checking the given value is effectively a
// scalar shape (i.e. the number of elements is 1).
Value GreaterRankIsN(OpBuilder &builder, Location loc, Value actual_rank,
int targeted_rank) const {
return builder.create<CmpIOp>(
loc, CmpIPredicate::eq, actual_rank,
builder.create<ConstantIndexOp>(loc, targeted_rank));
}
scf::IfOp createIfOpForRankSpecializedBroadcastAndOp(
OpBuilder &builder, ChloOpTy op, Value actual_rank,
int targeted_rank) const {
// Create the if block to place the current specialized logic in.
Value greater_rank_is_n =
GreaterRankIsN(builder, op.getLoc(), actual_rank, targeted_rank);
return builder.create<scf::IfOp>(op.getLoc(), op.getResult().getType(),
greater_rank_is_n, true);
}
Value createBroadcastToKnownRank(OpBuilder &builder, ChloOpTy op, Value shape,
int targeted_rank) const {
auto loc = op.getLoc();
SmallVector<int64_t, 6> ranked_shape(targeted_rank, 1);
auto unknown_rank_extent_tensor_type = RankedTensorType::get(
{RankedTensorType::kDynamicSize}, builder.getIndexType());
auto known_rank_extent_tensor_type =
RankedTensorType::get({targeted_rank}, builder.getIndexType());
Value ranked_shape_val = builder.create<shape::ConstShapeOp>(
loc, known_rank_extent_tensor_type,
mlir::DenseIntElementsAttr::get(known_rank_extent_tensor_type,
ranked_shape));
Value extended_value = builder.create<shape::BroadcastOp>(
loc, unknown_rank_extent_tensor_type, shape, ranked_shape_val, nullptr);
return builder.create<tensor::CastOp>(loc, known_rank_extent_tensor_type,
extended_value);
}
// Create the if statement and code for a broadcasting op with a result of a
// given rank.
void createRankSpecializedBroadcastAndOp(OpBuilder &if_builder, ChloOpTy op,
ValueRange operands,
ValueRange operand_shapes,
int targeted_rank) const {
auto loc = op.getLoc();
SmallVector<Value, 2> reshaped_operands;
auto dynamic_dimensions = llvm::SmallVector<int64_t, 6>(
targeted_rank, RankedTensorType::kDynamicSize);
for (auto it : llvm::zip(operands, operand_shapes)) {
Value operand, shape;
std::tie(operand, shape) = it;
// Handle shape broadcasting and inference.
Value extended_operand_casted =
createBroadcastToKnownRank(if_builder, op, shape, targeted_rank);
// 1. Reshape operands to the given rank (with the same number of
// elements)
// 2. Compute the ranked-broadcasted ChloOp (which will assert that the
// ops
// can be broadcasted and do the actual broadcasting)
// 3. Type erase the output back to unranked
auto reshaped_type = RankedTensorType::get(
dynamic_dimensions,
operand.getType().template dyn_cast<TensorType>().getElementType());
Value reshaped_operand = if_builder.create<mhlo::DynamicReshapeOp>(
loc, reshaped_type, operand, extended_operand_casted);
reshaped_operands.push_back(reshaped_operand);
}
auto result_element_type = op.getResult()
.getType()
.template dyn_cast<TensorType>()
.getElementType();
auto result_type =
RankedTensorType::get(dynamic_dimensions, result_element_type);
Value result = if_builder.create<ChloOpTy>(
loc, ArrayRef<Type>{result_type}, reshaped_operands, op->getAttrs());
Value reshaped_result = if_builder.create<tensor::CastOp>(
loc, UnrankedTensorType::get(result_element_type), result);
if_builder.create<scf::YieldOp>(loc, reshaped_result);
}
// Iterates over the desired ranks to be specialized and generates the code
// snippet for each case.
Value HandleBroadcastAndOp(OpBuilder &rewriter, ChloOpTy op,
ValueRange operands) const {
auto loc = op.getLoc();
// Get the minimum broadcast shapes of the operands.
SmallVector<Value> shapes;
shapes.reserve(operands.size());
auto extent_tensor_type = RankedTensorType::get({ShapedType::kDynamicSize},
rewriter.getIndexType());
for (Value operand : operands) {
Value shape =
rewriter.create<shape::ShapeOfOp>(loc, extent_tensor_type, operand);
shapes.push_back(shape);
}
auto broadcast_shape = rewriter.create<shape::BroadcastOp>(
loc, extent_tensor_type, shapes, nullptr);
SmallVector<Type> result_types(shapes.size(), extent_tensor_type);
auto reduced_shapes =
rewriter
.create<chlo::MinimumBroadcastShapesOp>(loc, result_types, shapes)
.results();
SmallVector<Value> reshaped_operands;
reshaped_operands.reserve(operands.size());
for (auto it : llvm::zip(operands, reduced_shapes)) {
Value operand;
Value reduced_shape;
std::tie(operand, reduced_shape) = it;
auto reshaped_operand = rewriter.create<mhlo::DynamicReshapeOp>(
loc, operand.getType(), operand, reduced_shape);
reshaped_operands.push_back(reshaped_operand);
}
// Find the largest rank of the operands.
Value greater_rank;
for (Value shape : reduced_shapes) {
Value rank =
rewriter.create<shape::RankOp>(loc, rewriter.getIndexType(), shape);
if (!greater_rank) {
greater_rank = rank;
} else {
Value greater_rank_compare = rewriter.create<CmpIOp>(
loc, CmpIPredicate::sgt, greater_rank, rank);
greater_rank = rewriter.create<SelectOp>(loc, greater_rank_compare,
greater_rank, rank);
}
}
// Generate a list of nested if/else statements to handle rank
// specializations from 1 to `kMaxRankSpecialization`.
scf::IfOp if_op = createIfOpForRankSpecializedBroadcastAndOp(
rewriter, op, greater_rank, 1);
OpBuilder if_builder = if_op.getThenBodyBuilder(rewriter.getListener());
createRankSpecializedBroadcastAndOp(if_builder, op, reshaped_operands,
reduced_shapes, 1);
// Put each subsequent rank specialization inside the else statement of the
// previous one.
OpBuilder else_builder = if_op.getElseBodyBuilder(rewriter.getListener());
// Tensorflow supports up to rank 8 for SelectOp (currently the only op with
// arity > 2 that we support), but only up to rank 5 for binary ops. We want
// to preserve this behavior.
const int kMaxRankSpecialization = operands.size() > 2 ? 8 : 5;
for (int i = 2; i < kMaxRankSpecialization; i++) {
auto inner_if = createIfOpForRankSpecializedBroadcastAndOp(
else_builder, op, greater_rank, i);
if_builder = inner_if.getThenBodyBuilder(rewriter.getListener());
createRankSpecializedBroadcastAndOp(if_builder, op, reshaped_operands,
reduced_shapes, i);
else_builder.create<scf::YieldOp>(loc, inner_if.getResult(0));
else_builder = inner_if.getElseBodyBuilder(rewriter.getListener());
}
// Fire an assertion if none of the rank specializations applied (one of
// the ranks was greater than `kMaxRankSpecialization`).
else_builder.create<AssertOp>(
loc,
GreaterRankIsN(else_builder, op.getLoc(), greater_rank,
kMaxRankSpecialization),
"Input for dynamic binary op lowering was of a rank greater than " +
std::to_string(kMaxRankSpecialization));
// Add the rank 5 specialization to the innermost else block.
createRankSpecializedBroadcastAndOp(else_builder, op, reshaped_operands,
reduced_shapes, kMaxRankSpecialization);
// Return the reshaped result of the outermost if statement.
auto result = if_op.getResult(0);
auto reshaped_result = rewriter.create<mhlo::DynamicReshapeOp>(
loc, result.getType(), result, broadcast_shape);
return reshaped_result;
}
};
struct TransformUnrankedHloPass
: public mhlo::TransformUnrankedHloPassBase<TransformUnrankedHloPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<chlo::HloClientDialect, mhlo::MhloDialect, scf::SCFDialect,
shape::ShapeDialect>();
}
void runOnFunction() override {
// Setup conversion target.
MLIRContext &ctx = getContext();
ConversionTarget target(ctx);
target.addLegalDialect<chlo::HloClientDialect, mhlo::MhloDialect,
StandardOpsDialect, shape::ShapeDialect,
scf::SCFDialect, tensor::TensorDialect>();
target.addLegalOp<FuncOp>();
#define ADD_LEGAL_MHLO(op) AddLegalOpOnRankedTensor<mhlo::op>(&target)
#define ADD_LEGAL_CHLO(op) AddLegalOpOnRankedTensor<chlo::op>(&target)
MAP_XLA_OPERATION_CWISE_UNARY(ADD_LEGAL_MHLO, ;);
MAP_XLA_OPERATION_CWISE_BINARY(ADD_LEGAL_MHLO, ;);
MAP_CHLO_OPERATION_CWISE_UNARY(ADD_LEGAL_CHLO, ;);
#undef ADD_LEGAL_MHLO
#undef ADD_LEGAL_CHLO
AddLegalOpOnRankedTensor<mhlo::CompareOp>(&target);
AddLegalOpOnRankedTensor<mhlo::SelectOp>(&target);
target.addDynamicallyLegalDialect<chlo::HloClientDialect>(
[](Operation *op) {
return llvm::none_of(op->getOperandTypes(), [](Type type) {
return type.isa<UnrankedTensorType>();
});
});
// Populate rewrite patterns.
OwningRewritePatternList patterns(&ctx);
mhlo::PopulateTransformUnrankedHloPatterns(&ctx, &patterns);
// Apply transformation.
if (failed(
applyPartialConversion(getFunction(), target, std::move(patterns))))
return signalPassFailure();
}
};
} // namespace
namespace mhlo {
void PopulateTransformUnrankedHloPatterns(MLIRContext *context,
OwningRewritePatternList *patterns) {
#define MAP_HLO(op) ElementwiseOpConversion<mhlo::op>
#define MAP_CHLO(op) ElementwiseOpConversion<chlo::op>
#define COMMA ,
// clang-format off
patterns->insert<
MAP_XLA_OPERATION_CWISE_UNARY(MAP_HLO, COMMA),
MAP_XLA_OPERATION_CWISE_BINARY(MAP_HLO, COMMA),
MAP_CHLO_OPERATION_CWISE_UNARY(MAP_CHLO, COMMA),
MAP_CHLO_OPERATION_CWISE_BINARY(MAP_CHLO, COMMA),
ElementwiseOpConversion<mhlo::CompareOp>,
ElementwiseOpConversion<mhlo::SelectOp>>(context);
// clang-format on
#undef MAP_HLO
#undef MAP_CHLO
#undef COMMA
chlo::PopulateForBroadcastingBinaryOp<ConvertUnrankedDynamicBroadcastNaryOp>(
context, patterns);
patterns->insert<ConvertUnrankedDynamicBroadcastNaryOp<
chlo::BroadcastSelectOp, mhlo::SelectOp,
chlo::HloNaryElementwiseAdaptor<chlo::BroadcastSelectOp,
mhlo::SelectOp>>>(context);
chlo::PopulateForBroadcastingBinaryOp<
ConvertUnrankedScalarDynamicBroadcastBinaryOp>(context, patterns);
}
std::unique_ptr<FunctionPass> createTransformUnrankedHloPass() {
return std::make_unique<TransformUnrankedHloPass>();
}
} // namespace mhlo
} // namespace mlir

View File

@ -1,410 +0,0 @@
// RUN: mlir-hlo-opt --mhlo-transform-unranked-hlo --cse --split-input-file %s | FileCheck %s
// Check the validity of expected IR.
// CHECK-LABEL: @sqr_transform_result
func @sqr_transform_result(%a: tensor<*xf32>) -> tensor<*xf32> {
// Flatten operand shape.
%shape = shape.shape_of %a : tensor<*xf32> -> tensor<?xindex>
%num_elements = shape.num_elements %shape : tensor<?xindex> -> index
%flat_shape = tensor.from_elements %num_elements : tensor<1xindex>
%flat_a = "mhlo.dynamic_reshape"(%a, %flat_shape)
: (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// Apply operation.
%flat_b = "mhlo.sqrt"(%flat_a) : (tensor<?xf32>) -> tensor<?xf32>
// Restore original shape.
%b = "mhlo.dynamic_reshape"(%flat_b, %shape)
: (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
return %b : tensor<*xf32>
}
// -----
// Check transformation of unranked code.
// CHECK-LABEL: @sqrt
// CHECK-SAME: (%[[A:.*]]: tensor<*xf32>)
func @sqrt(%a: tensor<*xf32>) -> tensor<*xf32> {
// CHECK-NEXT: %[[SHAPE:.*]] = shape.shape_of %[[A]] : tensor<*xf32> -> tensor<?xindex>
// CHECK-NEXT: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE]]
// CHECK-NEXT: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[NUM_ELEMENTS]] : tensor<1xindex>
// CHECK-NEXT: %[[FLAT_A:.*]] = "mhlo.dynamic_reshape"(%[[A]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK-NEXT: %[[FLAT_B:.*]] = "mhlo.sqrt"(%[[FLAT_A]]) : (tensor<?xf32>) -> tensor<?xf32>
// CHECK-NEXT: %[[B:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_B]], %[[SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK-NEXT: return %[[B]] : tensor<*xf32>
%b = "mhlo.sqrt"(%a) : (tensor<*xf32>) -> tensor<*xf32>
return %b : tensor<*xf32>
}
// -----
// Not transformed when ranked.
// CHECK-LABEL: @sqrt_ranked
// CHECK-SAME: (%[[A:.*]]: tensor<3x?xf32>)
func @sqrt_ranked(%a: tensor<3x?xf32>) -> tensor<3x?xf32> {
// CHECK-NEXT: %[[B:.*]] = "mhlo.sqrt"(%[[A]]) : (tensor<3x?xf32>) -> tensor<3x?xf32>
// CHECK-NEXT: return %[[B]] : tensor<3x?xf32>
%b = "mhlo.sqrt"(%a) : (tensor<3x?xf32>) -> tensor<3x?xf32>
return %b : tensor<3x?xf32>
}
// -----
// Not transformed when statically shaped.
// CHECK-LABEL: @sqrt_static
// CHECK-SAME: (%[[A:.*]]: tensor<2x3xf32>)
func @sqrt_static(%a: tensor<2x3xf32>) -> tensor<2x3xf32> {
// CHECK-NEXT: %[[B:.*]] = "mhlo.sqrt"(%[[A]]) : (tensor<2x3xf32>) -> tensor<2x3xf32>
// CHECK-NEXT: return %[[B]] : tensor<2x3xf32>
%b = "mhlo.sqrt"(%a) : (tensor<2x3xf32>) -> tensor<2x3xf32>
return %b : tensor<2x3xf32>
}
// -----
// Transformed if there is a mix of unranked/static shapes.
// CHECK-LABEL: @select_mixed
// CHECK-SAME: (%[[PRED:.*]]: tensor<*xi1>, %[[ON_TRUE:.*]]: tensor<*xf32>, %[[ON_FALSE:.*]]: tensor<2xf32>)
func @select_mixed(%pred: tensor<*xi1>, %on_true: tensor<*xf32>, %on_false: tensor<2xf32>) -> tensor<*xf32> {
// CHECK: %[[SHAPE_PRED:.*]] = shape.shape_of %[[PRED]]
// CHECK: %[[SHAPE_ON_TRUE:.*]] = shape.shape_of %[[ON_TRUE]]
// CHECK: %[[SHAPE_ON_FALSE:.*]] = shape.shape_of %[[ON_FALSE]]
// CHECK: %[[SHAPE:.*]] = shape.any %[[SHAPE_PRED]], %[[SHAPE_ON_TRUE]], %[[SHAPE_ON_FALSE]]
// CHECK: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE]]
// CHECK: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[NUM_ELEMENTS]] : tensor<1xindex>
// CHECK: %[[FLAT_PRED:.*]] = "mhlo.dynamic_reshape"(%[[PRED]], %[[FLAT_SHAPE]]) : (tensor<*xi1>, tensor<1xindex>) -> tensor<?xi1>
// CHECK: %[[FLAT_ON_TRUE:.*]] = "mhlo.dynamic_reshape"(%[[ON_TRUE]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK: %[[FLAT_ON_FALSE:.*]] = "mhlo.dynamic_reshape"(%[[ON_FALSE]], %[[FLAT_SHAPE]]) : (tensor<2xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK: %[[FLAT_RESULT:.*]] = "mhlo.select"(%[[FLAT_PRED]], %[[FLAT_ON_TRUE]], %[[FLAT_ON_FALSE]]) : (tensor<?xi1>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
// CHECK: %[[RESULT:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_RESULT]], %[[SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK: return %[[RESULT]] : tensor<*xf32>
%b = "mhlo.select"(%pred, %on_true, %on_false) : (tensor<*xi1>, tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32>
return %b : tensor<*xf32>
}
// -----
// CHECK-LABEL: @add_unranked
// CHECK-SAME: (%[[A:.*]]: tensor<*xf32>, %[[B:.*]]: tensor<*xf32>) -> tensor<*xf32>
func @add_unranked(%a : tensor<*xf32>, %b : tensor<*xf32>) -> tensor<*xf32> {
// CHECK: %[[SHAPE_A:.*]] = shape.shape_of %[[A]]
// CHECK: %[[SHAPE_B:.*]] = shape.shape_of %[[B]]
// CHECK: %[[SHAPE:.*]] = shape.any %[[SHAPE_A]], %[[SHAPE_B]]
// CHECK: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE]]
// CHECK: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[NUM_ELEMENTS]] : tensor<1xindex>
// CHECK: %[[FLAT_A:.*]] = "mhlo.dynamic_reshape"(%[[A]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK: %[[FLAT_B:.*]] = "mhlo.dynamic_reshape"(%[[B]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK: %[[FLAT_RESULT:.*]] = mhlo.add %[[FLAT_A]], %[[FLAT_B]] : tensor<?xf32>
// CHECK: %[[RESULT:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_RESULT]], %[[SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK: return %[[RESULT]] : tensor<*xf32>
%result = mhlo.add %a, %b : tensor<*xf32>
return %result : tensor<*xf32>
}
// -----
// CHECK-LABEL: @tan
// CHECK-SAME: (%[[A:.*]]: tensor<*xf32>) -> tensor<*xf32>
func @tan(%a : tensor<*xf32>) -> tensor<*xf32> {
// CHECK: %[[SHAPE:.*]] = shape.shape_of %[[A]] : tensor<*xf32> -> tensor<?xindex>
// CHECK: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE]]
// CHECK: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[NUM_ELEMENTS]] : tensor<1xindex>
// CHECK: %[[FLAT_A:.*]] = "mhlo.dynamic_reshape"(%[[A]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK: %[[FLAT_B:.*]] = chlo.tan %[[FLAT_A]] : tensor<?xf32> -> tensor<?xf32>
// CHECK: %[[B:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_B]], %[[SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK: return %[[B]] : tensor<*xf32>
%result = chlo.tan %a : tensor<*xf32> -> tensor<*xf32>
return %result : tensor<*xf32>
}
// -----
func @addScalarUnranked(%arg0: tensor<f32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
%0 = chlo.broadcast_add %arg0, %arg1 : (tensor<f32>, tensor<*xf32>)
-> tensor<*xf32>
return %0 : tensor<*xf32>
}
// CHECK-LABEL: func @addScalarUnranked(
// CHECK-SAME: %[[ARG_0:.*]]: tensor<f32>,
// CHECK-SAME: %[[ARG_1:.*]]: tensor<*xf32>
// CHECK-SAME: ) -> tensor<*xf32> {
// First handle the dynamic reshaping of the unranked operand
// to a 1D tensor.
// CHECK-NEXT: %[[SHAPE_1:.*]] = shape.shape_of %[[ARG_1]] : tensor<*xf32>
// CHECK-NEXT: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE_1]] : tensor<?xindex> -> index
// CHECK-NEXT: %[[SIZE_TENSOR:.*]] = tensor.from_elements %[[NUM_ELEMENTS]] : tensor<1xindex>
// CHECK-NEXT: %[[RESHAPED:.*]] = "mhlo.dynamic_reshape"(%[[ARG_1]], %[[SIZE_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK-NEXT: %[[BROADCASTED_RESULT:.*]] = chlo.broadcast_add %[[ARG_0]], %[[RESHAPED]] : (tensor<f32>, tensor<?xf32>) -> tensor<?xf32>
// As part of the unranked logic, the result is reshaped back
// to an unranked tensor.
// CHECK-NEXT: %[[RESHAPED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[BROADCASTED_RESULT:.*]], %[[SHAPE_1]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK-NEXT: return %[[RESHAPED_RESULT]] : tensor<*xf32>
// CHECK-NEXT: }
// -----
func @addUnrankedScalar(%arg0: tensor<*xf32>, %arg1: tensor<f32>) -> tensor<*xf32> {
%0 = chlo.broadcast_add %arg0, %arg1 : (tensor<*xf32>, tensor<f32>)
-> tensor<*xf32>
return %0 : tensor<*xf32>
}
// CHECK-LABEL: func @addUnrankedScalar(
// CHECK-SAME: %[[ARG_0:.*]]: tensor<*xf32>,
// CHECK-SAME: %[[ARG_1:.*]]: tensor<f32>) -> tensor<*xf32> {
// First handle the dynamic reshaping of the unranked operand
// to a 1D tensor.
// CHECK-NEXT: %[[SHAPE_0:.*]] = shape.shape_of %[[ARG_0]] : tensor<*xf32>
// CHECK-NEXT: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE_0]] : tensor<?xindex> -> index
// CHECK-NEXT: %[[SIZE_TENSOR:.*]] = tensor.from_elements %[[NUM_ELEMENTS]] : tensor<1xindex>
// CHECK-NEXT: %[[RESHAPED:.*]] = "mhlo.dynamic_reshape"(%[[ARG_0]], %[[SIZE_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// The assuming region is part of the second stage of lowering
// with ranked broadcasting logic.
// CHECK-NEXT: %[[BROADCASTED_RESULT:.*]] = chlo.broadcast_add %[[RESHAPED]], %[[ARG_1]] : (tensor<?xf32>, tensor<f32>) -> tensor<?xf32>
// As part of the unranked logic, the result is reshaped back
// to an unranked tensor.
// CHECK-NEXT: %[[RESHAPED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[BROADCASTED_RESULT:.*]], %[[SHAPE_0]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK-NEXT: return %[[RESHAPED_RESULT]] : tensor<*xf32>
// CHECK-NEXT: }
// -----
func @addUnrankedUnranked(
%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
%0 = chlo.broadcast_add %arg0, %arg1 : (tensor<*xf32>, tensor<*xf32>)
-> tensor<*xf32>
return %0 : tensor<*xf32>
}
// CHECK-LABEL: func @addUnrankedUnranked(
// CHECK-SAME: %[[LHS:.*]]: tensor<*xf32>,
// CHECK-SAME: %[[RHS:.*]]: tensor<*xf32>) -> tensor<*xf32> {
// CHECK-NEXT: %[[LHS_SHAPE:.*]] = shape.shape_of %[[LHS]] : tensor<*xf32> -> tensor<?xindex>
// CHECK-NEXT: %[[RHS_SHAPE:.*]] = shape.shape_of %[[RHS]] : tensor<*xf32> -> tensor<?xindex>
// CHECK-NEXT: %[[C0:.*]] = constant 0 : index
// CHECK-NEXT: %[[C1:.*]] = constant 1 : index
// CHECK-NEXT: %[[NUM_LHS:.*]] = shape.num_elements %[[LHS_SHAPE]] : tensor<?xindex> -> index
// CHECK-NEXT: %[[LHS_IS_SCALAR:.*]] = cmpi eq, %[[NUM_LHS]], %[[C1]] : index
// CHECK-NEXT: %[[COUNTER_PLUS_ONE:.*]] = addi %[[C0]], %[[C1]] : index
// CHECK-NEXT: %[[COUNTER:.*]] = select %[[LHS_IS_SCALAR]], %[[COUNTER_PLUS_ONE]], %[[C0]] : index
// CHECK-NEXT: %[[NUM_RHS:.*]] = shape.num_elements %[[RHS_SHAPE]] : tensor<?xindex> -> index
// CHECK-NEXT: %[[RHS_IS_SCALAR:.*]] = cmpi eq, %[[NUM_RHS]], %[[C1]] : index
// CHECK-NEXT: %[[COUNTER_PLUS_ONE2:.*]] = addi %[[COUNTER]], %[[C1]] : index
// CHECK-NEXT: %[[COUNTER2:.*]] = select %[[RHS_IS_SCALAR]], %[[COUNTER_PLUS_ONE2]], %[[COUNTER]] : index
// Handle scalar case
// CHECK-NEXT: %[[IS_SCALAR_CASE:.*]] = cmpi uge, %[[COUNTER2]], %[[C1]] : index
// CHECK-NEXT: %[[VAL_8:.*]] = scf.if %[[IS_SCALAR_CASE]] -> (tensor<*xf32>) {
// CHECK-NEXT: %[[NUM_TENS_LHS:.*]] = tensor.from_elements %[[NUM_LHS]] : tensor<1xindex>
// CHECK-NEXT: %[[RESHAPED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[NUM_TENS_LHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK-NEXT: %[[NUM_TENS_RHS:.*]] = tensor.from_elements %[[NUM_RHS]] : tensor<1xindex>
// CHECK-NEXT: %[[RESHAPED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[NUM_TENS_RHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK-NEXT: %[[SCALAR_RESULT:.*]] = chlo.broadcast_add %[[RESHAPED_LHS]], %[[RESHAPED_RHS]] : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
// CHECK-NEXT: %[[SHAPE_BROADCAST:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[RESHAPED_EXTENDED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[SCALAR_RESULT]], %[[SHAPE_BROADCAST]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK-NEXT: scf.yield %[[RESHAPED_EXTENDED_RESULT]] : tensor<*xf32>
// CHECK-NEXT: } else {
// CHECK-NEXT: %[[SHAPES_EQ:.*]] = shape.shape_eq %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex>
// Handle equal shapes case
// CHECK-NEXT: %[[VAL_18:.*]] = scf.if %[[SHAPES_EQ]] -> (tensor<*xf32>) {
// CHECK-NEXT: %[[ANY_SHAPE:.*]] = shape.any %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[ANY_NUM:.*]] = shape.num_elements %[[ANY_SHAPE]] : tensor<?xindex> -> index
// CHECK-NEXT: %[[ANY_TENSOR:.*]] = tensor.from_elements %[[ANY_NUM]] : tensor<1xindex>
// CHECK-NEXT: %[[FLATTENED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[ANY_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK-NEXT: %[[FLATTENED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[ANY_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK-NEXT: %[[FLATTENED_RESULT:.*]] = mhlo.add %[[FLATTENED_LHS]], %[[FLATTENED_RHS]] : tensor<?xf32>
// CHECK-NEXT: %[[RESHAPED_SAME_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[FLATTENED_RESULT]], %[[ANY_SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK-NEXT: scf.yield %[[RESHAPED_SAME_RESULT]] : tensor<*xf32>
// CHECK-NEXT: } else {
// CHECK-NEXT: %[[RESULT_SHAPE:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[MINIMUM_SHAPES:.*]]:2 = chlo.minimum_broadcast_shapes %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>, tensor<?xindex>
// CHECK-NEXT: %[[MINIMUM_RESHAPED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[MINIMUM_SHAPES]]#0) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK-NEXT: %[[MINIMUM_RESHAPED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[MINIMUM_SHAPES]]#1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK-NEXT: %[[LHS_RANK:.*]] = shape.rank %[[MINIMUM_SHAPES]]#0 : tensor<?xindex> -> index
// CHECK-NEXT: %[[RHS_RANK:.*]] = shape.rank %[[MINIMUM_SHAPES]]#1 : tensor<?xindex> -> index
// CHECK-NEXT: %[[LHS_RANK_GREATER:.*]] = cmpi sgt, %[[LHS_RANK]], %[[RHS_RANK]] : index
// CHECK-NEXT: %[[GREATEST_RANK:.*]] = select %[[LHS_RANK_GREATER]], %[[LHS_RANK]], %[[RHS_RANK]] : index
// Handle rank 1 specialization
// CHECK-NEXT: %[[GREATEST_RANK_IS_1:.*]] = cmpi eq, %[[GREATEST_RANK]], %[[C1]] : index
// CHECK-NEXT: %[[RESULT_RANK_SPECIALIZATION:.*]] = scf.if %[[GREATEST_RANK_IS_1]] -> (tensor<*xf32>) {
// CHECK-NEXT: %[[CONST_SHAPE_1:.*]] = shape.const_shape [1]
// CHECK-NEXT: %[[BROADCASTED_LHS_1:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[CASTED_LHS_1:.*]] = tensor.cast %[[BROADCASTED_LHS_1]] : tensor<?xindex> to tensor<1xindex>
// CHECK-NEXT: %[[RESHAPED_LHS_1:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS_1]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK-NEXT: %[[BROADCASTED_RHS_1:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[CASTED_RHS_1:.*]] = tensor.cast %[[BROADCASTED_RHS_1]] : tensor<?xindex> to tensor<1xindex>
// CHECK-NEXT: %[[RESHAPED_RHS_1:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS_1]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK-NEXT: %[[RESULT_RANK_1:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_1]], %[[RESHAPED_RHS_1]] : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
// CHECK-NEXT: %[[RESULT_1:.*]] = tensor.cast %[[RESULT_RANK_1]] : tensor<?xf32> to tensor<*xf32>
// CHECK-NEXT: scf.yield %[[RESULT_1]] : tensor<*xf32>
// CHECK-NEXT: } else {
// CHECK-NEXT: %[[C2:.*]] = constant 2 : index
// CHECK-NEXT: %[[GREATEST_RANK_IS_2:.*]] = cmpi eq, %[[GREATEST_RANK]], %[[C2]] : index
// Handle rank 2 specialization
// CHECK-NEXT: %[[VAL_26:.*]] = scf.if %[[GREATEST_RANK_IS_2]] -> (tensor<*xf32>) {
// CHECK-NEXT: %[[CONST_SHAPE_2:.*]] = shape.const_shape [1, 1]
// CHECK-NEXT: %[[BROADCASTED_LHS_2:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_2]] : tensor<?xindex>, tensor<2xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[CASTED_LHS_2:.*]] = tensor.cast %[[BROADCASTED_LHS_2]] : tensor<?xindex> to tensor<2xindex>
// CHECK-NEXT: %[[RESHAPED_LHS_2:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor<?x?xf32>
// CHECK-NEXT: %[[BROADCASTED_RHS_2:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_2]] : tensor<?xindex>, tensor<2xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[CASTED_RHS_2:.*]] = tensor.cast %[[BROADCASTED_RHS_2]] : tensor<?xindex> to tensor<2xindex>
// CHECK-NEXT: %[[RESHAPED_RHS_2:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor<?x?xf32>
// CHECK-NEXT: %[[RESULT_RANK_2:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_2]], %[[RESHAPED_RHS_2]] : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK-NEXT: %[[RESULT_2:.*]] = tensor.cast %[[RESULT_RANK_2]] : tensor<?x?xf32> to tensor<*xf32>
// CHECK-NEXT: scf.yield %[[RESULT_2]] : tensor<*xf32>
// CHECK-NEXT: } else {
// CHECK-NEXT: %[[C3:.*]] = constant 3 : index
// CHECK-NEXT: %[[GREATEST_RANK_IS_3:.*]] = cmpi eq, %[[GREATEST_RANK]], %[[C3]] : index
// Handle rank 3 specialization
// CHECK-NEXT: %[[VAL_34:.*]] = scf.if %[[GREATEST_RANK_IS_3]] -> (tensor<*xf32>) {
// CHECK-NEXT: %[[CONST_SHAPE_3:.*]] = shape.const_shape [1, 1, 1]
// CHECK-NEXT: %[[BROADCASTED_LHS_3:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_3]] : tensor<?xindex>, tensor<3xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[CASTED_LHS_3:.*]] = tensor.cast %[[BROADCASTED_LHS_3]] : tensor<?xindex> to tensor<3xindex>
// CHECK-NEXT: %[[RESHAPED_LHS_3:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
// CHECK-NEXT: %[[BROADCASTED_RHS_3:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_3]] : tensor<?xindex>, tensor<3xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[CASTED_RHS_3:.*]] = tensor.cast %[[BROADCASTED_RHS_3]] : tensor<?xindex> to tensor<3xindex>
// CHECK-NEXT: %[[RESHAPED_RHS_3:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
// CHECK-NEXT: %[[RESULT_RANK_3:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_3]], %[[RESHAPED_RHS_3]] : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
// CHECK-NEXT: %[[RESULT_3:.*]] = tensor.cast %[[RESULT_RANK_3]] : tensor<?x?x?xf32> to tensor<*xf32>
// CHECK-NEXT: scf.yield %[[RESULT_3]] : tensor<*xf32>
// CHECK-NEXT: } else {
// CHECK-NEXT: %[[C4:.*]] = constant 4 : index
// CHECK-NEXT: %[[GREATEST_RANK_IS_4:.*]] = cmpi eq, %[[GREATEST_RANK]], %[[C4]] : index
// Handle rank 4 specialization
// CHECK-NEXT: %[[VAL_42:.*]] = scf.if %[[GREATEST_RANK_IS_4]] -> (tensor<*xf32>) {
// CHECK-NEXT: %[[CONST_SHAPE_4:.*]] = shape.const_shape [1, 1, 1, 1]
// CHECK-NEXT: %[[BROADCASTED_LHS_4:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_4]] : tensor<?xindex>, tensor<4xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[CASTED_LHS_4:.*]] = tensor.cast %[[BROADCASTED_LHS_4]] : tensor<?xindex> to tensor<4xindex>
// CHECK-NEXT: %[[RESHAPED_LHS_4:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
// CHECK-NEXT: %[[BROADCASTED_RHS_4:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_4]] : tensor<?xindex>, tensor<4xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[CASTED_RHS_4:.*]] = tensor.cast %[[BROADCASTED_RHS_4]] : tensor<?xindex> to tensor<4xindex>
// CHECK-NEXT: %[[RESHAPED_RHS_4:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
// CHECK-NEXT: %[[RESULT_RANK_4:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_4]], %[[RESHAPED_RHS_4]] : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
// CHECK-NEXT: %[[RESULT_4:.*]] = tensor.cast %[[RESULT_RANK_4]] : tensor<?x?x?x?xf32> to tensor<*xf32>
// CHECK-NEXT: scf.yield %[[RESULT_4]] : tensor<*xf32>
// CHECK-NEXT: } else {
// CHECK-NEXT: %[[C5:.*]] = constant 5 : index
// CHECK-NEXT: %[[GREATEST_RANK_IS_5:.*]] = cmpi eq, %[[GREATEST_RANK]], %[[C5]] : index
// CHECK-NEXT: assert %[[GREATEST_RANK_IS_5]]
// Handle rank 5 specialization
// CHECK-NEXT: %[[CONST_SHAPE_5:.*]] = shape.const_shape [1, 1, 1, 1, 1]
// CHECK-NEXT: %[[BROADCASTED_LHS_5:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_5]] : tensor<?xindex>, tensor<5xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[CASTED_LHS_5:.*]] = tensor.cast %[[BROADCASTED_LHS_5]] : tensor<?xindex> to tensor<5xindex>
// CHECK-NEXT: %[[RESHAPED_LHS_5:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor<?x?x?x?x?xf32>
// CHECK-NEXT: %[[BROADCASTED_RHS_5:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_5]] : tensor<?xindex>, tensor<5xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[CASTED_RHS_5:.*]] = tensor.cast %[[BROADCASTED_RHS_5]] : tensor<?xindex> to tensor<5xindex>
// CHECK-NEXT: %[[RESHAPED_RHS_5:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor<?x?x?x?x?xf32>
// CHECK-NEXT: %[[RESULT_RANK_5:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_5]], %[[RESHAPED_RHS_5]] : (tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
// CHECK-NEXT: %[[RESULT_5:.*]] = tensor.cast %[[RESULT_RANK_5]] : tensor<?x?x?x?x?xf32> to tensor<*xf32>
// CHECK-NEXT: scf.yield %[[RESULT_5]] : tensor<*xf32>
// CHECK-NEXT: }
// CHECK-NEXT: scf.yield %[[VAL_66:.*]] : tensor<*xf32>
// CHECK-NEXT: }
// CHECK-NEXT: scf.yield %[[VAL_67:.*]] : tensor<*xf32>
// CHECK-NEXT: }
// CHECK-NEXT: scf.yield %[[VAL_68:.*]] : tensor<*xf32>
// CHECK-NEXT: }
// CHECK-NEXT: %[[RESHAPED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[RESULT_RANK_SPECIALIZATION]], %[[RESULT_SHAPE]]) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK-NEXT: scf.yield %[[RESHAPED_RESULT]] : tensor<*xf32>
// CHECK-NEXT: }
// CHECK-NEXT: scf.yield %[[VAL_71:.*]] : tensor<*xf32>
// CHECK-NEXT: }
// CHECK-NEXT: return %[[VAL_72:.*]] : tensor<*xf32>
// CHECK-NEXT: }
// -----
func @selectUnrankedUnrankedUnranked(
%arg0: tensor<*xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>)
-> tensor<*xf32> {
%0 = chlo.broadcast_select %arg0, %arg1, %arg2
: (tensor<*xi1>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
// CHECK-LABEL: func @selectUnrankedUnrankedUnranked(
// CHECK-SAME: %[[PRED:.*]]: tensor<*xi1>,
// CHECK-SAME: %[[LHS:.*]]: tensor<*xf32>,
// CHECK-SAME: %[[RHS:.*]]: tensor<*xf32>) -> tensor<*xf32> {
// CHECK-NEXT: %[[PRED_SHAPE:.*]] = shape.shape_of %[[PRED]] : tensor<*xi1> -> tensor<?xindex>
// CHECK-NEXT: %[[LHS_SHAPE:.*]] = shape.shape_of %[[LHS]] : tensor<*xf32> -> tensor<?xindex>
// CHECK-NEXT: %[[RHS_SHAPE:.*]] = shape.shape_of %[[RHS]] : tensor<*xf32> -> tensor<?xindex>
// CHECK-NEXT: %c0 = constant 0 : index
// CHECK-NEXT: %c1 = constant 1 : index
// CHECK-NEXT: %[[NUM_ELEMENTS_PRED:.*]] = shape.num_elements %[[PRED_SHAPE]] : tensor<?xindex> -> index
// CHECK-NEXT: %[[PRED_IS_SCALAR:.*]] = cmpi eq, %[[NUM_ELEMENTS_PRED]], %c1 : index
// CHECK-NEXT: %[[COUNTER_PLUS_ONE:.*]] = addi %c0, %c1 : index
// CHECK-NEXT: %[[COUNTER:.*]] = select %[[PRED_IS_SCALAR]], %[[COUNTER_PLUS_ONE]], %c0 : index
// CHECK-NEXT: %[[NUM_ELEMENTS_LHS:.*]] = shape.num_elements %[[LHS_SHAPE]] : tensor<?xindex> -> index
// CHECK-NEXT: %[[LHS_IS_SCALAR:.*]] = cmpi eq, %[[NUM_ELEMENTS_LHS]], %c1 : index
// CHECK-NEXT: %[[COUNTER_PLUS_ONE:.*]] = addi %[[COUNTER]], %c1 : index
// CHECK-NEXT: %[[COUNTER2:.*]] = select %[[LHS_IS_SCALAR]], %[[COUNTER_PLUS_ONE]], %[[COUNTER]] : index
// CHECK-NEXT: %[[NUM_ELEMENTS_RHS:.*]] = shape.num_elements %[[RHS_SHAPE]] : tensor<?xindex> -> index
// CHECK-NEXT: %[[RHS_IS_SCALAR:.*]] = cmpi eq, %[[NUM_ELEMENTS_RHS]], %c1 : index
// CHECK-NEXT: %[[COUNTER_PLUS_ONE:.*]] = addi %[[COUNTER2]], %c1 : index
// CHECK-NEXT: %[[COUNTER3:.*]] = select %[[RHS_IS_SCALAR]], %[[COUNTER_PLUS_ONE]], %[[COUNTER2]] : index
// CHECK-NEXT: %c2 = constant 2 : index
// CHECK-NEXT: %[[IS_SCALAR_CASE:.*]] = cmpi uge, %[[COUNTER3]], %c2 : index
// CHECK-NEXT: %[[IF_IS_SCALAR_CASE:.*]] = scf.if %[[IS_SCALAR_CASE]] -> (tensor<*xf32>) {
// CHECK-NEXT: %[[NUM_TENS_PRED:.*]] = tensor.from_elements %[[NUM_ELEMENTS_PRED]] : tensor<1xindex>
// CHECK-NEXT: %[[RESHAPED_PRED:.*]] = "mhlo.dynamic_reshape"(%[[PRED]], %[[NUM_TENS_PRED]]) : (tensor<*xi1>, tensor<1xindex>) -> tensor<?xi1>
// CHECK-NEXT: %[[NUM_TENS_LHS:.*]] = tensor.from_elements %[[NUM_ELEMENTS_LHS]] : tensor<1xindex>
// CHECK-NEXT: %[[RESHAPED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[NUM_TENS_LHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK-NEXT: %[[NUM_TENS_RHS:.*]] = tensor.from_elements %[[NUM_ELEMENTS_RHS]] : tensor<1xindex>
// CHECK-NEXT: %[[RESHAPED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[NUM_TENS_RHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK-NEXT: %[[SCALAR_RESULT:.*]] = chlo.broadcast_select %[[RESHAPED_PRED]], %[[RESHAPED_LHS]], %[[RESHAPED_RHS]] : (tensor<?xi1>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
// CHECK-NEXT: %[[SHAPE_BROADCAST:.*]] = shape.broadcast %[[PRED_SHAPE]], %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[RESHAPED_EXTENDED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[SCALAR_RESULT]], %[[SHAPE_BROADCAST]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK-NEXT: scf.yield %[[RESHAPED_EXTENDED_RESULT]] : tensor<*xf32>
// CHECK-NEXT: } else {
// CHECK-NEXT: %[[FIRST_SHAPES_EQUAL:.*]] = shape.shape_eq %[[PRED_SHAPE]], %[[LHS_SHAPE]] : tensor<?xindex>, tensor<?xindex>
// CHECK-NEXT: %[[SECOND_SHAPES_EQUAL:.*]] = shape.shape_eq %[[PRED_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex>
// CHECK-NEXT: %[[ALL_SHAPES_EQUAL:.*]] = and %[[FIRST_SHAPES_EQUAL]], %[[SECOND_SHAPES_EQUAL]] : i1
// CHECK-NEXT: %[[IF_EQUAL_CASE:.*]] = scf.if %[[ALL_SHAPES_EQUAL]] -> (tensor<*xf32>) {
// CHECK-NEXT: %[[ANY_SHAPE:.*]] = shape.any %[[PRED_SHAPE]], %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[ANY_NUM:.*]] = shape.num_elements %[[ANY_SHAPE]] : tensor<?xindex> -> index
// CHECK-NEXT: %[[ANY_TENSOR:.*]] = tensor.from_elements %[[ANY_NUM]] : tensor<1xindex>
// CHECK-NEXT: %[[FLATTENED_PRED:.*]] = "mhlo.dynamic_reshape"(%[[PRED]], %[[ANY_TENSOR]]) : (tensor<*xi1>, tensor<1xindex>) -> tensor<?xi1>
// CHECK-NEXT: %[[FLATTENED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[ANY_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK-NEXT: %[[FLATTENED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[ANY_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK-NEXT: %[[FLATTENED_RESULT:.*]] = "mhlo.select"(%[[FLATTENED_PRED]], %[[FLATTENED_LHS]], %[[FLATTENED_RHS]]) : (tensor<?xi1>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
// CHECK-NEXT: %[[RESHAPED_SAME_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[FLATTENED_RESULT]], %[[ANY_SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK-NEXT: scf.yield %[[RESHAPED_SAME_RESULT]] : tensor<*xf32>
// CHECK-NEXT: } else {
// CHECK-NEXT: %[[RESULT_SHAPE:.*]] = shape.broadcast %[[PRED_SHAPE]], %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[MINIMUM_SHAPES:.*]]:3 = chlo.minimum_broadcast_shapes %[[PRED_SHAPE]], %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>, tensor<?xindex>, tensor<?xindex>
// CHECK-NEXT: %[[MINIMUM_RESHAPED_PRED:.*]] = "mhlo.dynamic_reshape"(%[[PRED]], %[[MINIMUM_SHAPES]]#0) : (tensor<*xi1>, tensor<?xindex>) -> tensor<*xi1>
// CHECK-NEXT: %[[MINIMUM_RESHAPED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[MINIMUM_SHAPES]]#1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK-NEXT: %[[MINIMUM_RESHAPED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[MINIMUM_SHAPES]]#2) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK-NEXT: %[[PRED_RANK:.*]] = shape.rank %[[MINIMUM_SHAPES]]#0 : tensor<?xindex> -> index
// CHECK-NEXT: %[[LHS_RANK:.*]] = shape.rank %[[MINIMUM_SHAPES]]#1 : tensor<?xindex> -> index
// CHECK-NEXT: %[[GREATER_RANK_CMP:.*]] = cmpi sgt, %[[PRED_RANK]], %[[LHS_RANK]] : index
// CHECK-NEXT: %[[GREATER_RANK:.*]] = select %[[GREATER_RANK_CMP]], %[[PRED_RANK]], %[[LHS_RANK]] : index
// CHECK-NEXT: %[[RHS_RANK:.*]] = shape.rank %[[MINIMUM_SHAPES]]#2 : tensor<?xindex> -> index
// CHECK-NEXT: %[[GREATEST_RANK_CMP:.*]] = cmpi sgt, %[[GREATER_RANK]], %[[RHS_RANK]] : index
// CHECK-NEXT: %[[GREATEST_RANK:.*]] = select %[[GREATEST_RANK_CMP]], %[[GREATER_RANK]], %[[RHS_RANK]] : index
// CHECK-NEXT: %[[GREATEST_RANK_IS_1:.*]] = cmpi eq, %[[GREATEST_RANK]], %c1 : index
// Handle rank 1 specialization
// CHECK-NEXT: scf.if %[[GREATEST_RANK_IS_1]] -> (tensor<*xf32>) {
// CHECK-NEXT: %[[CONST_SHAPE_1:.*]] = shape.const_shape [1] : tensor<1xindex>
// CHECK-NEXT: %[[BROADCASTED_PRED:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[CASTED_PRED:.*]] = tensor.cast %[[BROADCASTED_PRED]] : tensor<?xindex> to tensor<1xindex>
// CHECK-NEXT: %[[RESHAPED_PRED:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_PRED]], %[[CASTED_PRED]]) : (tensor<*xi1>, tensor<1xindex>) -> tensor<?xi1>
// CHECK-NEXT: %[[BROADCASTED_LHS:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[CASTED_LHS:.*]] = tensor.cast %[[BROADCASTED_LHS]] : tensor<?xindex> to tensor<1xindex>
// CHECK-NEXT: %[[RESHAPED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK-NEXT: %[[BROADCASTED_RHS:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#2, %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[CASTED_RHS:.*]] = tensor.cast %[[BROADCASTED_RHS]] : tensor<?xindex> to tensor<1xindex>
// CHECK-NEXT: %[[RESHAPED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK-NEXT: %[[RESULT_RANK_1:.*]] = chlo.broadcast_select %[[RESHAPED_PRED]], %[[RESHAPED_LHS]], %[[RESHAPED_RHS]] : (tensor<?xi1>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
// CHECK-NEXT: %[[RESULT_1:.*]] = tensor.cast %[[RESULT_RANK_1:.*]] : tensor<?xf32> to tensor<*xf32>
// CHECK-NEXT: scf.yield %[[RESULT_1]] : tensor<*xf32>
// CHECK-NEXT: }
// CHECK: chlo.broadcast_select {{.*}} : (tensor<?x?xi1>, tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: chlo.broadcast_select {{.*}} : (tensor<?x?x?xi1>, tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
// CHECK: chlo.broadcast_select {{.*}} : (tensor<?x?x?x?xi1>, tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
// CHECK: chlo.broadcast_select {{.*}} : (tensor<?x?x?x?x?xi1>, tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
// CHECK: chlo.broadcast_select {{.*}} : (tensor<?x?x?x?x?x?xi1>, tensor<?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?xf32>) -> tensor<?x?x?x?x?x?xf32>
// CHECK: chlo.broadcast_select {{.*}} : (tensor<?x?x?x?x?x?x?xi1>, tensor<?x?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?x?xf32>) -> tensor<?x?x?x?x?x?x?xf32>
// CHECK: chlo.broadcast_select {{.*}} : (tensor<?x?x?x?x?x?x?x?xi1>, tensor<?x?x?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?x?x?xf32>) -> tensor<?x?x?x?x?x?x?x?xf32>