diff --git a/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc b/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc index 6e2c9db..e23e210 100644 --- a/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc +++ b/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc @@ -183,7 +183,9 @@ struct ConvertUnrankedScalarDynamicBroadcastBinaryOp Value size_tensor = rewriter.create(loc, num_elements); Value reshaped = rewriter.create( - loc, RankedTensorType::get({-1}, scalar_element_type), + loc, + RankedTensorType::get({RankedTensorType::kDynamicSize}, + scalar_element_type), lhs_is_scalar ? rhs : lhs, size_tensor); // Create a new ranked Chlo op that will be further lowered by other @@ -191,7 +193,9 @@ struct ConvertUnrankedScalarDynamicBroadcastBinaryOp SmallVector new_operands{lhs_is_scalar ? lhs : reshaped, rhs_is_scalar ? rhs : reshaped}; Value computed = rewriter.create( - loc, TypeRange{RankedTensorType::get({-1}, result_element_type)}, + loc, + TypeRange{RankedTensorType::get({RankedTensorType::kDynamicSize}, + result_element_type)}, new_operands, op->getAttrs()); // Reshape the result back into an unranked tensor. @@ -202,7 +206,7 @@ struct ConvertUnrankedScalarDynamicBroadcastBinaryOp } }; -template +template struct ConvertUnrankedDynamicBroadcastOpHelper { // Returns the dynamic result of checking the given value is effectively a // scalar shape (i.e. the number of elements is 1). @@ -378,14 +382,15 @@ struct ConvertUnrankedDynamicBroadcastOpHelper { // Handles lowering of the following pattern to patterns that will be further // matched by other patterns until they result in LHLO: -// %result = "chlo.op"(%lhs, %rhs) : (<*xTy>, <*xTy>) -> <*xTy> +// %result = "chlo.op"(%op0, %op1, ...) : (<*xTy>, <*xTy>, ...) -> <*xTy> // // The sequence of specializations this handles is: -// - Either operand being scalar -// - Operands having equal shapes -// - The resulting value being any of ranks [2,6] +// - At most one operand has a shape that does not consist of exactly one +// element. +// - All operands having equal shapes +// - The resulting minimized shapes being any of ranks [1,5] template -struct ConvertUnrankedDynamicBroadcastBinaryOp +struct ConvertUnrankedDynamicBroadcastNaryOp : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -394,76 +399,96 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); typename ChloOpTy::Adaptor transformed(operands); - Value lhs = transformed.lhs(); - Value rhs = transformed.rhs(); - auto lhs_type = lhs.getType().dyn_cast(); - auto rhs_type = rhs.getType().dyn_cast(); + ValueRange transformed_operands = transformed.getOperands(); + auto num_operands = transformed_operands.size(); + llvm::SmallVector operand_types; + operand_types.reserve(num_operands); + for (int i = 0; i < num_operands; ++i) { + auto type = + transformed_operands[i].getType().dyn_cast(); + // Only support unranked operands. + if (!type) return failure(); + operand_types.push_back(type); + } auto result_type = op.getResult().getType().template dyn_cast(); - // Only support unranked operands. If either operand is ranked, another - // pattern will handle the lowering. - if (!lhs_type || !rhs_type) return failure(); + llvm::SmallVector shapes; + shapes.reserve(num_operands); + for (int i = 0; i < num_operands; ++i) { + shapes.push_back( + rewriter.create(loc, transformed_operands[i])); + } - Value shape_of_lhs = rewriter.create(loc, lhs); - Value shape_of_rhs = rewriter.create(loc, rhs); + // If at most one shape does not have exactly one element + Value counter = rewriter.create(loc, 0); + Value one = rewriter.create(loc, 1); + for (int i = 0; i < num_operands; ++i) { + Value is_scalar_like = IsSingleElementShape(rewriter, op, shapes[i]); + Value counter_plus_one = rewriter.create(loc, counter, one); + counter = rewriter.create(loc, is_scalar_like, counter_plus_one, + counter); + } + Value num_operands_minus_one = + rewriter.create(loc, num_operands - 1); + Value at_most_one_non_scalar = + rewriter.create(loc, rewriter.getI1Type(), CmpIPredicate::uge, + counter, num_operands_minus_one); - // If lhs has exactly one element - auto if_op = rewriter.create( - loc, result_type, IsSingleElementShape(rewriter, op, shape_of_lhs), - true); - OpBuilder if_lhs_scalar_builder = + auto if_op = rewriter.create(loc, result_type, + at_most_one_non_scalar, true); + OpBuilder if_at_most_one_non_scalar_builder = if_op.getThenBodyBuilder(rewriter.getListener()); - Value reshaped_lhs = if_lhs_scalar_builder.create( - loc, RankedTensorType::get({}, lhs_type.getElementType()), lhs); - Value if_lhs_scalar_result = if_lhs_scalar_builder.create( - loc, ArrayRef{result_type}, ArrayRef{reshaped_lhs, rhs}, - op->getAttrs()); - Value extended_if_lhs_scalar_result = - extendToBroadcastShape(if_lhs_scalar_builder, loc, if_lhs_scalar_result, - shape_of_lhs, shape_of_rhs); - if_lhs_scalar_builder.create(loc, - extended_if_lhs_scalar_result); - // If lhs does not have exactly one element + llvm::SmallVector reshaped_operands; + reshaped_operands.reserve(num_operands); + for (int i = 0; i < num_operands; ++i) { + Value num_elements = + if_at_most_one_non_scalar_builder.create( + loc, shapes[i]); + Value size_tensor = + if_at_most_one_non_scalar_builder.create( + loc, num_elements); + Value reshaped = + if_at_most_one_non_scalar_builder.create( + loc, + RankedTensorType::get({RankedTensorType::kDynamicSize}, + operand_types[i].getElementType()), + transformed_operands[i], size_tensor); + reshaped_operands.push_back(reshaped); + } + + auto rank_one_result_type = RankedTensorType::get( + {RankedTensorType::kDynamicSize}, result_type.getElementType()); + Value if_at_most_one_non_scalar_result = + if_at_most_one_non_scalar_builder.create( + loc, ArrayRef{rank_one_result_type}, reshaped_operands, + op->getAttrs()); + Value extended_result = extendToBroadcastShape( + if_at_most_one_non_scalar_builder, loc, result_type, + if_at_most_one_non_scalar_result, shapes); + if_at_most_one_non_scalar_builder.create(loc, + extended_result); + + // If there is more than one shape which does not have exactly one element // - // See if rhs has exactly one element - OpBuilder else_lhs_scalar_builder = - if_op.getElseBodyBuilder(rewriter.getListener()); - auto if_rhs_scalar_op = else_lhs_scalar_builder.create( - loc, result_type, - IsSingleElementShape(else_lhs_scalar_builder, op, shape_of_rhs), true); - else_lhs_scalar_builder.create(loc, - if_rhs_scalar_op.getResult(0)); - OpBuilder if_rhs_scalar_builder = - if_rhs_scalar_op.getThenBodyBuilder(rewriter.getListener()); - Value reshaped_rhs = if_rhs_scalar_builder.create( - loc, RankedTensorType::get({}, rhs_type.getElementType()), rhs); - Value if_rhs_scalar_result = if_rhs_scalar_builder.create( - loc, ArrayRef{result_type}, ArrayRef{lhs, reshaped_rhs}, - op->getAttrs()); - Value extended_if_rhs_scalar_result = - extendToBroadcastShape(if_rhs_scalar_builder, loc, if_rhs_scalar_result, - shape_of_lhs, shape_of_rhs); - if_rhs_scalar_builder.create(loc, - extended_if_rhs_scalar_result); + // See if all shapes are equal. + OpBuilder else_builder = if_op.getElseBodyBuilder(rewriter.getListener()); + Value equal_shapes = + else_builder.create(loc, shapes[0], shapes[1]); + for (int i = 2; i < num_operands; ++i) { + Value are_equal = + else_builder.create(loc, shapes[0], shapes[i]); + equal_shapes = else_builder.create(loc, equal_shapes, are_equal); + } - // If NEITHER shape has exactly one element - // - // See if shapes are equal. - OpBuilder else_no_scalars_builder = - if_rhs_scalar_op.getElseBodyBuilder(rewriter.getListener()); - Value equal_shapes = else_no_scalars_builder.create( - loc, shape_of_lhs, shape_of_rhs); - - auto if_eq_shapes_op = else_no_scalars_builder.create( - loc, result_type, equal_shapes, true); - else_no_scalars_builder.create(loc, - if_eq_shapes_op.getResult(0)); + auto if_eq_shapes_op = + else_builder.create(loc, result_type, equal_shapes, true); + else_builder.create(loc, if_eq_shapes_op.getResult(0)); OpBuilder if_eq_shapes_builder = if_eq_shapes_op.getThenBodyBuilder(rewriter.getListener()); - Value non_broadcast_op = - Adaptor::CreateOp(op, result_type, {lhs, rhs}, if_eq_shapes_builder); + Value non_broadcast_op = Adaptor::CreateOp( + op, result_type, transformed_operands, if_eq_shapes_builder); if_eq_shapes_builder.create(loc, non_broadcast_op); // If shapes do not have exactly one element, nor are equal @@ -472,9 +497,9 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp OpBuilder if_neq_shapes_builder = if_eq_shapes_op.getElseBodyBuilder(rewriter.getListener()); if_neq_shapes_builder.create( - loc, ConvertUnrankedDynamicBroadcastOpHelper< - ChloOpTy, HloOpTy>::HandleBroadcastAndOp(if_neq_shapes_builder, - op, {lhs, rhs})); + loc, + ConvertUnrankedDynamicBroadcastOpHelper::HandleBroadcastAndOp( + if_neq_shapes_builder, op, transformed_operands)); rewriter.replaceOp(op, {if_op.getResult(0)}); return success(); @@ -494,37 +519,18 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp rewriter.create(loc, 1)); } - Value extendToBroadcastShape(OpBuilder &builder, Location loc, Value value, - Value shape_of_lhs, Value shape_of_rhs) const { + Value extendToBroadcastShape(OpBuilder &builder, Location loc, + Type result_type, Value value, + ValueRange shapes) const { auto unknown_rank_extent_tensor_type = RankedTensorType::get( {RankedTensorType::kDynamicSize}, builder.getIndexType()); - Value broadcast_shape = - builder.create(loc, unknown_rank_extent_tensor_type, - shape_of_lhs, shape_of_rhs, nullptr); - return builder.create(loc, value.getType(), value, + Value broadcast_shape = builder.create( + loc, unknown_rank_extent_tensor_type, shapes, nullptr); + return builder.create(loc, result_type, value, broadcast_shape); } }; -// Rank-specialize chlo.broadcast_select ops. -struct ConvertUnrankedDynamicBroadcastSelectOp - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite( - chlo::BroadcastSelectOp op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - // For now only do the bare minimum and specialize for every rank. There is - // more potential for optimization here. This also is missing the - // specialization for rank 0. - rewriter.replaceOp( - op, {ConvertUnrankedDynamicBroadcastOpHelper< - chlo::BroadcastSelectOp, - mhlo::SelectOp>::HandleBroadcastAndOp(rewriter, op, operands)}); - return success(); - } -}; - struct TransformUnrankedHloPass : public PassWrapper { void getDependentDialects(DialectRegistry ®istry) const override { @@ -588,11 +594,14 @@ void PopulateTransformUnrankedHloPatterns(MLIRContext *context, #undef MAP_HLO #undef MAP_CHLO #undef COMMA - chlo::PopulateForBroadcastingBinaryOp< - ConvertUnrankedDynamicBroadcastBinaryOp>(context, patterns); + chlo::PopulateForBroadcastingBinaryOp( + context, patterns); + patterns->insert>>(context); chlo::PopulateForBroadcastingBinaryOp< ConvertUnrankedScalarDynamicBroadcastBinaryOp>(context, patterns); - patterns->insert(context); } std::unique_ptr createTransformUnrankedHloPass() { diff --git a/tests/hlo-transform-unranked.mlir b/tests/hlo-transform-unranked.mlir index b3abe71..687b0eb 100644 --- a/tests/hlo-transform-unranked.mlir +++ b/tests/hlo-transform-unranked.mlir @@ -159,138 +159,130 @@ func @addUnrankedUnranked( // CHECK-SAME: %[[RHS:.*]]: tensor<*xf32>) -> tensor<*xf32> { // CHECK-NEXT: %[[LHS_SHAPE:.*]] = shape.shape_of %[[LHS]] : tensor<*xf32> -> tensor // CHECK-NEXT: %[[RHS_SHAPE:.*]] = shape.shape_of %[[RHS]] : tensor<*xf32> -> tensor -// CHECK-NEXT: %[[NUM_LHS:.*]] = shape.num_elements %[[LHS_SHAPE]] : tensor -> index +// CHECK-NEXT: %[[C0:.*]] = constant 0 : index // CHECK-NEXT: %[[C1:.*]] = constant 1 : index +// CHECK-NEXT: %[[NUM_LHS:.*]] = shape.num_elements %[[LHS_SHAPE]] : tensor -> index // CHECK-NEXT: %[[LHS_IS_SCALAR:.*]] = cmpi eq, %[[NUM_LHS]], %[[C1]] : index -// Handle scalar LHS case -// CHECK-NEXT: %[[VAL_8:.*]] = scf.if %[[LHS_IS_SCALAR]] -> (tensor<*xf32>) { -// CHECK-NEXT: %[[SCALAR_LHS:.*]] = "mhlo.reshape"(%[[LHS]]) : (tensor<*xf32>) -> tensor -// CHECK-NEXT: %[[NUM_RHS:.*]] = shape.num_elements %[[RHS_SHAPE]] : tensor -> index +// CHECK-NEXT: %[[COUNTER_PLUS_ONE:.*]] = addi %[[C0]], %[[C1]] : index +// CHECK-NEXT: %[[COUNTER:.*]] = select %[[LHS_IS_SCALAR]], %[[COUNTER_PLUS_ONE]], %[[C0]] : index +// CHECK-NEXT: %[[NUM_RHS:.*]] = shape.num_elements %[[RHS_SHAPE]] : tensor -> index +// CHECK-NEXT: %[[RHS_IS_SCALAR:.*]] = cmpi eq, %[[NUM_RHS]], %[[C1]] : index +// CHECK-NEXT: %[[COUNTER_PLUS_ONE2:.*]] = addi %[[COUNTER]], %[[C1]] : index +// CHECK-NEXT: %[[COUNTER2:.*]] = select %[[RHS_IS_SCALAR]], %[[COUNTER_PLUS_ONE2]], %[[COUNTER]] : index +// Handle scalar case +// CHECK-NEXT: %[[IS_SCALAR_CASE:.*]] = cmpi uge, %[[COUNTER2]], %[[C1]] : index +// CHECK-NEXT: %[[VAL_8:.*]] = scf.if %[[IS_SCALAR_CASE]] -> (tensor<*xf32>) { +// CHECK-NEXT: %[[NUM_TENS_LHS:.*]] = tensor.from_elements %[[NUM_LHS]] : tensor<1xindex> +// CHECK-NEXT: %[[RESHAPED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[NUM_TENS_LHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor // CHECK-NEXT: %[[NUM_TENS_RHS:.*]] = tensor.from_elements %[[NUM_RHS]] : tensor<1xindex> // CHECK-NEXT: %[[RESHAPED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[NUM_TENS_RHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor -// CHECK-NEXT: %[[LHS_SCALAR_RESULT:.*]] = chlo.broadcast_add %[[SCALAR_LHS]], %[[RESHAPED_RHS]] : (tensor, tensor) -> tensor -// CHECK-NEXT: %[[RESHAPED_LHS_SCALAR_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[LHS_SCALAR_RESULT]], %[[RHS_SHAPE]]) : (tensor, tensor) -> tensor<*xf32> -// CHECK-NEXT: %[[SHAPE_BROADCAST_LHS:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor, tensor -> tensor -// CHECK-NEXT: %[[RESHAPED_EXTENDED_LHS_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[RESHAPED_LHS_SCALAR_RESULT]], %[[SHAPE_BROADCAST_LHS]]) : (tensor<*xf32>, tensor) -> tensor<*xf32> -// CHECK-NEXT: scf.yield %[[RESHAPED_EXTENDED_LHS_RESULT]] : tensor<*xf32> +// CHECK-NEXT: %[[SCALAR_RESULT:.*]] = chlo.broadcast_add %[[RESHAPED_LHS]], %[[RESHAPED_RHS]] : (tensor, tensor) -> tensor +// CHECK-NEXT: %[[SHAPE_BROADCAST:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor, tensor -> tensor +// CHECK-NEXT: %[[RESHAPED_EXTENDED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[SCALAR_RESULT]], %[[SHAPE_BROADCAST]]) : (tensor, tensor) -> tensor<*xf32> +// CHECK-NEXT: scf.yield %[[RESHAPED_EXTENDED_RESULT]] : tensor<*xf32> // CHECK-NEXT: } else { -// CHECK-NEXT: %[[NUM_RHS:.*]] = shape.num_elements %[[RHS_SHAPE]] : tensor -> index -// CHECK-NEXT: %[[RHS_IS_SCALAR:.*]] = cmpi eq, %[[NUM_RHS]], %[[C1]] : index -// Handle scalar RHS case -// CHECK-NEXT: %[[VAL_14:.*]] = scf.if %[[RHS_IS_SCALAR]] -> (tensor<*xf32>) { -// CHECK-NEXT: %[[SCALAR_RHS:.*]] = "mhlo.reshape"(%[[RHS]]) : (tensor<*xf32>) -> tensor -// CHECK-NEXT: %[[NUM_TENS_LHS:.*]] = tensor.from_elements %[[NUM_LHS]] : tensor<1xindex> -// CHECK-NEXT: %[[RESHAPED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[NUM_TENS_LHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor -// CHECK-NEXT: %[[RHS_SCALAR_RESULT:.*]] = chlo.broadcast_add %[[RESHAPED_LHS]], %[[SCALAR_RHS]] : (tensor, tensor) -> tensor -// CHECK-NEXT: %[[RESHAPED_RHS_SCALAR_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[RHS_SCALAR_RESULT:.*]], %[[LHS_SHAPE]]) : (tensor, tensor) -> tensor<*xf32> -// CHECK-NEXT: %[[SHAPE_BROADCAST_RHS:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor, tensor -> tensor -// CHECK-NEXT: %[[RESHAPED_EXTENDED_RHS_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[RESHAPED_RHS_SCALAR_RESULT]], %[[SHAPE_BROADCAST_RHS]]) : (tensor<*xf32>, tensor) -> tensor<*xf32> -// CHECK-NEXT: scf.yield %[[RESHAPED_EXTENDED_RHS_RESULT]] : tensor<*xf32> +// CHECK-NEXT: %[[SHAPES_EQ:.*]] = shape.shape_eq %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor, tensor +// Handle equal shapes case +// CHECK-NEXT: %[[VAL_18:.*]] = scf.if %[[SHAPES_EQ]] -> (tensor<*xf32>) { +// CHECK-NEXT: %[[ANY_SHAPE:.*]] = shape.any %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor, tensor -> tensor +// CHECK-NEXT: %[[ANY_NUM:.*]] = shape.num_elements %[[ANY_SHAPE]] : tensor -> index +// CHECK-NEXT: %[[ANY_TENSOR:.*]] = tensor.from_elements %[[ANY_NUM]] : tensor<1xindex> +// CHECK-NEXT: %[[FLATTENED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[ANY_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor +// CHECK-NEXT: %[[FLATTENED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[ANY_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor +// CHECK-NEXT: %[[FLATTENED_RESULT:.*]] = mhlo.add %[[FLATTENED_LHS]], %[[FLATTENED_RHS]] : tensor +// CHECK-NEXT: %[[RESHAPED_SAME_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[FLATTENED_RESULT]], %[[ANY_SHAPE]]) : (tensor, tensor) -> tensor<*xf32> +// CHECK-NEXT: scf.yield %[[RESHAPED_SAME_RESULT]] : tensor<*xf32> // CHECK-NEXT: } else { -// CHECK-NEXT: %[[SHAPES_EQ:.*]] = shape.shape_eq %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor, tensor -// Handle equal shapes case -// CHECK-NEXT: %[[VAL_18:.*]] = scf.if %[[SHAPES_EQ]] -> (tensor<*xf32>) { -// CHECK-NEXT: %[[ANY_SHAPE:.*]] = shape.any %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor, tensor -> tensor -// CHECK-NEXT: %[[ANY_NUM:.*]] = shape.num_elements %[[ANY_SHAPE]] : tensor -> index -// CHECK-NEXT: %[[ANY_TENSOR:.*]] = tensor.from_elements %[[ANY_NUM]] : tensor<1xindex> -// CHECK-NEXT: %[[FLATTENED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[ANY_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor -// CHECK-NEXT: %[[FLATTENED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[ANY_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor -// CHECK-NEXT: %[[FLATTENED_RESULT:.*]] = mhlo.add %[[FLATTENED_LHS]], %[[FLATTENED_RHS]] : tensor -// CHECK-NEXT: %[[RESHAPED_SAME_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[FLATTENED_RESULT]], %[[ANY_SHAPE]]) : (tensor, tensor) -> tensor<*xf32> -// CHECK-NEXT: scf.yield %[[RESHAPED_SAME_RESULT]] : tensor<*xf32> +// CHECK-NEXT: %[[RESULT_SHAPE:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor, tensor -> tensor +// CHECK-NEXT: %[[MINIMUM_SHAPES:.*]]:2 = chlo.minimum_broadcast_shapes %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor, tensor -> tensor, tensor +// CHECK-NEXT: %[[MINIMUM_RESHAPED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[MINIMUM_SHAPES]]#0) : (tensor<*xf32>, tensor) -> tensor<*xf32> +// CHECK-NEXT: %[[MINIMUM_RESHAPED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[MINIMUM_SHAPES]]#1) : (tensor<*xf32>, tensor) -> tensor<*xf32> +// CHECK-NEXT: %[[LHS_RANK:.*]] = shape.rank %[[MINIMUM_SHAPES]]#0 : tensor -> index +// CHECK-NEXT: %[[RHS_RANK:.*]] = shape.rank %[[MINIMUM_SHAPES]]#1 : tensor -> index +// CHECK-NEXT: %[[LHS_RANK_GREATER:.*]] = cmpi sgt, %[[LHS_RANK]], %[[RHS_RANK]] : index +// CHECK-NEXT: %[[GREATEST_RANK:.*]] = select %[[LHS_RANK_GREATER]], %[[LHS_RANK]], %[[RHS_RANK]] : index +// Handle rank 1 specialization +// CHECK-NEXT: %[[GREATEST_RANK_IS_1:.*]] = cmpi eq, %[[GREATEST_RANK]], %[[C1]] : index +// CHECK-NEXT: %[[RESULT_RANK_SPECIALIZATION:.*]] = scf.if %[[GREATEST_RANK_IS_1]] -> (tensor<*xf32>) { +// CHECK-NEXT: %[[CONST_SHAPE_1:.*]] = shape.const_shape [1] +// CHECK-NEXT: %[[BROADCASTED_LHS_1:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_1]] : tensor, tensor<1xindex> -> tensor +// CHECK-NEXT: %[[CASTED_LHS_1:.*]] = tensor.cast %[[BROADCASTED_LHS_1]] : tensor to tensor<1xindex> +// CHECK-NEXT: %[[RESHAPED_LHS_1:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS_1]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor +// CHECK-NEXT: %[[BROADCASTED_RHS_1:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_1]] : tensor, tensor<1xindex> -> tensor +// CHECK-NEXT: %[[CASTED_RHS_1:.*]] = tensor.cast %[[BROADCASTED_RHS_1]] : tensor to tensor<1xindex> +// CHECK-NEXT: %[[RESHAPED_RHS_1:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS_1]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor +// CHECK-NEXT: %[[RESULT_RANK_1:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_1]], %[[RESHAPED_RHS_1]] : (tensor, tensor) -> tensor +// CHECK-NEXT: %[[RESULT_1:.*]] = tensor.cast %[[RESULT_RANK_1]] : tensor to tensor<*xf32> +// CHECK-NEXT: scf.yield %[[RESULT_1]] : tensor<*xf32> // CHECK-NEXT: } else { -// CHECK-NEXT: %[[RESULT_SHAPE:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor, tensor -> tensor -// CHECK-NEXT: %[[MINIMUM_SHAPES:.*]]:2 = chlo.minimum_broadcast_shapes %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor, tensor -> tensor, tensor -// CHECK-NEXT: %[[MINIMUM_RESHAPED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[MINIMUM_SHAPES]]#0) : (tensor<*xf32>, tensor) -> tensor<*xf32> -// CHECK-NEXT: %[[MINIMUM_RESHAPED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[MINIMUM_SHAPES]]#1) : (tensor<*xf32>, tensor) -> tensor<*xf32> -// CHECK-NEXT: %[[LHS_RANK:.*]] = shape.rank %[[MINIMUM_SHAPES]]#0 : tensor -> index -// CHECK-NEXT: %[[RHS_RANK:.*]] = shape.rank %[[MINIMUM_SHAPES]]#1 : tensor -> index -// CHECK-NEXT: %[[LHS_RANK_GREATER:.*]] = cmpi sgt, %[[LHS_RANK]], %[[RHS_RANK]] : index -// CHECK-NEXT: %[[GREATEST_RANK:.*]] = select %[[LHS_RANK_GREATER]], %[[LHS_RANK]], %[[RHS_RANK]] : index -// Handle rank 1 specialization -// CHECK-NEXT: %[[GREATEST_RANK_IS_1:.*]] = cmpi eq, %[[GREATEST_RANK]], %[[C1]] : index -// CHECK-NEXT: %[[RESULT_RANK_SPECIALIZATION:.*]] = scf.if %[[GREATEST_RANK_IS_1]] -> (tensor<*xf32>) { -// CHECK-NEXT: %[[CONST_SHAPE_1:.*]] = shape.const_shape [1] -// CHECK-NEXT: %[[BROADCASTED_LHS_1:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_1]] : tensor, tensor<1xindex> -> tensor -// CHECK-NEXT: %[[CASTED_LHS_1:.*]] = tensor.cast %[[BROADCASTED_LHS_1]] : tensor to tensor<1xindex> -// CHECK-NEXT: %[[RESHAPED_LHS_1:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS_1]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor -// CHECK-NEXT: %[[BROADCASTED_RHS_1:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_1]] : tensor, tensor<1xindex> -> tensor -// CHECK-NEXT: %[[CASTED_RHS_1:.*]] = tensor.cast %[[BROADCASTED_RHS_1]] : tensor to tensor<1xindex> -// CHECK-NEXT: %[[RESHAPED_RHS_1:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS_1]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor -// CHECK-NEXT: %[[RESULT_RANK_1:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_1]], %[[RESHAPED_RHS_1]] : (tensor, tensor) -> tensor -// CHECK-NEXT: %[[RESULT_1:.*]] = tensor.cast %[[RESULT_RANK_1]] : tensor to tensor<*xf32> -// CHECK-NEXT: scf.yield %[[RESULT_1]] : tensor<*xf32> +// CHECK-NEXT: %[[C2:.*]] = constant 2 : index +// CHECK-NEXT: %[[GREATEST_RANK_IS_2:.*]] = cmpi eq, %[[GREATEST_RANK]], %[[C2]] : index +// Handle rank 2 specialization +// CHECK-NEXT: %[[VAL_26:.*]] = scf.if %[[GREATEST_RANK_IS_2]] -> (tensor<*xf32>) { +// CHECK-NEXT: %[[CONST_SHAPE_2:.*]] = shape.const_shape [1, 1] +// CHECK-NEXT: %[[BROADCASTED_LHS_2:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_2]] : tensor, tensor<2xindex> -> tensor +// CHECK-NEXT: %[[CASTED_LHS_2:.*]] = tensor.cast %[[BROADCASTED_LHS_2]] : tensor to tensor<2xindex> +// CHECK-NEXT: %[[RESHAPED_LHS_2:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor +// CHECK-NEXT: %[[BROADCASTED_RHS_2:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_2]] : tensor, tensor<2xindex> -> tensor +// CHECK-NEXT: %[[CASTED_RHS_2:.*]] = tensor.cast %[[BROADCASTED_RHS_2]] : tensor to tensor<2xindex> +// CHECK-NEXT: %[[RESHAPED_RHS_2:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor +// CHECK-NEXT: %[[RESULT_RANK_2:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_2]], %[[RESHAPED_RHS_2]] : (tensor, tensor) -> tensor +// CHECK-NEXT: %[[RESULT_2:.*]] = tensor.cast %[[RESULT_RANK_2]] : tensor to tensor<*xf32> +// CHECK-NEXT: scf.yield %[[RESULT_2]] : tensor<*xf32> // CHECK-NEXT: } else { -// CHECK-NEXT: %[[C2:.*]] = constant 2 : index -// CHECK-NEXT: %[[GREATEST_RANK_IS_2:.*]] = cmpi eq, %[[GREATEST_RANK]], %[[C2]] : index -// Handle rank 2 specialization -// CHECK-NEXT: %[[VAL_26:.*]] = scf.if %[[GREATEST_RANK_IS_2]] -> (tensor<*xf32>) { -// CHECK-NEXT: %[[CONST_SHAPE_2:.*]] = shape.const_shape [1, 1] -// CHECK-NEXT: %[[BROADCASTED_LHS_2:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_2]] : tensor, tensor<2xindex> -> tensor -// CHECK-NEXT: %[[CASTED_LHS_2:.*]] = tensor.cast %[[BROADCASTED_LHS_2]] : tensor to tensor<2xindex> -// CHECK-NEXT: %[[RESHAPED_LHS_2:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor -// CHECK-NEXT: %[[BROADCASTED_RHS_2:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_2]] : tensor, tensor<2xindex> -> tensor -// CHECK-NEXT: %[[CASTED_RHS_2:.*]] = tensor.cast %[[BROADCASTED_RHS_2]] : tensor to tensor<2xindex> -// CHECK-NEXT: %[[RESHAPED_RHS_2:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor -// CHECK-NEXT: %[[RESULT_RANK_2:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_2]], %[[RESHAPED_RHS_2]] : (tensor, tensor) -> tensor -// CHECK-NEXT: %[[RESULT_2:.*]] = tensor.cast %[[RESULT_RANK_2]] : tensor to tensor<*xf32> -// CHECK-NEXT: scf.yield %[[RESULT_2]] : tensor<*xf32> +// CHECK-NEXT: %[[C3:.*]] = constant 3 : index +// CHECK-NEXT: %[[GREATEST_RANK_IS_3:.*]] = cmpi eq, %[[GREATEST_RANK]], %[[C3]] : index +// Handle rank 3 specialization +// CHECK-NEXT: %[[VAL_34:.*]] = scf.if %[[GREATEST_RANK_IS_3]] -> (tensor<*xf32>) { +// CHECK-NEXT: %[[CONST_SHAPE_3:.*]] = shape.const_shape [1, 1, 1] +// CHECK-NEXT: %[[BROADCASTED_LHS_3:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_3]] : tensor, tensor<3xindex> -> tensor +// CHECK-NEXT: %[[CASTED_LHS_3:.*]] = tensor.cast %[[BROADCASTED_LHS_3]] : tensor to tensor<3xindex> +// CHECK-NEXT: %[[RESHAPED_LHS_3:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor +// CHECK-NEXT: %[[BROADCASTED_RHS_3:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_3]] : tensor, tensor<3xindex> -> tensor +// CHECK-NEXT: %[[CASTED_RHS_3:.*]] = tensor.cast %[[BROADCASTED_RHS_3]] : tensor to tensor<3xindex> +// CHECK-NEXT: %[[RESHAPED_RHS_3:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor +// CHECK-NEXT: %[[RESULT_RANK_3:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_3]], %[[RESHAPED_RHS_3]] : (tensor, tensor) -> tensor +// CHECK-NEXT: %[[RESULT_3:.*]] = tensor.cast %[[RESULT_RANK_3]] : tensor to tensor<*xf32> +// CHECK-NEXT: scf.yield %[[RESULT_3]] : tensor<*xf32> // CHECK-NEXT: } else { -// CHECK-NEXT: %[[C3:.*]] = constant 3 : index -// CHECK-NEXT: %[[GREATEST_RANK_IS_3:.*]] = cmpi eq, %[[GREATEST_RANK]], %[[C3]] : index -// Handle rank 3 specialization -// CHECK-NEXT: %[[VAL_34:.*]] = scf.if %[[GREATEST_RANK_IS_3]] -> (tensor<*xf32>) { -// CHECK-NEXT: %[[CONST_SHAPE_3:.*]] = shape.const_shape [1, 1, 1] -// CHECK-NEXT: %[[BROADCASTED_LHS_3:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_3]] : tensor, tensor<3xindex> -> tensor -// CHECK-NEXT: %[[CASTED_LHS_3:.*]] = tensor.cast %[[BROADCASTED_LHS_3]] : tensor to tensor<3xindex> -// CHECK-NEXT: %[[RESHAPED_LHS_3:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor -// CHECK-NEXT: %[[BROADCASTED_RHS_3:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_3]] : tensor, tensor<3xindex> -> tensor -// CHECK-NEXT: %[[CASTED_RHS_3:.*]] = tensor.cast %[[BROADCASTED_RHS_3]] : tensor to tensor<3xindex> -// CHECK-NEXT: %[[RESHAPED_RHS_3:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor -// CHECK-NEXT: %[[RESULT_RANK_3:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_3]], %[[RESHAPED_RHS_3]] : (tensor, tensor) -> tensor -// CHECK-NEXT: %[[RESULT_3:.*]] = tensor.cast %[[RESULT_RANK_3]] : tensor to tensor<*xf32> -// CHECK-NEXT: scf.yield %[[RESULT_3]] : tensor<*xf32> +// CHECK-NEXT: %[[C4:.*]] = constant 4 : index +// CHECK-NEXT: %[[GREATEST_RANK_IS_4:.*]] = cmpi eq, %[[GREATEST_RANK]], %[[C4]] : index +// Handle rank 4 specialization +// CHECK-NEXT: %[[VAL_42:.*]] = scf.if %[[GREATEST_RANK_IS_4]] -> (tensor<*xf32>) { +// CHECK-NEXT: %[[CONST_SHAPE_4:.*]] = shape.const_shape [1, 1, 1, 1] +// CHECK-NEXT: %[[BROADCASTED_LHS_4:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_4]] : tensor, tensor<4xindex> -> tensor +// CHECK-NEXT: %[[CASTED_LHS_4:.*]] = tensor.cast %[[BROADCASTED_LHS_4]] : tensor to tensor<4xindex> +// CHECK-NEXT: %[[RESHAPED_LHS_4:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor +// CHECK-NEXT: %[[BROADCASTED_RHS_4:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_4]] : tensor, tensor<4xindex> -> tensor +// CHECK-NEXT: %[[CASTED_RHS_4:.*]] = tensor.cast %[[BROADCASTED_RHS_4]] : tensor to tensor<4xindex> +// CHECK-NEXT: %[[RESHAPED_RHS_4:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor +// CHECK-NEXT: %[[RESULT_RANK_4:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_4]], %[[RESHAPED_RHS_4]] : (tensor, tensor) -> tensor +// CHECK-NEXT: %[[RESULT_4:.*]] = tensor.cast %[[RESULT_RANK_4]] : tensor to tensor<*xf32> +// CHECK-NEXT: scf.yield %[[RESULT_4]] : tensor<*xf32> // CHECK-NEXT: } else { -// CHECK-NEXT: %[[C4:.*]] = constant 4 : index -// CHECK-NEXT: %[[GREATEST_RANK_IS_4:.*]] = cmpi eq, %[[GREATEST_RANK]], %[[C4]] : index -// Handle rank 4 specialization -// CHECK-NEXT: %[[VAL_42:.*]] = scf.if %[[GREATEST_RANK_IS_4]] -> (tensor<*xf32>) { -// CHECK-NEXT: %[[CONST_SHAPE_4:.*]] = shape.const_shape [1, 1, 1, 1] -// CHECK-NEXT: %[[BROADCASTED_LHS_4:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_4]] : tensor, tensor<4xindex> -> tensor -// CHECK-NEXT: %[[CASTED_LHS_4:.*]] = tensor.cast %[[BROADCASTED_LHS_4]] : tensor to tensor<4xindex> -// CHECK-NEXT: %[[RESHAPED_LHS_4:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor -// CHECK-NEXT: %[[BROADCASTED_RHS_4:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_4]] : tensor, tensor<4xindex> -> tensor -// CHECK-NEXT: %[[CASTED_RHS_4:.*]] = tensor.cast %[[BROADCASTED_RHS_4]] : tensor to tensor<4xindex> -// CHECK-NEXT: %[[RESHAPED_RHS_4:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor -// CHECK-NEXT: %[[RESULT_RANK_4:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_4]], %[[RESHAPED_RHS_4]] : (tensor, tensor) -> tensor -// CHECK-NEXT: %[[RESULT_4:.*]] = tensor.cast %[[RESULT_RANK_4]] : tensor to tensor<*xf32> -// CHECK-NEXT: scf.yield %[[RESULT_4]] : tensor<*xf32> -// CHECK-NEXT: } else { -// CHECK-NEXT: %[[C5:.*]] = constant 5 : index -// CHECK-NEXT: %[[GREATEST_RANK_IS_5:.*]] = cmpi eq, %[[GREATEST_RANK]], %[[C5]] : index -// CHECK-NEXT: assert %[[GREATEST_RANK_IS_5]] -// Handle rank 5 specialization -// CHECK-NEXT: %[[CONST_SHAPE_5:.*]] = shape.const_shape [1, 1, 1, 1, 1] -// CHECK-NEXT: %[[BROADCASTED_LHS_5:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_5]] : tensor, tensor<5xindex> -> tensor -// CHECK-NEXT: %[[CASTED_LHS_5:.*]] = tensor.cast %[[BROADCASTED_LHS_5]] : tensor to tensor<5xindex> -// CHECK-NEXT: %[[RESHAPED_LHS_5:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor -// CHECK-NEXT: %[[BROADCASTED_RHS_5:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_5]] : tensor, tensor<5xindex> -> tensor -// CHECK-NEXT: %[[CASTED_RHS_5:.*]] = tensor.cast %[[BROADCASTED_RHS_5]] : tensor to tensor<5xindex> -// CHECK-NEXT: %[[RESHAPED_RHS_5:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor -// CHECK-NEXT: %[[RESULT_RANK_5:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_5]], %[[RESHAPED_RHS_5]] : (tensor, tensor) -> tensor -// CHECK-NEXT: %[[RESULT_5:.*]] = tensor.cast %[[RESULT_RANK_5]] : tensor to tensor<*xf32> -// CHECK-NEXT: scf.yield %[[RESULT_5]] : tensor<*xf32> -// CHECK-NEXT: } -// CHECK-NEXT: scf.yield %[[VAL_66:.*]] : tensor<*xf32> +// CHECK-NEXT: %[[C5:.*]] = constant 5 : index +// CHECK-NEXT: %[[GREATEST_RANK_IS_5:.*]] = cmpi eq, %[[GREATEST_RANK]], %[[C5]] : index +// CHECK-NEXT: assert %[[GREATEST_RANK_IS_5]] +// Handle rank 5 specialization +// CHECK-NEXT: %[[CONST_SHAPE_5:.*]] = shape.const_shape [1, 1, 1, 1, 1] +// CHECK-NEXT: %[[BROADCASTED_LHS_5:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_5]] : tensor, tensor<5xindex> -> tensor +// CHECK-NEXT: %[[CASTED_LHS_5:.*]] = tensor.cast %[[BROADCASTED_LHS_5]] : tensor to tensor<5xindex> +// CHECK-NEXT: %[[RESHAPED_LHS_5:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor +// CHECK-NEXT: %[[BROADCASTED_RHS_5:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_5]] : tensor, tensor<5xindex> -> tensor +// CHECK-NEXT: %[[CASTED_RHS_5:.*]] = tensor.cast %[[BROADCASTED_RHS_5]] : tensor to tensor<5xindex> +// CHECK-NEXT: %[[RESHAPED_RHS_5:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor +// CHECK-NEXT: %[[RESULT_RANK_5:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_5]], %[[RESHAPED_RHS_5]] : (tensor, tensor) -> tensor +// CHECK-NEXT: %[[RESULT_5:.*]] = tensor.cast %[[RESULT_RANK_5]] : tensor to tensor<*xf32> +// CHECK-NEXT: scf.yield %[[RESULT_5]] : tensor<*xf32> // CHECK-NEXT: } -// CHECK-NEXT: scf.yield %[[VAL_67:.*]] : tensor<*xf32> +// CHECK-NEXT: scf.yield %[[VAL_66:.*]] : tensor<*xf32> // CHECK-NEXT: } -// CHECK-NEXT: scf.yield %[[VAL_68:.*]] : tensor<*xf32> +// CHECK-NEXT: scf.yield %[[VAL_67:.*]] : tensor<*xf32> // CHECK-NEXT: } -// CHECK-NEXT: %[[RESHAPED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[RESULT_RANK_SPECIALIZATION]], %[[RESULT_SHAPE]]) : (tensor<*xf32>, tensor) -> tensor<*xf32> -// CHECK-NEXT: scf.yield %[[RESHAPED_RESULT]] : tensor<*xf32> +// CHECK-NEXT: scf.yield %[[VAL_68:.*]] : tensor<*xf32> // CHECK-NEXT: } -// CHECK-NEXT: scf.yield %[[VAL_70:.*]] : tensor<*xf32> +// CHECK-NEXT: %[[RESHAPED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[RESULT_RANK_SPECIALIZATION]], %[[RESULT_SHAPE]]) : (tensor<*xf32>, tensor) -> tensor<*xf32> +// CHECK-NEXT: scf.yield %[[RESHAPED_RESULT]] : tensor<*xf32> // CHECK-NEXT: } // CHECK-NEXT: scf.yield %[[VAL_71:.*]] : tensor<*xf32> // CHECK-NEXT: } @@ -315,36 +307,77 @@ func @selectUnrankedUnrankedUnranked( // CHECK-NEXT: %[[PRED_SHAPE:.*]] = shape.shape_of %[[PRED]] : tensor<*xi1> -> tensor // CHECK-NEXT: %[[LHS_SHAPE:.*]] = shape.shape_of %[[LHS]] : tensor<*xf32> -> tensor // CHECK-NEXT: %[[RHS_SHAPE:.*]] = shape.shape_of %[[RHS]] : tensor<*xf32> -> tensor -// CHECK-NEXT: %[[RESULT_SHAPE:.*]] = shape.broadcast %[[PRED_SHAPE]], %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor, tensor, tensor -> tensor -// CHECK-NEXT: %[[MINIMUM_SHAPES:.*]]:3 = chlo.minimum_broadcast_shapes %[[PRED_SHAPE]], %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor, tensor, tensor -> tensor, tensor, tensor -// CHECK-NEXT: %[[MINIMUM_RESHAPED_PRED:.*]] = "mhlo.dynamic_reshape"(%[[PRED]], %[[MINIMUM_SHAPES]]#0) : (tensor<*xi1>, tensor) -> tensor<*xi1> -// CHECK-NEXT: %[[MINIMUM_RESHAPED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[MINIMUM_SHAPES]]#1) : (tensor<*xf32>, tensor) -> tensor<*xf32> -// CHECK-NEXT: %[[MINIMUM_RESHAPED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[MINIMUM_SHAPES]]#2) : (tensor<*xf32>, tensor) -> tensor<*xf32> -// CHECK-NEXT: %[[PRED_RANK:.*]] = shape.rank %[[MINIMUM_SHAPES]]#0 : tensor -> index -// CHECK-NEXT: %[[LHS_RANK:.*]] = shape.rank %[[MINIMUM_SHAPES]]#1 : tensor -> index -// CHECK-NEXT: %[[GREATER_RANK_CMP:.*]] = cmpi sgt, %[[PRED_RANK]], %[[LHS_RANK]] : index -// CHECK-NEXT: %[[GREATER_RANK:.*]] = select %[[GREATER_RANK_CMP]], %[[PRED_RANK]], %[[LHS_RANK]] : index -// CHECK-NEXT: %[[RHS_RANK:.*]] = shape.rank %[[MINIMUM_SHAPES]]#2 : tensor -> index -// CHECK-NEXT: %[[GREATEST_RANK_CMP:.*]] = cmpi sgt, %[[GREATER_RANK]], %[[RHS_RANK]] : index -// CHECK-NEXT: %[[GREATEST_RANK:.*]] = select %[[GREATEST_RANK_CMP]], %[[GREATER_RANK]], %[[RHS_RANK]] : index +// CHECK-NEXT: %c0 = constant 0 : index // CHECK-NEXT: %c1 = constant 1 : index -// CHECK-NEXT: %[[GREATEST_RANK_IS_1:.*]] = cmpi eq, %[[GREATEST_RANK]], %c1 : index -// Handle rank 1 specialization -// CHECK-NEXT: scf.if %[[GREATEST_RANK_IS_1]] -> (tensor<*xf32>) { -// CHECK-NEXT: %[[CONST_SHAPE_1:.*]] = shape.const_shape [1] : tensor<1xindex> -// CHECK-NEXT: %[[BROADCASTED_PRED:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_1]] : tensor, tensor<1xindex> -> tensor -// CHECK-NEXT: %[[CASTED_PRED:.*]] = tensor.cast %[[BROADCASTED_PRED]] : tensor to tensor<1xindex> -// CHECK-NEXT: %[[RESHAPED_PRED:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_PRED]], %[[CASTED_PRED]]) : (tensor<*xi1>, tensor<1xindex>) -> tensor -// CHECK-NEXT: %[[BROADCASTED_LHS:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_1]] : tensor, tensor<1xindex> -> tensor -// CHECK-NEXT: %[[CASTED_LHS:.*]] = tensor.cast %[[BROADCASTED_LHS]] : tensor to tensor<1xindex> -// CHECK-NEXT: %[[RESHAPED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor -// CHECK-NEXT: %[[BROADCASTED_RHS:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#2, %[[CONST_SHAPE_1]] : tensor, tensor<1xindex> -> tensor -// CHECK-NEXT: %[[CASTED_RHS:.*]] = tensor.cast %[[BROADCASTED_RHS]] : tensor to tensor<1xindex> -// CHECK-NEXT: %[[RESHAPED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor -// CHECK-NEXT: %[[RESULT_RANK_1:.*]] = chlo.broadcast_select %[[RESHAPED_PRED]], %[[RESHAPED_LHS]], %[[RESHAPED_RHS]] : (tensor, tensor, tensor) -> tensor -// CHECK-NEXT: %[[RESULT_1:.*]] = tensor.cast %[[RESULT_RANK_1:.*]] : tensor to tensor<*xf32> -// CHECK-NEXT: scf.yield %[[RESULT_1]] : tensor<*xf32> -// CHECK-NEXT: } +// CHECK-NEXT: %[[NUM_ELEMENTS_PRED:.*]] = shape.num_elements %[[PRED_SHAPE]] : tensor -> index +// CHECK-NEXT: %[[PRED_IS_SCALAR:.*]] = cmpi eq, %[[NUM_ELEMENTS_PRED]], %c1 : index +// CHECK-NEXT: %[[COUNTER_PLUS_ONE:.*]] = addi %c0, %c1 : index +// CHECK-NEXT: %[[COUNTER:.*]] = select %[[PRED_IS_SCALAR]], %[[COUNTER_PLUS_ONE]], %c0 : index +// CHECK-NEXT: %[[NUM_ELEMENTS_LHS:.*]] = shape.num_elements %[[LHS_SHAPE]] : tensor -> index +// CHECK-NEXT: %[[LHS_IS_SCALAR:.*]] = cmpi eq, %[[NUM_ELEMENTS_LHS]], %c1 : index +// CHECK-NEXT: %[[COUNTER_PLUS_ONE:.*]] = addi %[[COUNTER]], %c1 : index +// CHECK-NEXT: %[[COUNTER2:.*]] = select %[[LHS_IS_SCALAR]], %[[COUNTER_PLUS_ONE]], %[[COUNTER]] : index +// CHECK-NEXT: %[[NUM_ELEMENTS_RHS:.*]] = shape.num_elements %[[RHS_SHAPE]] : tensor -> index +// CHECK-NEXT: %[[RHS_IS_SCALAR:.*]] = cmpi eq, %[[NUM_ELEMENTS_RHS]], %c1 : index +// CHECK-NEXT: %[[COUNTER_PLUS_ONE:.*]] = addi %[[COUNTER2]], %c1 : index +// CHECK-NEXT: %[[COUNTER3:.*]] = select %[[RHS_IS_SCALAR]], %[[COUNTER_PLUS_ONE]], %[[COUNTER2]] : index +// CHECK-NEXT: %c2 = constant 2 : index +// CHECK-NEXT: %[[IS_SCALAR_CASE:.*]] = cmpi uge, %[[COUNTER3]], %c2 : index +// CHECK-NEXT: %[[IF_IS_SCALAR_CASE:.*]] = scf.if %[[IS_SCALAR_CASE]] -> (tensor<*xf32>) { +// CHECK-NEXT: %[[NUM_TENS_PRED:.*]] = tensor.from_elements %[[NUM_ELEMENTS_PRED]] : tensor<1xindex> +// CHECK-NEXT: %[[RESHAPED_PRED:.*]] = "mhlo.dynamic_reshape"(%[[PRED]], %[[NUM_TENS_PRED]]) : (tensor<*xi1>, tensor<1xindex>) -> tensor +// CHECK-NEXT: %[[NUM_TENS_LHS:.*]] = tensor.from_elements %[[NUM_ELEMENTS_LHS]] : tensor<1xindex> +// CHECK-NEXT: %[[RESHAPED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[NUM_TENS_LHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor +// CHECK-NEXT: %[[NUM_TENS_RHS:.*]] = tensor.from_elements %[[NUM_ELEMENTS_RHS]] : tensor<1xindex> +// CHECK-NEXT: %[[RESHAPED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[NUM_TENS_RHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor +// CHECK-NEXT: %[[SCALAR_RESULT:.*]] = chlo.broadcast_select %[[RESHAPED_PRED]], %[[RESHAPED_LHS]], %[[RESHAPED_RHS]] : (tensor, tensor, tensor) -> tensor +// CHECK-NEXT: %[[SHAPE_BROADCAST:.*]] = shape.broadcast %[[PRED_SHAPE]], %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor, tensor, tensor -> tensor +// CHECK-NEXT: %[[RESHAPED_EXTENDED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[SCALAR_RESULT]], %[[SHAPE_BROADCAST]]) : (tensor, tensor) -> tensor<*xf32> +// CHECK-NEXT: scf.yield %[[RESHAPED_EXTENDED_RESULT]] : tensor<*xf32> +// CHECK-NEXT: } else { +// CHECK-NEXT: %[[FIRST_SHAPES_EQUAL:.*]] = shape.shape_eq %[[PRED_SHAPE]], %[[LHS_SHAPE]] : tensor, tensor +// CHECK-NEXT: %[[SECOND_SHAPES_EQUAL:.*]] = shape.shape_eq %[[PRED_SHAPE]], %[[RHS_SHAPE]] : tensor, tensor +// CHECK-NEXT: %[[ALL_SHAPES_EQUAL:.*]] = and %[[FIRST_SHAPES_EQUAL]], %[[SECOND_SHAPES_EQUAL]] : i1 +// CHECK-NEXT: %[[IF_EQUAL_CASE:.*]] = scf.if %[[ALL_SHAPES_EQUAL]] -> (tensor<*xf32>) { +// CHECK-NEXT: %[[ANY_SHAPE:.*]] = shape.any %[[PRED_SHAPE]], %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor, tensor, tensor -> tensor +// CHECK-NEXT: %[[ANY_NUM:.*]] = shape.num_elements %[[ANY_SHAPE]] : tensor -> index +// CHECK-NEXT: %[[ANY_TENSOR:.*]] = tensor.from_elements %[[ANY_NUM]] : tensor<1xindex> +// CHECK-NEXT: %[[FLATTENED_PRED:.*]] = "mhlo.dynamic_reshape"(%[[PRED]], %[[ANY_TENSOR]]) : (tensor<*xi1>, tensor<1xindex>) -> tensor +// CHECK-NEXT: %[[FLATTENED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[ANY_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor +// CHECK-NEXT: %[[FLATTENED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[ANY_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor +// CHECK-NEXT: %[[FLATTENED_RESULT:.*]] = "mhlo.select"(%[[FLATTENED_PRED]], %[[FLATTENED_LHS]], %[[FLATTENED_RHS]]) : (tensor, tensor, tensor) -> tensor +// CHECK-NEXT: %[[RESHAPED_SAME_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[FLATTENED_RESULT]], %[[ANY_SHAPE]]) : (tensor, tensor) -> tensor<*xf32> +// CHECK-NEXT: scf.yield %[[RESHAPED_SAME_RESULT]] : tensor<*xf32> +// CHECK-NEXT: } else { +// CHECK-NEXT: %[[RESULT_SHAPE:.*]] = shape.broadcast %[[PRED_SHAPE]], %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor, tensor, tensor -> tensor +// CHECK-NEXT: %[[MINIMUM_SHAPES:.*]]:3 = chlo.minimum_broadcast_shapes %[[PRED_SHAPE]], %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor, tensor, tensor -> tensor, tensor, tensor +// CHECK-NEXT: %[[MINIMUM_RESHAPED_PRED:.*]] = "mhlo.dynamic_reshape"(%[[PRED]], %[[MINIMUM_SHAPES]]#0) : (tensor<*xi1>, tensor) -> tensor<*xi1> +// CHECK-NEXT: %[[MINIMUM_RESHAPED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[MINIMUM_SHAPES]]#1) : (tensor<*xf32>, tensor) -> tensor<*xf32> +// CHECK-NEXT: %[[MINIMUM_RESHAPED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[MINIMUM_SHAPES]]#2) : (tensor<*xf32>, tensor) -> tensor<*xf32> +// CHECK-NEXT: %[[PRED_RANK:.*]] = shape.rank %[[MINIMUM_SHAPES]]#0 : tensor -> index +// CHECK-NEXT: %[[LHS_RANK:.*]] = shape.rank %[[MINIMUM_SHAPES]]#1 : tensor -> index +// CHECK-NEXT: %[[GREATER_RANK_CMP:.*]] = cmpi sgt, %[[PRED_RANK]], %[[LHS_RANK]] : index +// CHECK-NEXT: %[[GREATER_RANK:.*]] = select %[[GREATER_RANK_CMP]], %[[PRED_RANK]], %[[LHS_RANK]] : index +// CHECK-NEXT: %[[RHS_RANK:.*]] = shape.rank %[[MINIMUM_SHAPES]]#2 : tensor -> index +// CHECK-NEXT: %[[GREATEST_RANK_CMP:.*]] = cmpi sgt, %[[GREATER_RANK]], %[[RHS_RANK]] : index +// CHECK-NEXT: %[[GREATEST_RANK:.*]] = select %[[GREATEST_RANK_CMP]], %[[GREATER_RANK]], %[[RHS_RANK]] : index +// CHECK-NEXT: %[[GREATEST_RANK_IS_1:.*]] = cmpi eq, %[[GREATEST_RANK]], %c1 : index +// Handle rank 1 specialization +// CHECK-NEXT: scf.if %[[GREATEST_RANK_IS_1]] -> (tensor<*xf32>) { +// CHECK-NEXT: %[[CONST_SHAPE_1:.*]] = shape.const_shape [1] : tensor<1xindex> +// CHECK-NEXT: %[[BROADCASTED_PRED:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_1]] : tensor, tensor<1xindex> -> tensor +// CHECK-NEXT: %[[CASTED_PRED:.*]] = tensor.cast %[[BROADCASTED_PRED]] : tensor to tensor<1xindex> +// CHECK-NEXT: %[[RESHAPED_PRED:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_PRED]], %[[CASTED_PRED]]) : (tensor<*xi1>, tensor<1xindex>) -> tensor +// CHECK-NEXT: %[[BROADCASTED_LHS:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_1]] : tensor, tensor<1xindex> -> tensor +// CHECK-NEXT: %[[CASTED_LHS:.*]] = tensor.cast %[[BROADCASTED_LHS]] : tensor to tensor<1xindex> +// CHECK-NEXT: %[[RESHAPED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor +// CHECK-NEXT: %[[BROADCASTED_RHS:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#2, %[[CONST_SHAPE_1]] : tensor, tensor<1xindex> -> tensor +// CHECK-NEXT: %[[CASTED_RHS:.*]] = tensor.cast %[[BROADCASTED_RHS]] : tensor to tensor<1xindex> +// CHECK-NEXT: %[[RESHAPED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor +// CHECK-NEXT: %[[RESULT_RANK_1:.*]] = chlo.broadcast_select %[[RESHAPED_PRED]], %[[RESHAPED_LHS]], %[[RESHAPED_RHS]] : (tensor, tensor, tensor) -> tensor +// CHECK-NEXT: %[[RESULT_1:.*]] = tensor.cast %[[RESULT_RANK_1:.*]] : tensor to tensor<*xf32> +// CHECK-NEXT: scf.yield %[[RESULT_1]] : tensor<*xf32> +// CHECK-NEXT: } // CHECK: chlo.broadcast_select {{.*}} : (tensor, tensor, tensor) -> tensor // CHECK: chlo.broadcast_select {{.*}} : (tensor, tensor, tensor) -> tensor