From 470ac45f4508f603b0cd552243ee74cbb2083976 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 17 Jun 2021 05:19:54 -0700 Subject: [PATCH] [MLIR][HLO] Remove unused pass `TransformUnrankedHloPass` The pass was replaced by the new generalized rank specialization and the two passes `mhlo-rank-specialization-cluster` and `mhlo-rank-specialization-to-scf`. PiperOrigin-RevId: 379935562 --- BUILD | 24 - .../Dialect/mhlo/transforms/mhlo_passes.td | 6 - .../mlir-hlo/Dialect/mhlo/transforms/passes.h | 3 - .../Dialect/mhlo/transforms/rewriters.h | 4 - lib/Dialect/mhlo/transforms/CMakeLists.txt | 1 - .../mhlo/transforms/chlo_legalize_to_hlo.cc | 2 +- .../mhlo/transforms/transform_unranked_hlo.cc | 619 ------------------ tests/hlo-transform-unranked.mlir | 410 ------------ 8 files changed, 1 insertion(+), 1068 deletions(-) delete mode 100644 lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc delete mode 100644 tests/hlo-transform-unranked.mlir diff --git a/BUILD b/BUILD index 6c3dc82..8d84c1d 100644 --- a/BUILD +++ b/BUILD @@ -857,29 +857,6 @@ cc_library( alwayslink = 1, ) -cc_library( - name = "transform_unranked_hlo", - srcs = ["lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc"], - hdrs = [ - "include/mlir-hlo/Dialect/mhlo/transforms/passes.h", - "include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h", - ], - deps = [ - ":hlo", - ":map_chlo_to_hlo_op", - ":pass_details", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:SCFDialect", - "@llvm-project//mlir:Shape", - "@llvm-project//mlir:StandardOps", - "@llvm-project//mlir:TensorDialect", - "@llvm-project//mlir:Transforms", - ], - alwayslink = 1, -) - cc_library( name = "broadcast_propagation", srcs = ["lib/Dialect/mhlo/transforms/broadcast_propagation.cc"], @@ -1379,7 +1356,6 @@ cc_library( ":rank_specialization", ":sink_constants_to_control_flow", ":test_passes", - ":transform_unranked_hlo", "@llvm-project//mlir:Pass", ], ) diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.td b/include/mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.td index f7ee2d8..976940e 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.td +++ b/include/mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.td @@ -112,12 +112,6 @@ def TestInferShapedTypeMethodsPass : FunctionPass<"mhlo-test-infer-shaped-type-m let constructor = "createTestInferShapedTypeMethodsPass()"; } - -def TransformUnrankedHloPass : FunctionPass<"mhlo-transform-unranked-hlo"> { - let summary = "Realize element-wise operations on ranked tensors where possible."; - let constructor = "createTransformUnrankedHloPass()"; -} - def BroadcastPropagationPass : FunctionPass<"mhlo-broadcast-propagation"> { let summary = "Move dynamic broadcasts up over element-wise operations and " "broadcast the operands rather than the result. This will eventually allow " diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/passes.h b/include/mlir-hlo/Dialect/mhlo/transforms/passes.h index 5eb953f..44aff5c 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/passes.h +++ b/include/mlir-hlo/Dialect/mhlo/transforms/passes.h @@ -32,9 +32,6 @@ class Pass; namespace mhlo { -// Transforms unranked HLO operations to ranked ones where possible. -std::unique_ptr createTransformUnrankedHloPass(); - /// Lowers HLO control flow ops to the Standard dialect. std::unique_ptr> createLegalizeControlFlowPass(); diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h b/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h index 0d3d641..a0f50ba 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h +++ b/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h @@ -81,10 +81,6 @@ void SetupMaterializeBroadcastsLegality(MLIRContext *context, void PopulateMaterializeBroadcastsPatterns(MLIRContext *context, OwningRewritePatternList *patterns); -// Sets up legality definitions for element-wise operations on ranked tensors. -void SetupTransformUnrankedHloLegality(MLIRContext *context, - ConversionTarget *conversionTarget); - // Populates a collection of rewrite patterns to realize element-wise operations // on ranked tensors where possible. void PopulateTransformUnrankedHloPatterns(MLIRContext *context, diff --git a/lib/Dialect/mhlo/transforms/CMakeLists.txt b/lib/Dialect/mhlo/transforms/CMakeLists.txt index 8caba8b..e9214fb 100644 --- a/lib/Dialect/mhlo/transforms/CMakeLists.txt +++ b/lib/Dialect/mhlo/transforms/CMakeLists.txt @@ -62,7 +62,6 @@ add_mlir_library(MhloPasses rank_specialization.cc sink_constants_to_control_flow.cc test_infer_shaped_type_pass.cc - transform_unranked_hlo.cc unfuse_batch_norm.cc unfuse_batch_norm_pass.cc diff --git a/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc b/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc index 9d36a0f..d8c8683 100644 --- a/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc +++ b/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc @@ -52,7 +52,7 @@ struct ConvertConstantLikeOp : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override { auto result_ty = op.getType().cast(); - // Unranked uses are not supported. Consider `mhlo-transform-unranked-hlo`. + // Unranked uses are not supported. if (!result_ty.hasRank()) return failure(); // Lower to MHLO constant if statically shaped. diff --git a/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc b/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc deleted file mode 100644 index 00cba94..0000000 --- a/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc +++ /dev/null @@ -1,619 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. - -==============================================================================*/ - -#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h" -#include "mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h" -#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" -#include "mlir/Dialect/SCF/SCF.h" -#include "mlir/Dialect/Shape/IR/Shape.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/Operation.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/DialectConversion.h" - -namespace mlir { -namespace { - -// TODO(herhut): Generate these out of op definitions. -#define MAP_XLA_OPERATION_CWISE_UNARY(fn, sep) \ - fn(AbsOp) sep fn(CeilOp) sep fn(ClzOp) sep fn(ConvertOp) sep fn(CosOp) \ - sep fn(ExpOp) sep fn(Expm1Op) sep fn(FloorOp) sep fn(ImagOp) \ - sep fn(IsFiniteOp) sep fn(LogOp) sep fn(Log1pOp) sep fn(LogisticOp) \ - sep fn(NotOp) sep fn(NegOp) sep fn(PopulationCountOp) \ - sep fn(RealOp) sep fn(RoundOp) sep fn(RsqrtOp) \ - sep fn(SignOp) sep fn(SinOp) sep fn(SqrtOp) \ - sep fn(TanhOp) - -// TODO(herhut): Generate these out of op definitions. -#define MAP_XLA_OPERATION_CWISE_BINARY(fn, sep) \ - fn(AddOp) sep fn(AndOp) sep fn(Atan2Op) sep fn(ComplexOp) sep fn(DivOp) \ - sep fn(MaxOp) sep fn(MinOp) sep fn(MulOp) sep fn(OrOp) sep fn(PowOp) \ - sep fn(RemOp) sep fn(ShiftLeftOp) sep fn(ShiftRightArithmeticOp) \ - sep fn(ShiftRightLogicalOp) sep fn(SubOp) sep fn(XorOp) - -// TODO(herhut): Generate these out of op definitions. -#define MAP_CHLO_OPERATION_CWISE_UNARY(fn, sep) \ - fn(AcosOp) sep fn(AcoshOp) sep fn(AsinOp) sep fn(AsinhOp) sep fn(AtanOp) \ - sep fn(AtanhOp) sep fn(ConjOp) sep fn(CoshOp) sep fn(DigammaOp) \ - sep fn(ErfOp) sep fn(ErfcOp) sep fn(IsInfOp) sep fn(LgammaOp) \ - sep fn(SinhOp) sep fn(TanOp) - -// TODO(herhut): Generate these out of op definitions. -#define MAP_CHLO_OPERATION_CWISE_BINARY(fn, sep) fn(PolygammaOp) sep fn(ZetaOp) - -template -inline void AddLegalOpOnRankedTensor(ConversionTarget *target) { - target->addDynamicallyLegalOp([](OpTy op) { - return llvm::all_of(op.getOperation()->getOperandTypes(), - [&](Type t) { return t.isa(); }); - }); -} - -/// Element-wise operations on unranked tensors can be applied to the flattened -/// tensor operands with the same effect. This pattern rewrites every such -/// operation to -/// (i) flatten the input tensor, -/// (ii) apply the operation, and -/// (iii) restore the original shape. -template -struct ElementwiseOpConversion : public OpRewritePattern { - explicit ElementwiseOpConversion(MLIRContext *context) - : OpRewritePattern(context) {} - - LogicalResult matchAndRewrite(OpTy op, - PatternRewriter &rewriter) const override { - // Only apply conversion if at least one operand is unranked. - if (llvm::none_of(op.getOperation()->getOperands(), [&](Value operand) { - return operand.getType().isa(); - })) { - return failure(); - } - - // Get operands' shape. - auto loc = op.getLoc(); - Type extentTensorTy = shape::getExtentTensorType(rewriter.getContext()); - SmallVector operandShapes; - for (Value operand : op.getOperation()->getOperands()) { - Value shape = - rewriter.create(loc, extentTensorTy, operand); - operandShapes.push_back(shape); - } - Value shape = - operandShapes.size() == 1 - ? operandShapes.front() - : rewriter.create(loc, extentTensorTy, operandShapes); - - // Derive flat shape. - Type indexTy = rewriter.getIndexType(); - Value numElements = - rewriter.create(loc, indexTy, shape); - Value flatShape = rewriter.create(loc, numElements); - - // Flatten operands. - SmallVector flatOperands; - for (Value operand : op.getOperation()->getOperands()) { - Type operandElementTy = - operand.getType().template cast().getElementType(); - Type flatTy = - RankedTensorType::get({ShapedType::kDynamicSize}, operandElementTy); - Value flat = rewriter.create(loc, flatTy, operand, - flatShape); - flatOperands.push_back(flat); - } - - // Apply operation to flattened operands. - Type resultElementTy = - op.getType().template cast().getElementType(); - Type flatResultTy = - RankedTensorType::get({ShapedType::kDynamicSize}, resultElementTy); - Value flatResult = - rewriter.create(loc, flatResultTy, flatOperands, op->getAttrs()); - - // Restore original shape. - rewriter.replaceOpWithNewOp(op, op.getType(), - flatResult, shape); - - return success(); - } -}; - -// Converts a broadcasting binary operation with a scalar operand and an -// unranked operand to a ranked broadcasting operation by dynamically reshaping -// the unranked operand to a 1D tensor. This will always be safe because -// broadcasting from a scalar to another shape always works. -template -struct ConvertUnrankedScalarDynamicBroadcastBinaryOp - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite( - ChloOpTy op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - typename ChloOpTy::Adaptor transformed(operands); - Value lhs = transformed.lhs(); - Value rhs = transformed.rhs(); - - auto lhs_ranked_type = lhs.getType().dyn_cast(); - auto lhs_unranked_type = lhs.getType().dyn_cast(); - - auto rhs_ranked_type = rhs.getType().dyn_cast(); - auto rhs_unranked_type = rhs.getType().dyn_cast(); - - bool lhs_is_scalar = lhs_ranked_type && - lhs_ranked_type.getShape().empty() && - rhs_unranked_type; - bool rhs_is_scalar = rhs_ranked_type && - rhs_ranked_type.getShape().empty() && - lhs_unranked_type; - - // Only support the case where exactly one operand is scalar and the other - // is unranked. Other patterns in chlo-to-hlo legalization will create more - // efficient lowerings for cases where both ranks are known or will handle - // the more generic case of both inputs being unranked. - if (!(lhs_is_scalar ^ rhs_is_scalar)) return failure(); - - auto scalar_element_type = lhs_is_scalar ? lhs_ranked_type.getElementType() - : rhs_ranked_type.getElementType(); - auto result_type = op.getResult().getType().template dyn_cast(); - auto result_element_type = result_type.getElementType(); - - // Reshape the non-scalar value into a dynamically sized, rank-1 tensor - Value shape = - rewriter.create(loc, lhs_is_scalar ? rhs : lhs); - Value num_elements = rewriter.create(loc, shape); - Value size_tensor = - rewriter.create(loc, num_elements); - Value reshaped = rewriter.create( - loc, - RankedTensorType::get({RankedTensorType::kDynamicSize}, - scalar_element_type), - lhs_is_scalar ? rhs : lhs, size_tensor); - - // Create a new ranked Chlo op that will be further lowered by other - // patterns into Mhlo. - SmallVector new_operands{lhs_is_scalar ? lhs : reshaped, - rhs_is_scalar ? rhs : reshaped}; - Value computed = rewriter.create( - loc, - TypeRange{RankedTensorType::get({RankedTensorType::kDynamicSize}, - result_element_type)}, - new_operands, op->getAttrs()); - - // Reshape the result back into an unranked tensor. - rewriter.replaceOpWithNewOp(op, result_type, - computed, shape); - - return success(); - } -}; - -// Handles lowering of the following pattern to patterns that will be further -// matched by other patterns until they result in LHLO: -// %result = "chlo.op"(%op0, %op1, ...) : (<*xTy>, <*xTy>, ...) -> <*xTy> -// -// The sequence of specializations this handles is: -// - At most one operand has a shape that does not consist of exactly one -// element. -// - All operands having equal shapes -// - The resulting minimized shapes being any of ranks [1,5] -template -struct ConvertUnrankedDynamicBroadcastNaryOp - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite( - ChloOpTy op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - typename ChloOpTy::Adaptor transformed(operands); - ValueRange transformed_operands = transformed.getOperands(); - auto num_operands = transformed_operands.size(); - llvm::SmallVector operand_element_types; - operand_element_types.reserve(num_operands); - bool has_unranked_tensor_type = false; - for (int i = 0; i < num_operands; ++i) { - if (auto type = - transformed_operands[i].getType().dyn_cast()) { - if (type.isa()) { - has_unranked_tensor_type = true; - } - operand_element_types.push_back(type.getElementType()); - } else { - return failure(); - } - } - if (!has_unranked_tensor_type) return failure(); - auto result_type = op.getResult().getType().template dyn_cast(); - - llvm::SmallVector shapes; - shapes.reserve(num_operands); - for (int i = 0; i < num_operands; ++i) { - shapes.push_back( - rewriter.create(loc, transformed_operands[i])); - } - - // If at most one shape does not have exactly one element - Value counter = rewriter.create(loc, 0); - Value one = rewriter.create(loc, 1); - for (int i = 0; i < num_operands; ++i) { - Value is_scalar_like = IsSingleElementShape(rewriter, op, shapes[i]); - Value counter_plus_one = rewriter.create(loc, counter, one); - counter = rewriter.create(loc, is_scalar_like, counter_plus_one, - counter); - } - Value num_operands_minus_one = - rewriter.create(loc, num_operands - 1); - Value at_most_one_non_scalar = - rewriter.create(loc, rewriter.getI1Type(), CmpIPredicate::uge, - counter, num_operands_minus_one); - - auto if_op = rewriter.create(loc, result_type, - at_most_one_non_scalar, true); - OpBuilder if_at_most_one_non_scalar_builder = - if_op.getThenBodyBuilder(rewriter.getListener()); - - llvm::SmallVector reshaped_operands; - reshaped_operands.reserve(num_operands); - for (int i = 0; i < num_operands; ++i) { - Value num_elements = - if_at_most_one_non_scalar_builder.create( - loc, shapes[i]); - Value size_tensor = - if_at_most_one_non_scalar_builder.create( - loc, num_elements); - Value reshaped = - if_at_most_one_non_scalar_builder.create( - loc, - RankedTensorType::get({RankedTensorType::kDynamicSize}, - operand_element_types[i]), - transformed_operands[i], size_tensor); - reshaped_operands.push_back(reshaped); - } - - auto rank_one_result_type = RankedTensorType::get( - {RankedTensorType::kDynamicSize}, result_type.getElementType()); - Value if_at_most_one_non_scalar_result = - if_at_most_one_non_scalar_builder.create( - loc, ArrayRef{rank_one_result_type}, reshaped_operands, - op->getAttrs()); - Value extended_result = extendToBroadcastShape( - if_at_most_one_non_scalar_builder, loc, result_type, - if_at_most_one_non_scalar_result, shapes); - if_at_most_one_non_scalar_builder.create(loc, - extended_result); - - // If there is more than one shape which does not have exactly one element - // - // See if all shapes are equal. - OpBuilder else_builder = if_op.getElseBodyBuilder(rewriter.getListener()); - Value equal_shapes = - else_builder.create(loc, shapes[0], shapes[1]); - for (int i = 2; i < num_operands; ++i) { - Value are_equal = - else_builder.create(loc, shapes[0], shapes[i]); - equal_shapes = else_builder.create(loc, equal_shapes, are_equal); - } - - auto if_eq_shapes_op = - else_builder.create(loc, result_type, equal_shapes, true); - else_builder.create(loc, if_eq_shapes_op.getResult(0)); - - OpBuilder if_eq_shapes_builder = - if_eq_shapes_op.getThenBodyBuilder(rewriter.getListener()); - Value non_broadcast_op = Adaptor::CreateOp( - op, result_type, transformed_operands, if_eq_shapes_builder); - if_eq_shapes_builder.create(loc, non_broadcast_op); - - // If shapes do not have exactly one element, nor are equal - // - // See if values are of a rank that we support. - OpBuilder if_neq_shapes_builder = - if_eq_shapes_op.getElseBodyBuilder(rewriter.getListener()); - if_neq_shapes_builder.create( - loc, - HandleBroadcastAndOp(if_neq_shapes_builder, op, transformed_operands)); - - rewriter.replaceOp(op, {if_op.getResult(0)}); - return success(); - } - - private: - // Returns the dynamic result of checking the given value is effectively a - // scalar shape (i.e. the number of elements is 1). - Value IsSingleElementShape(OpBuilder &rewriter, ChloOpTy op, - Value shape_of_tensor) const { - auto loc = op.getLoc(); - - Value num_elements = - rewriter.create(loc, shape_of_tensor); - return rewriter.create(loc, rewriter.getI1Type(), CmpIPredicate::eq, - num_elements, - rewriter.create(loc, 1)); - } - - Value extendToBroadcastShape(OpBuilder &builder, Location loc, - Type result_type, Value value, - ValueRange shapes) const { - auto unknown_rank_extent_tensor_type = RankedTensorType::get( - {RankedTensorType::kDynamicSize}, builder.getIndexType()); - Value broadcast_shape = builder.create( - loc, unknown_rank_extent_tensor_type, shapes, nullptr); - return builder.create(loc, result_type, value, - broadcast_shape); - } - - // Returns the dynamic result of checking the given value is effectively a - // scalar shape (i.e. the number of elements is 1). - Value GreaterRankIsN(OpBuilder &builder, Location loc, Value actual_rank, - int targeted_rank) const { - return builder.create( - loc, CmpIPredicate::eq, actual_rank, - builder.create(loc, targeted_rank)); - } - - scf::IfOp createIfOpForRankSpecializedBroadcastAndOp( - OpBuilder &builder, ChloOpTy op, Value actual_rank, - int targeted_rank) const { - // Create the if block to place the current specialized logic in. - Value greater_rank_is_n = - GreaterRankIsN(builder, op.getLoc(), actual_rank, targeted_rank); - return builder.create(op.getLoc(), op.getResult().getType(), - greater_rank_is_n, true); - } - - Value createBroadcastToKnownRank(OpBuilder &builder, ChloOpTy op, Value shape, - int targeted_rank) const { - auto loc = op.getLoc(); - SmallVector ranked_shape(targeted_rank, 1); - auto unknown_rank_extent_tensor_type = RankedTensorType::get( - {RankedTensorType::kDynamicSize}, builder.getIndexType()); - auto known_rank_extent_tensor_type = - RankedTensorType::get({targeted_rank}, builder.getIndexType()); - Value ranked_shape_val = builder.create( - loc, known_rank_extent_tensor_type, - mlir::DenseIntElementsAttr::get(known_rank_extent_tensor_type, - ranked_shape)); - Value extended_value = builder.create( - loc, unknown_rank_extent_tensor_type, shape, ranked_shape_val, nullptr); - return builder.create(loc, known_rank_extent_tensor_type, - extended_value); - } - - // Create the if statement and code for a broadcasting op with a result of a - // given rank. - void createRankSpecializedBroadcastAndOp(OpBuilder &if_builder, ChloOpTy op, - ValueRange operands, - ValueRange operand_shapes, - int targeted_rank) const { - auto loc = op.getLoc(); - SmallVector reshaped_operands; - - auto dynamic_dimensions = llvm::SmallVector( - targeted_rank, RankedTensorType::kDynamicSize); - - for (auto it : llvm::zip(operands, operand_shapes)) { - Value operand, shape; - std::tie(operand, shape) = it; - // Handle shape broadcasting and inference. - Value extended_operand_casted = - createBroadcastToKnownRank(if_builder, op, shape, targeted_rank); - - // 1. Reshape operands to the given rank (with the same number of - // elements) - // 2. Compute the ranked-broadcasted ChloOp (which will assert that the - // ops - // can be broadcasted and do the actual broadcasting) - // 3. Type erase the output back to unranked - auto reshaped_type = RankedTensorType::get( - dynamic_dimensions, - operand.getType().template dyn_cast().getElementType()); - Value reshaped_operand = if_builder.create( - loc, reshaped_type, operand, extended_operand_casted); - reshaped_operands.push_back(reshaped_operand); - } - auto result_element_type = op.getResult() - .getType() - .template dyn_cast() - .getElementType(); - auto result_type = - RankedTensorType::get(dynamic_dimensions, result_element_type); - Value result = if_builder.create( - loc, ArrayRef{result_type}, reshaped_operands, op->getAttrs()); - Value reshaped_result = if_builder.create( - loc, UnrankedTensorType::get(result_element_type), result); - if_builder.create(loc, reshaped_result); - } - - // Iterates over the desired ranks to be specialized and generates the code - // snippet for each case. - Value HandleBroadcastAndOp(OpBuilder &rewriter, ChloOpTy op, - ValueRange operands) const { - auto loc = op.getLoc(); - - // Get the minimum broadcast shapes of the operands. - SmallVector shapes; - shapes.reserve(operands.size()); - auto extent_tensor_type = RankedTensorType::get({ShapedType::kDynamicSize}, - rewriter.getIndexType()); - for (Value operand : operands) { - Value shape = - rewriter.create(loc, extent_tensor_type, operand); - shapes.push_back(shape); - } - auto broadcast_shape = rewriter.create( - loc, extent_tensor_type, shapes, nullptr); - SmallVector result_types(shapes.size(), extent_tensor_type); - auto reduced_shapes = - rewriter - .create(loc, result_types, shapes) - .results(); - SmallVector reshaped_operands; - reshaped_operands.reserve(operands.size()); - for (auto it : llvm::zip(operands, reduced_shapes)) { - Value operand; - Value reduced_shape; - std::tie(operand, reduced_shape) = it; - auto reshaped_operand = rewriter.create( - loc, operand.getType(), operand, reduced_shape); - reshaped_operands.push_back(reshaped_operand); - } - - // Find the largest rank of the operands. - Value greater_rank; - for (Value shape : reduced_shapes) { - Value rank = - rewriter.create(loc, rewriter.getIndexType(), shape); - if (!greater_rank) { - greater_rank = rank; - } else { - Value greater_rank_compare = rewriter.create( - loc, CmpIPredicate::sgt, greater_rank, rank); - greater_rank = rewriter.create(loc, greater_rank_compare, - greater_rank, rank); - } - } - - // Generate a list of nested if/else statements to handle rank - // specializations from 1 to `kMaxRankSpecialization`. - scf::IfOp if_op = createIfOpForRankSpecializedBroadcastAndOp( - rewriter, op, greater_rank, 1); - OpBuilder if_builder = if_op.getThenBodyBuilder(rewriter.getListener()); - createRankSpecializedBroadcastAndOp(if_builder, op, reshaped_operands, - reduced_shapes, 1); - - // Put each subsequent rank specialization inside the else statement of the - // previous one. - OpBuilder else_builder = if_op.getElseBodyBuilder(rewriter.getListener()); - - // Tensorflow supports up to rank 8 for SelectOp (currently the only op with - // arity > 2 that we support), but only up to rank 5 for binary ops. We want - // to preserve this behavior. - const int kMaxRankSpecialization = operands.size() > 2 ? 8 : 5; - for (int i = 2; i < kMaxRankSpecialization; i++) { - auto inner_if = createIfOpForRankSpecializedBroadcastAndOp( - else_builder, op, greater_rank, i); - if_builder = inner_if.getThenBodyBuilder(rewriter.getListener()); - createRankSpecializedBroadcastAndOp(if_builder, op, reshaped_operands, - reduced_shapes, i); - else_builder.create(loc, inner_if.getResult(0)); - else_builder = inner_if.getElseBodyBuilder(rewriter.getListener()); - } - // Fire an assertion if none of the rank specializations applied (one of - // the ranks was greater than `kMaxRankSpecialization`). - else_builder.create( - loc, - GreaterRankIsN(else_builder, op.getLoc(), greater_rank, - kMaxRankSpecialization), - "Input for dynamic binary op lowering was of a rank greater than " + - std::to_string(kMaxRankSpecialization)); - // Add the rank 5 specialization to the innermost else block. - createRankSpecializedBroadcastAndOp(else_builder, op, reshaped_operands, - reduced_shapes, kMaxRankSpecialization); - - // Return the reshaped result of the outermost if statement. - auto result = if_op.getResult(0); - auto reshaped_result = rewriter.create( - loc, result.getType(), result, broadcast_shape); - return reshaped_result; - } -}; - -struct TransformUnrankedHloPass - : public mhlo::TransformUnrankedHloPassBase { - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - } - - void runOnFunction() override { - // Setup conversion target. - MLIRContext &ctx = getContext(); - ConversionTarget target(ctx); - target.addLegalDialect(); - target.addLegalOp(); -#define ADD_LEGAL_MHLO(op) AddLegalOpOnRankedTensor(&target) -#define ADD_LEGAL_CHLO(op) AddLegalOpOnRankedTensor(&target) - MAP_XLA_OPERATION_CWISE_UNARY(ADD_LEGAL_MHLO, ;); - MAP_XLA_OPERATION_CWISE_BINARY(ADD_LEGAL_MHLO, ;); - MAP_CHLO_OPERATION_CWISE_UNARY(ADD_LEGAL_CHLO, ;); -#undef ADD_LEGAL_MHLO -#undef ADD_LEGAL_CHLO - AddLegalOpOnRankedTensor(&target); - AddLegalOpOnRankedTensor(&target); - target.addDynamicallyLegalDialect( - [](Operation *op) { - return llvm::none_of(op->getOperandTypes(), [](Type type) { - return type.isa(); - }); - }); - - // Populate rewrite patterns. - OwningRewritePatternList patterns(&ctx); - mhlo::PopulateTransformUnrankedHloPatterns(&ctx, &patterns); - - // Apply transformation. - if (failed( - applyPartialConversion(getFunction(), target, std::move(patterns)))) - return signalPassFailure(); - } -}; - -} // namespace - -namespace mhlo { - -void PopulateTransformUnrankedHloPatterns(MLIRContext *context, - OwningRewritePatternList *patterns) { -#define MAP_HLO(op) ElementwiseOpConversion -#define MAP_CHLO(op) ElementwiseOpConversion -#define COMMA , - // clang-format off - patterns->insert< - MAP_XLA_OPERATION_CWISE_UNARY(MAP_HLO, COMMA), - MAP_XLA_OPERATION_CWISE_BINARY(MAP_HLO, COMMA), - MAP_CHLO_OPERATION_CWISE_UNARY(MAP_CHLO, COMMA), - MAP_CHLO_OPERATION_CWISE_BINARY(MAP_CHLO, COMMA), - ElementwiseOpConversion, - ElementwiseOpConversion>(context); - // clang-format on -#undef MAP_HLO -#undef MAP_CHLO -#undef COMMA - chlo::PopulateForBroadcastingBinaryOp( - context, patterns); - patterns->insert>>(context); - chlo::PopulateForBroadcastingBinaryOp< - ConvertUnrankedScalarDynamicBroadcastBinaryOp>(context, patterns); -} - -std::unique_ptr createTransformUnrankedHloPass() { - return std::make_unique(); -} - -} // namespace mhlo -} // namespace mlir diff --git a/tests/hlo-transform-unranked.mlir b/tests/hlo-transform-unranked.mlir deleted file mode 100644 index 52d6695..0000000 --- a/tests/hlo-transform-unranked.mlir +++ /dev/null @@ -1,410 +0,0 @@ -// RUN: mlir-hlo-opt --mhlo-transform-unranked-hlo --cse --split-input-file %s | FileCheck %s - -// Check the validity of expected IR. -// CHECK-LABEL: @sqr_transform_result -func @sqr_transform_result(%a: tensor<*xf32>) -> tensor<*xf32> { - - // Flatten operand shape. - %shape = shape.shape_of %a : tensor<*xf32> -> tensor - %num_elements = shape.num_elements %shape : tensor -> index - %flat_shape = tensor.from_elements %num_elements : tensor<1xindex> - %flat_a = "mhlo.dynamic_reshape"(%a, %flat_shape) - : (tensor<*xf32>, tensor<1xindex>) -> tensor - - // Apply operation. - %flat_b = "mhlo.sqrt"(%flat_a) : (tensor) -> tensor - - // Restore original shape. - %b = "mhlo.dynamic_reshape"(%flat_b, %shape) - : (tensor, tensor) -> tensor<*xf32> - - return %b : tensor<*xf32> -} - -// ----- - -// Check transformation of unranked code. -// CHECK-LABEL: @sqrt -// CHECK-SAME: (%[[A:.*]]: tensor<*xf32>) -func @sqrt(%a: tensor<*xf32>) -> tensor<*xf32> { - // CHECK-NEXT: %[[SHAPE:.*]] = shape.shape_of %[[A]] : tensor<*xf32> -> tensor - // CHECK-NEXT: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE]] - // CHECK-NEXT: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[NUM_ELEMENTS]] : tensor<1xindex> - // CHECK-NEXT: %[[FLAT_A:.*]] = "mhlo.dynamic_reshape"(%[[A]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor - // CHECK-NEXT: %[[FLAT_B:.*]] = "mhlo.sqrt"(%[[FLAT_A]]) : (tensor) -> tensor - // CHECK-NEXT: %[[B:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_B]], %[[SHAPE]]) : (tensor, tensor) -> tensor<*xf32> - // CHECK-NEXT: return %[[B]] : tensor<*xf32> - %b = "mhlo.sqrt"(%a) : (tensor<*xf32>) -> tensor<*xf32> - return %b : tensor<*xf32> -} - -// ----- - -// Not transformed when ranked. -// CHECK-LABEL: @sqrt_ranked -// CHECK-SAME: (%[[A:.*]]: tensor<3x?xf32>) -func @sqrt_ranked(%a: tensor<3x?xf32>) -> tensor<3x?xf32> { - // CHECK-NEXT: %[[B:.*]] = "mhlo.sqrt"(%[[A]]) : (tensor<3x?xf32>) -> tensor<3x?xf32> - // CHECK-NEXT: return %[[B]] : tensor<3x?xf32> - %b = "mhlo.sqrt"(%a) : (tensor<3x?xf32>) -> tensor<3x?xf32> - return %b : tensor<3x?xf32> -} - -// ----- - -// Not transformed when statically shaped. -// CHECK-LABEL: @sqrt_static -// CHECK-SAME: (%[[A:.*]]: tensor<2x3xf32>) -func @sqrt_static(%a: tensor<2x3xf32>) -> tensor<2x3xf32> { - // CHECK-NEXT: %[[B:.*]] = "mhlo.sqrt"(%[[A]]) : (tensor<2x3xf32>) -> tensor<2x3xf32> - // CHECK-NEXT: return %[[B]] : tensor<2x3xf32> - %b = "mhlo.sqrt"(%a) : (tensor<2x3xf32>) -> tensor<2x3xf32> - return %b : tensor<2x3xf32> -} - -// ----- - -// Transformed if there is a mix of unranked/static shapes. -// CHECK-LABEL: @select_mixed -// CHECK-SAME: (%[[PRED:.*]]: tensor<*xi1>, %[[ON_TRUE:.*]]: tensor<*xf32>, %[[ON_FALSE:.*]]: tensor<2xf32>) -func @select_mixed(%pred: tensor<*xi1>, %on_true: tensor<*xf32>, %on_false: tensor<2xf32>) -> tensor<*xf32> { - // CHECK: %[[SHAPE_PRED:.*]] = shape.shape_of %[[PRED]] - // CHECK: %[[SHAPE_ON_TRUE:.*]] = shape.shape_of %[[ON_TRUE]] - // CHECK: %[[SHAPE_ON_FALSE:.*]] = shape.shape_of %[[ON_FALSE]] - // CHECK: %[[SHAPE:.*]] = shape.any %[[SHAPE_PRED]], %[[SHAPE_ON_TRUE]], %[[SHAPE_ON_FALSE]] - // CHECK: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE]] - // CHECK: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[NUM_ELEMENTS]] : tensor<1xindex> - // CHECK: %[[FLAT_PRED:.*]] = "mhlo.dynamic_reshape"(%[[PRED]], %[[FLAT_SHAPE]]) : (tensor<*xi1>, tensor<1xindex>) -> tensor - // CHECK: %[[FLAT_ON_TRUE:.*]] = "mhlo.dynamic_reshape"(%[[ON_TRUE]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor - // CHECK: %[[FLAT_ON_FALSE:.*]] = "mhlo.dynamic_reshape"(%[[ON_FALSE]], %[[FLAT_SHAPE]]) : (tensor<2xf32>, tensor<1xindex>) -> tensor - // CHECK: %[[FLAT_RESULT:.*]] = "mhlo.select"(%[[FLAT_PRED]], %[[FLAT_ON_TRUE]], %[[FLAT_ON_FALSE]]) : (tensor, tensor, tensor) -> tensor - // CHECK: %[[RESULT:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_RESULT]], %[[SHAPE]]) : (tensor, tensor) -> tensor<*xf32> - // CHECK: return %[[RESULT]] : tensor<*xf32> - %b = "mhlo.select"(%pred, %on_true, %on_false) : (tensor<*xi1>, tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32> - return %b : tensor<*xf32> -} - -// ----- - -// CHECK-LABEL: @add_unranked -// CHECK-SAME: (%[[A:.*]]: tensor<*xf32>, %[[B:.*]]: tensor<*xf32>) -> tensor<*xf32> -func @add_unranked(%a : tensor<*xf32>, %b : tensor<*xf32>) -> tensor<*xf32> { - // CHECK: %[[SHAPE_A:.*]] = shape.shape_of %[[A]] - // CHECK: %[[SHAPE_B:.*]] = shape.shape_of %[[B]] - // CHECK: %[[SHAPE:.*]] = shape.any %[[SHAPE_A]], %[[SHAPE_B]] - // CHECK: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE]] - // CHECK: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[NUM_ELEMENTS]] : tensor<1xindex> - // CHECK: %[[FLAT_A:.*]] = "mhlo.dynamic_reshape"(%[[A]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor - // CHECK: %[[FLAT_B:.*]] = "mhlo.dynamic_reshape"(%[[B]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor - // CHECK: %[[FLAT_RESULT:.*]] = mhlo.add %[[FLAT_A]], %[[FLAT_B]] : tensor - // CHECK: %[[RESULT:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_RESULT]], %[[SHAPE]]) : (tensor, tensor) -> tensor<*xf32> - // CHECK: return %[[RESULT]] : tensor<*xf32> - %result = mhlo.add %a, %b : tensor<*xf32> - return %result : tensor<*xf32> -} - -// ----- - -// CHECK-LABEL: @tan -// CHECK-SAME: (%[[A:.*]]: tensor<*xf32>) -> tensor<*xf32> -func @tan(%a : tensor<*xf32>) -> tensor<*xf32> { - // CHECK: %[[SHAPE:.*]] = shape.shape_of %[[A]] : tensor<*xf32> -> tensor - // CHECK: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE]] - // CHECK: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[NUM_ELEMENTS]] : tensor<1xindex> - // CHECK: %[[FLAT_A:.*]] = "mhlo.dynamic_reshape"(%[[A]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor - // CHECK: %[[FLAT_B:.*]] = chlo.tan %[[FLAT_A]] : tensor -> tensor - // CHECK: %[[B:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_B]], %[[SHAPE]]) : (tensor, tensor) -> tensor<*xf32> - // CHECK: return %[[B]] : tensor<*xf32> - %result = chlo.tan %a : tensor<*xf32> -> tensor<*xf32> - return %result : tensor<*xf32> -} - -// ----- - -func @addScalarUnranked(%arg0: tensor, %arg1: tensor<*xf32>) -> tensor<*xf32> { - %0 = chlo.broadcast_add %arg0, %arg1 : (tensor, tensor<*xf32>) - -> tensor<*xf32> - return %0 : tensor<*xf32> -} - -// CHECK-LABEL: func @addScalarUnranked( -// CHECK-SAME: %[[ARG_0:.*]]: tensor, -// CHECK-SAME: %[[ARG_1:.*]]: tensor<*xf32> -// CHECK-SAME: ) -> tensor<*xf32> { -// First handle the dynamic reshaping of the unranked operand -// to a 1D tensor. -// CHECK-NEXT: %[[SHAPE_1:.*]] = shape.shape_of %[[ARG_1]] : tensor<*xf32> -// CHECK-NEXT: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE_1]] : tensor -> index -// CHECK-NEXT: %[[SIZE_TENSOR:.*]] = tensor.from_elements %[[NUM_ELEMENTS]] : tensor<1xindex> -// CHECK-NEXT: %[[RESHAPED:.*]] = "mhlo.dynamic_reshape"(%[[ARG_1]], %[[SIZE_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor -// CHECK-NEXT: %[[BROADCASTED_RESULT:.*]] = chlo.broadcast_add %[[ARG_0]], %[[RESHAPED]] : (tensor, tensor) -> tensor -// As part of the unranked logic, the result is reshaped back -// to an unranked tensor. -// CHECK-NEXT: %[[RESHAPED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[BROADCASTED_RESULT:.*]], %[[SHAPE_1]]) : (tensor, tensor) -> tensor<*xf32> -// CHECK-NEXT: return %[[RESHAPED_RESULT]] : tensor<*xf32> -// CHECK-NEXT: } - -// ----- -func @addUnrankedScalar(%arg0: tensor<*xf32>, %arg1: tensor) -> tensor<*xf32> { - %0 = chlo.broadcast_add %arg0, %arg1 : (tensor<*xf32>, tensor) - -> tensor<*xf32> - return %0 : tensor<*xf32> -} -// CHECK-LABEL: func @addUnrankedScalar( -// CHECK-SAME: %[[ARG_0:.*]]: tensor<*xf32>, -// CHECK-SAME: %[[ARG_1:.*]]: tensor) -> tensor<*xf32> { -// First handle the dynamic reshaping of the unranked operand -// to a 1D tensor. -// CHECK-NEXT: %[[SHAPE_0:.*]] = shape.shape_of %[[ARG_0]] : tensor<*xf32> -// CHECK-NEXT: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE_0]] : tensor -> index -// CHECK-NEXT: %[[SIZE_TENSOR:.*]] = tensor.from_elements %[[NUM_ELEMENTS]] : tensor<1xindex> -// CHECK-NEXT: %[[RESHAPED:.*]] = "mhlo.dynamic_reshape"(%[[ARG_0]], %[[SIZE_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor -// The assuming region is part of the second stage of lowering -// with ranked broadcasting logic. -// CHECK-NEXT: %[[BROADCASTED_RESULT:.*]] = chlo.broadcast_add %[[RESHAPED]], %[[ARG_1]] : (tensor, tensor) -> tensor -// As part of the unranked logic, the result is reshaped back -// to an unranked tensor. -// CHECK-NEXT: %[[RESHAPED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[BROADCASTED_RESULT:.*]], %[[SHAPE_0]]) : (tensor, tensor) -> tensor<*xf32> -// CHECK-NEXT: return %[[RESHAPED_RESULT]] : tensor<*xf32> -// CHECK-NEXT: } - -// ----- -func @addUnrankedUnranked( - %arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> { - %0 = chlo.broadcast_add %arg0, %arg1 : (tensor<*xf32>, tensor<*xf32>) - -> tensor<*xf32> - return %0 : tensor<*xf32> -} - -// CHECK-LABEL: func @addUnrankedUnranked( -// CHECK-SAME: %[[LHS:.*]]: tensor<*xf32>, -// CHECK-SAME: %[[RHS:.*]]: tensor<*xf32>) -> tensor<*xf32> { -// CHECK-NEXT: %[[LHS_SHAPE:.*]] = shape.shape_of %[[LHS]] : tensor<*xf32> -> tensor -// CHECK-NEXT: %[[RHS_SHAPE:.*]] = shape.shape_of %[[RHS]] : tensor<*xf32> -> tensor -// CHECK-NEXT: %[[C0:.*]] = constant 0 : index -// CHECK-NEXT: %[[C1:.*]] = constant 1 : index -// CHECK-NEXT: %[[NUM_LHS:.*]] = shape.num_elements %[[LHS_SHAPE]] : tensor -> index -// CHECK-NEXT: %[[LHS_IS_SCALAR:.*]] = cmpi eq, %[[NUM_LHS]], %[[C1]] : index -// CHECK-NEXT: %[[COUNTER_PLUS_ONE:.*]] = addi %[[C0]], %[[C1]] : index -// CHECK-NEXT: %[[COUNTER:.*]] = select %[[LHS_IS_SCALAR]], %[[COUNTER_PLUS_ONE]], %[[C0]] : index -// CHECK-NEXT: %[[NUM_RHS:.*]] = shape.num_elements %[[RHS_SHAPE]] : tensor -> index -// CHECK-NEXT: %[[RHS_IS_SCALAR:.*]] = cmpi eq, %[[NUM_RHS]], %[[C1]] : index -// CHECK-NEXT: %[[COUNTER_PLUS_ONE2:.*]] = addi %[[COUNTER]], %[[C1]] : index -// CHECK-NEXT: %[[COUNTER2:.*]] = select %[[RHS_IS_SCALAR]], %[[COUNTER_PLUS_ONE2]], %[[COUNTER]] : index -// Handle scalar case -// CHECK-NEXT: %[[IS_SCALAR_CASE:.*]] = cmpi uge, %[[COUNTER2]], %[[C1]] : index -// CHECK-NEXT: %[[VAL_8:.*]] = scf.if %[[IS_SCALAR_CASE]] -> (tensor<*xf32>) { -// CHECK-NEXT: %[[NUM_TENS_LHS:.*]] = tensor.from_elements %[[NUM_LHS]] : tensor<1xindex> -// CHECK-NEXT: %[[RESHAPED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[NUM_TENS_LHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor -// CHECK-NEXT: %[[NUM_TENS_RHS:.*]] = tensor.from_elements %[[NUM_RHS]] : tensor<1xindex> -// CHECK-NEXT: %[[RESHAPED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[NUM_TENS_RHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor -// CHECK-NEXT: %[[SCALAR_RESULT:.*]] = chlo.broadcast_add %[[RESHAPED_LHS]], %[[RESHAPED_RHS]] : (tensor, tensor) -> tensor -// CHECK-NEXT: %[[SHAPE_BROADCAST:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor, tensor -> tensor -// CHECK-NEXT: %[[RESHAPED_EXTENDED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[SCALAR_RESULT]], %[[SHAPE_BROADCAST]]) : (tensor, tensor) -> tensor<*xf32> -// CHECK-NEXT: scf.yield %[[RESHAPED_EXTENDED_RESULT]] : tensor<*xf32> -// CHECK-NEXT: } else { -// CHECK-NEXT: %[[SHAPES_EQ:.*]] = shape.shape_eq %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor, tensor -// Handle equal shapes case -// CHECK-NEXT: %[[VAL_18:.*]] = scf.if %[[SHAPES_EQ]] -> (tensor<*xf32>) { -// CHECK-NEXT: %[[ANY_SHAPE:.*]] = shape.any %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor, tensor -> tensor -// CHECK-NEXT: %[[ANY_NUM:.*]] = shape.num_elements %[[ANY_SHAPE]] : tensor -> index -// CHECK-NEXT: %[[ANY_TENSOR:.*]] = tensor.from_elements %[[ANY_NUM]] : tensor<1xindex> -// CHECK-NEXT: %[[FLATTENED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[ANY_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor -// CHECK-NEXT: %[[FLATTENED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[ANY_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor -// CHECK-NEXT: %[[FLATTENED_RESULT:.*]] = mhlo.add %[[FLATTENED_LHS]], %[[FLATTENED_RHS]] : tensor -// CHECK-NEXT: %[[RESHAPED_SAME_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[FLATTENED_RESULT]], %[[ANY_SHAPE]]) : (tensor, tensor) -> tensor<*xf32> -// CHECK-NEXT: scf.yield %[[RESHAPED_SAME_RESULT]] : tensor<*xf32> -// CHECK-NEXT: } else { -// CHECK-NEXT: %[[RESULT_SHAPE:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor, tensor -> tensor -// CHECK-NEXT: %[[MINIMUM_SHAPES:.*]]:2 = chlo.minimum_broadcast_shapes %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor, tensor -> tensor, tensor -// CHECK-NEXT: %[[MINIMUM_RESHAPED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[MINIMUM_SHAPES]]#0) : (tensor<*xf32>, tensor) -> tensor<*xf32> -// CHECK-NEXT: %[[MINIMUM_RESHAPED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[MINIMUM_SHAPES]]#1) : (tensor<*xf32>, tensor) -> tensor<*xf32> -// CHECK-NEXT: %[[LHS_RANK:.*]] = shape.rank %[[MINIMUM_SHAPES]]#0 : tensor -> index -// CHECK-NEXT: %[[RHS_RANK:.*]] = shape.rank %[[MINIMUM_SHAPES]]#1 : tensor -> index -// CHECK-NEXT: %[[LHS_RANK_GREATER:.*]] = cmpi sgt, %[[LHS_RANK]], %[[RHS_RANK]] : index -// CHECK-NEXT: %[[GREATEST_RANK:.*]] = select %[[LHS_RANK_GREATER]], %[[LHS_RANK]], %[[RHS_RANK]] : index -// Handle rank 1 specialization -// CHECK-NEXT: %[[GREATEST_RANK_IS_1:.*]] = cmpi eq, %[[GREATEST_RANK]], %[[C1]] : index -// CHECK-NEXT: %[[RESULT_RANK_SPECIALIZATION:.*]] = scf.if %[[GREATEST_RANK_IS_1]] -> (tensor<*xf32>) { -// CHECK-NEXT: %[[CONST_SHAPE_1:.*]] = shape.const_shape [1] -// CHECK-NEXT: %[[BROADCASTED_LHS_1:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_1]] : tensor, tensor<1xindex> -> tensor -// CHECK-NEXT: %[[CASTED_LHS_1:.*]] = tensor.cast %[[BROADCASTED_LHS_1]] : tensor to tensor<1xindex> -// CHECK-NEXT: %[[RESHAPED_LHS_1:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS_1]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor -// CHECK-NEXT: %[[BROADCASTED_RHS_1:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_1]] : tensor, tensor<1xindex> -> tensor -// CHECK-NEXT: %[[CASTED_RHS_1:.*]] = tensor.cast %[[BROADCASTED_RHS_1]] : tensor to tensor<1xindex> -// CHECK-NEXT: %[[RESHAPED_RHS_1:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS_1]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor -// CHECK-NEXT: %[[RESULT_RANK_1:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_1]], %[[RESHAPED_RHS_1]] : (tensor, tensor) -> tensor -// CHECK-NEXT: %[[RESULT_1:.*]] = tensor.cast %[[RESULT_RANK_1]] : tensor to tensor<*xf32> -// CHECK-NEXT: scf.yield %[[RESULT_1]] : tensor<*xf32> -// CHECK-NEXT: } else { -// CHECK-NEXT: %[[C2:.*]] = constant 2 : index -// CHECK-NEXT: %[[GREATEST_RANK_IS_2:.*]] = cmpi eq, %[[GREATEST_RANK]], %[[C2]] : index -// Handle rank 2 specialization -// CHECK-NEXT: %[[VAL_26:.*]] = scf.if %[[GREATEST_RANK_IS_2]] -> (tensor<*xf32>) { -// CHECK-NEXT: %[[CONST_SHAPE_2:.*]] = shape.const_shape [1, 1] -// CHECK-NEXT: %[[BROADCASTED_LHS_2:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_2]] : tensor, tensor<2xindex> -> tensor -// CHECK-NEXT: %[[CASTED_LHS_2:.*]] = tensor.cast %[[BROADCASTED_LHS_2]] : tensor to tensor<2xindex> -// CHECK-NEXT: %[[RESHAPED_LHS_2:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor -// CHECK-NEXT: %[[BROADCASTED_RHS_2:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_2]] : tensor, tensor<2xindex> -> tensor -// CHECK-NEXT: %[[CASTED_RHS_2:.*]] = tensor.cast %[[BROADCASTED_RHS_2]] : tensor to tensor<2xindex> -// CHECK-NEXT: %[[RESHAPED_RHS_2:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor -// CHECK-NEXT: %[[RESULT_RANK_2:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_2]], %[[RESHAPED_RHS_2]] : (tensor, tensor) -> tensor -// CHECK-NEXT: %[[RESULT_2:.*]] = tensor.cast %[[RESULT_RANK_2]] : tensor to tensor<*xf32> -// CHECK-NEXT: scf.yield %[[RESULT_2]] : tensor<*xf32> -// CHECK-NEXT: } else { -// CHECK-NEXT: %[[C3:.*]] = constant 3 : index -// CHECK-NEXT: %[[GREATEST_RANK_IS_3:.*]] = cmpi eq, %[[GREATEST_RANK]], %[[C3]] : index -// Handle rank 3 specialization -// CHECK-NEXT: %[[VAL_34:.*]] = scf.if %[[GREATEST_RANK_IS_3]] -> (tensor<*xf32>) { -// CHECK-NEXT: %[[CONST_SHAPE_3:.*]] = shape.const_shape [1, 1, 1] -// CHECK-NEXT: %[[BROADCASTED_LHS_3:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_3]] : tensor, tensor<3xindex> -> tensor -// CHECK-NEXT: %[[CASTED_LHS_3:.*]] = tensor.cast %[[BROADCASTED_LHS_3]] : tensor to tensor<3xindex> -// CHECK-NEXT: %[[RESHAPED_LHS_3:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor -// CHECK-NEXT: %[[BROADCASTED_RHS_3:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_3]] : tensor, tensor<3xindex> -> tensor -// CHECK-NEXT: %[[CASTED_RHS_3:.*]] = tensor.cast %[[BROADCASTED_RHS_3]] : tensor to tensor<3xindex> -// CHECK-NEXT: %[[RESHAPED_RHS_3:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor -// CHECK-NEXT: %[[RESULT_RANK_3:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_3]], %[[RESHAPED_RHS_3]] : (tensor, tensor) -> tensor -// CHECK-NEXT: %[[RESULT_3:.*]] = tensor.cast %[[RESULT_RANK_3]] : tensor to tensor<*xf32> -// CHECK-NEXT: scf.yield %[[RESULT_3]] : tensor<*xf32> -// CHECK-NEXT: } else { -// CHECK-NEXT: %[[C4:.*]] = constant 4 : index -// CHECK-NEXT: %[[GREATEST_RANK_IS_4:.*]] = cmpi eq, %[[GREATEST_RANK]], %[[C4]] : index -// Handle rank 4 specialization -// CHECK-NEXT: %[[VAL_42:.*]] = scf.if %[[GREATEST_RANK_IS_4]] -> (tensor<*xf32>) { -// CHECK-NEXT: %[[CONST_SHAPE_4:.*]] = shape.const_shape [1, 1, 1, 1] -// CHECK-NEXT: %[[BROADCASTED_LHS_4:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_4]] : tensor, tensor<4xindex> -> tensor -// CHECK-NEXT: %[[CASTED_LHS_4:.*]] = tensor.cast %[[BROADCASTED_LHS_4]] : tensor to tensor<4xindex> -// CHECK-NEXT: %[[RESHAPED_LHS_4:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor -// CHECK-NEXT: %[[BROADCASTED_RHS_4:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_4]] : tensor, tensor<4xindex> -> tensor -// CHECK-NEXT: %[[CASTED_RHS_4:.*]] = tensor.cast %[[BROADCASTED_RHS_4]] : tensor to tensor<4xindex> -// CHECK-NEXT: %[[RESHAPED_RHS_4:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor -// CHECK-NEXT: %[[RESULT_RANK_4:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_4]], %[[RESHAPED_RHS_4]] : (tensor, tensor) -> tensor -// CHECK-NEXT: %[[RESULT_4:.*]] = tensor.cast %[[RESULT_RANK_4]] : tensor to tensor<*xf32> -// CHECK-NEXT: scf.yield %[[RESULT_4]] : tensor<*xf32> -// CHECK-NEXT: } else { -// CHECK-NEXT: %[[C5:.*]] = constant 5 : index -// CHECK-NEXT: %[[GREATEST_RANK_IS_5:.*]] = cmpi eq, %[[GREATEST_RANK]], %[[C5]] : index -// CHECK-NEXT: assert %[[GREATEST_RANK_IS_5]] -// Handle rank 5 specialization -// CHECK-NEXT: %[[CONST_SHAPE_5:.*]] = shape.const_shape [1, 1, 1, 1, 1] -// CHECK-NEXT: %[[BROADCASTED_LHS_5:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_5]] : tensor, tensor<5xindex> -> tensor -// CHECK-NEXT: %[[CASTED_LHS_5:.*]] = tensor.cast %[[BROADCASTED_LHS_5]] : tensor to tensor<5xindex> -// CHECK-NEXT: %[[RESHAPED_LHS_5:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor -// CHECK-NEXT: %[[BROADCASTED_RHS_5:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_5]] : tensor, tensor<5xindex> -> tensor -// CHECK-NEXT: %[[CASTED_RHS_5:.*]] = tensor.cast %[[BROADCASTED_RHS_5]] : tensor to tensor<5xindex> -// CHECK-NEXT: %[[RESHAPED_RHS_5:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor -// CHECK-NEXT: %[[RESULT_RANK_5:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_5]], %[[RESHAPED_RHS_5]] : (tensor, tensor) -> tensor -// CHECK-NEXT: %[[RESULT_5:.*]] = tensor.cast %[[RESULT_RANK_5]] : tensor to tensor<*xf32> -// CHECK-NEXT: scf.yield %[[RESULT_5]] : tensor<*xf32> -// CHECK-NEXT: } -// CHECK-NEXT: scf.yield %[[VAL_66:.*]] : tensor<*xf32> -// CHECK-NEXT: } -// CHECK-NEXT: scf.yield %[[VAL_67:.*]] : tensor<*xf32> -// CHECK-NEXT: } -// CHECK-NEXT: scf.yield %[[VAL_68:.*]] : tensor<*xf32> -// CHECK-NEXT: } -// CHECK-NEXT: %[[RESHAPED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[RESULT_RANK_SPECIALIZATION]], %[[RESULT_SHAPE]]) : (tensor<*xf32>, tensor) -> tensor<*xf32> -// CHECK-NEXT: scf.yield %[[RESHAPED_RESULT]] : tensor<*xf32> -// CHECK-NEXT: } -// CHECK-NEXT: scf.yield %[[VAL_71:.*]] : tensor<*xf32> -// CHECK-NEXT: } -// CHECK-NEXT: return %[[VAL_72:.*]] : tensor<*xf32> -// CHECK-NEXT: } - - -// ----- - -func @selectUnrankedUnrankedUnranked( - %arg0: tensor<*xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) - -> tensor<*xf32> { - %0 = chlo.broadcast_select %arg0, %arg1, %arg2 - : (tensor<*xi1>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> - return %0 : tensor<*xf32> -} - -// CHECK-LABEL: func @selectUnrankedUnrankedUnranked( -// CHECK-SAME: %[[PRED:.*]]: tensor<*xi1>, -// CHECK-SAME: %[[LHS:.*]]: tensor<*xf32>, -// CHECK-SAME: %[[RHS:.*]]: tensor<*xf32>) -> tensor<*xf32> { -// CHECK-NEXT: %[[PRED_SHAPE:.*]] = shape.shape_of %[[PRED]] : tensor<*xi1> -> tensor -// CHECK-NEXT: %[[LHS_SHAPE:.*]] = shape.shape_of %[[LHS]] : tensor<*xf32> -> tensor -// CHECK-NEXT: %[[RHS_SHAPE:.*]] = shape.shape_of %[[RHS]] : tensor<*xf32> -> tensor -// CHECK-NEXT: %c0 = constant 0 : index -// CHECK-NEXT: %c1 = constant 1 : index -// CHECK-NEXT: %[[NUM_ELEMENTS_PRED:.*]] = shape.num_elements %[[PRED_SHAPE]] : tensor -> index -// CHECK-NEXT: %[[PRED_IS_SCALAR:.*]] = cmpi eq, %[[NUM_ELEMENTS_PRED]], %c1 : index -// CHECK-NEXT: %[[COUNTER_PLUS_ONE:.*]] = addi %c0, %c1 : index -// CHECK-NEXT: %[[COUNTER:.*]] = select %[[PRED_IS_SCALAR]], %[[COUNTER_PLUS_ONE]], %c0 : index -// CHECK-NEXT: %[[NUM_ELEMENTS_LHS:.*]] = shape.num_elements %[[LHS_SHAPE]] : tensor -> index -// CHECK-NEXT: %[[LHS_IS_SCALAR:.*]] = cmpi eq, %[[NUM_ELEMENTS_LHS]], %c1 : index -// CHECK-NEXT: %[[COUNTER_PLUS_ONE:.*]] = addi %[[COUNTER]], %c1 : index -// CHECK-NEXT: %[[COUNTER2:.*]] = select %[[LHS_IS_SCALAR]], %[[COUNTER_PLUS_ONE]], %[[COUNTER]] : index -// CHECK-NEXT: %[[NUM_ELEMENTS_RHS:.*]] = shape.num_elements %[[RHS_SHAPE]] : tensor -> index -// CHECK-NEXT: %[[RHS_IS_SCALAR:.*]] = cmpi eq, %[[NUM_ELEMENTS_RHS]], %c1 : index -// CHECK-NEXT: %[[COUNTER_PLUS_ONE:.*]] = addi %[[COUNTER2]], %c1 : index -// CHECK-NEXT: %[[COUNTER3:.*]] = select %[[RHS_IS_SCALAR]], %[[COUNTER_PLUS_ONE]], %[[COUNTER2]] : index -// CHECK-NEXT: %c2 = constant 2 : index -// CHECK-NEXT: %[[IS_SCALAR_CASE:.*]] = cmpi uge, %[[COUNTER3]], %c2 : index -// CHECK-NEXT: %[[IF_IS_SCALAR_CASE:.*]] = scf.if %[[IS_SCALAR_CASE]] -> (tensor<*xf32>) { -// CHECK-NEXT: %[[NUM_TENS_PRED:.*]] = tensor.from_elements %[[NUM_ELEMENTS_PRED]] : tensor<1xindex> -// CHECK-NEXT: %[[RESHAPED_PRED:.*]] = "mhlo.dynamic_reshape"(%[[PRED]], %[[NUM_TENS_PRED]]) : (tensor<*xi1>, tensor<1xindex>) -> tensor -// CHECK-NEXT: %[[NUM_TENS_LHS:.*]] = tensor.from_elements %[[NUM_ELEMENTS_LHS]] : tensor<1xindex> -// CHECK-NEXT: %[[RESHAPED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[NUM_TENS_LHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor -// CHECK-NEXT: %[[NUM_TENS_RHS:.*]] = tensor.from_elements %[[NUM_ELEMENTS_RHS]] : tensor<1xindex> -// CHECK-NEXT: %[[RESHAPED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[NUM_TENS_RHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor -// CHECK-NEXT: %[[SCALAR_RESULT:.*]] = chlo.broadcast_select %[[RESHAPED_PRED]], %[[RESHAPED_LHS]], %[[RESHAPED_RHS]] : (tensor, tensor, tensor) -> tensor -// CHECK-NEXT: %[[SHAPE_BROADCAST:.*]] = shape.broadcast %[[PRED_SHAPE]], %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor, tensor, tensor -> tensor -// CHECK-NEXT: %[[RESHAPED_EXTENDED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[SCALAR_RESULT]], %[[SHAPE_BROADCAST]]) : (tensor, tensor) -> tensor<*xf32> -// CHECK-NEXT: scf.yield %[[RESHAPED_EXTENDED_RESULT]] : tensor<*xf32> -// CHECK-NEXT: } else { -// CHECK-NEXT: %[[FIRST_SHAPES_EQUAL:.*]] = shape.shape_eq %[[PRED_SHAPE]], %[[LHS_SHAPE]] : tensor, tensor -// CHECK-NEXT: %[[SECOND_SHAPES_EQUAL:.*]] = shape.shape_eq %[[PRED_SHAPE]], %[[RHS_SHAPE]] : tensor, tensor -// CHECK-NEXT: %[[ALL_SHAPES_EQUAL:.*]] = and %[[FIRST_SHAPES_EQUAL]], %[[SECOND_SHAPES_EQUAL]] : i1 -// CHECK-NEXT: %[[IF_EQUAL_CASE:.*]] = scf.if %[[ALL_SHAPES_EQUAL]] -> (tensor<*xf32>) { -// CHECK-NEXT: %[[ANY_SHAPE:.*]] = shape.any %[[PRED_SHAPE]], %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor, tensor, tensor -> tensor -// CHECK-NEXT: %[[ANY_NUM:.*]] = shape.num_elements %[[ANY_SHAPE]] : tensor -> index -// CHECK-NEXT: %[[ANY_TENSOR:.*]] = tensor.from_elements %[[ANY_NUM]] : tensor<1xindex> -// CHECK-NEXT: %[[FLATTENED_PRED:.*]] = "mhlo.dynamic_reshape"(%[[PRED]], %[[ANY_TENSOR]]) : (tensor<*xi1>, tensor<1xindex>) -> tensor -// CHECK-NEXT: %[[FLATTENED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[ANY_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor -// CHECK-NEXT: %[[FLATTENED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[ANY_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor -// CHECK-NEXT: %[[FLATTENED_RESULT:.*]] = "mhlo.select"(%[[FLATTENED_PRED]], %[[FLATTENED_LHS]], %[[FLATTENED_RHS]]) : (tensor, tensor, tensor) -> tensor -// CHECK-NEXT: %[[RESHAPED_SAME_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[FLATTENED_RESULT]], %[[ANY_SHAPE]]) : (tensor, tensor) -> tensor<*xf32> -// CHECK-NEXT: scf.yield %[[RESHAPED_SAME_RESULT]] : tensor<*xf32> -// CHECK-NEXT: } else { -// CHECK-NEXT: %[[RESULT_SHAPE:.*]] = shape.broadcast %[[PRED_SHAPE]], %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor, tensor, tensor -> tensor -// CHECK-NEXT: %[[MINIMUM_SHAPES:.*]]:3 = chlo.minimum_broadcast_shapes %[[PRED_SHAPE]], %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor, tensor, tensor -> tensor, tensor, tensor -// CHECK-NEXT: %[[MINIMUM_RESHAPED_PRED:.*]] = "mhlo.dynamic_reshape"(%[[PRED]], %[[MINIMUM_SHAPES]]#0) : (tensor<*xi1>, tensor) -> tensor<*xi1> -// CHECK-NEXT: %[[MINIMUM_RESHAPED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[MINIMUM_SHAPES]]#1) : (tensor<*xf32>, tensor) -> tensor<*xf32> -// CHECK-NEXT: %[[MINIMUM_RESHAPED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[MINIMUM_SHAPES]]#2) : (tensor<*xf32>, tensor) -> tensor<*xf32> -// CHECK-NEXT: %[[PRED_RANK:.*]] = shape.rank %[[MINIMUM_SHAPES]]#0 : tensor -> index -// CHECK-NEXT: %[[LHS_RANK:.*]] = shape.rank %[[MINIMUM_SHAPES]]#1 : tensor -> index -// CHECK-NEXT: %[[GREATER_RANK_CMP:.*]] = cmpi sgt, %[[PRED_RANK]], %[[LHS_RANK]] : index -// CHECK-NEXT: %[[GREATER_RANK:.*]] = select %[[GREATER_RANK_CMP]], %[[PRED_RANK]], %[[LHS_RANK]] : index -// CHECK-NEXT: %[[RHS_RANK:.*]] = shape.rank %[[MINIMUM_SHAPES]]#2 : tensor -> index -// CHECK-NEXT: %[[GREATEST_RANK_CMP:.*]] = cmpi sgt, %[[GREATER_RANK]], %[[RHS_RANK]] : index -// CHECK-NEXT: %[[GREATEST_RANK:.*]] = select %[[GREATEST_RANK_CMP]], %[[GREATER_RANK]], %[[RHS_RANK]] : index -// CHECK-NEXT: %[[GREATEST_RANK_IS_1:.*]] = cmpi eq, %[[GREATEST_RANK]], %c1 : index -// Handle rank 1 specialization -// CHECK-NEXT: scf.if %[[GREATEST_RANK_IS_1]] -> (tensor<*xf32>) { -// CHECK-NEXT: %[[CONST_SHAPE_1:.*]] = shape.const_shape [1] : tensor<1xindex> -// CHECK-NEXT: %[[BROADCASTED_PRED:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_1]] : tensor, tensor<1xindex> -> tensor -// CHECK-NEXT: %[[CASTED_PRED:.*]] = tensor.cast %[[BROADCASTED_PRED]] : tensor to tensor<1xindex> -// CHECK-NEXT: %[[RESHAPED_PRED:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_PRED]], %[[CASTED_PRED]]) : (tensor<*xi1>, tensor<1xindex>) -> tensor -// CHECK-NEXT: %[[BROADCASTED_LHS:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_1]] : tensor, tensor<1xindex> -> tensor -// CHECK-NEXT: %[[CASTED_LHS:.*]] = tensor.cast %[[BROADCASTED_LHS]] : tensor to tensor<1xindex> -// CHECK-NEXT: %[[RESHAPED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor -// CHECK-NEXT: %[[BROADCASTED_RHS:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#2, %[[CONST_SHAPE_1]] : tensor, tensor<1xindex> -> tensor -// CHECK-NEXT: %[[CASTED_RHS:.*]] = tensor.cast %[[BROADCASTED_RHS]] : tensor to tensor<1xindex> -// CHECK-NEXT: %[[RESHAPED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor -// CHECK-NEXT: %[[RESULT_RANK_1:.*]] = chlo.broadcast_select %[[RESHAPED_PRED]], %[[RESHAPED_LHS]], %[[RESHAPED_RHS]] : (tensor, tensor, tensor) -> tensor -// CHECK-NEXT: %[[RESULT_1:.*]] = tensor.cast %[[RESULT_RANK_1:.*]] : tensor to tensor<*xf32> -// CHECK-NEXT: scf.yield %[[RESULT_1]] : tensor<*xf32> -// CHECK-NEXT: } - -// CHECK: chlo.broadcast_select {{.*}} : (tensor, tensor, tensor) -> tensor -// CHECK: chlo.broadcast_select {{.*}} : (tensor, tensor, tensor) -> tensor -// CHECK: chlo.broadcast_select {{.*}} : (tensor, tensor, tensor) -> tensor -// CHECK: chlo.broadcast_select {{.*}} : (tensor, tensor, tensor) -> tensor -// CHECK: chlo.broadcast_select {{.*}} : (tensor, tensor, tensor) -> tensor -// CHECK: chlo.broadcast_select {{.*}} : (tensor, tensor, tensor) -> tensor -// CHECK: chlo.broadcast_select {{.*}} : (tensor, tensor, tensor) -> tensor