diff --git a/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc b/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc index 9325d7d..24b8055 100644 --- a/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc +++ b/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc @@ -289,7 +289,8 @@ struct ConvertUnrankedDynamicBroadcastOpHelper { // Iterates over the desired ranks to be specialized and generates the code // snippet for each case. static Value HandleBroadcastAndOp(OpBuilder &rewriter, ChloOpTy op, - ValueRange operands) { + ValueRange operands, + int min_rank_specialization) { auto loc = op.getLoc(); // Get the minimum broadcast shapes of the operands. @@ -336,18 +337,20 @@ struct ConvertUnrankedDynamicBroadcastOpHelper { } // Generate a list of nested if/else statements to handle rank - // specializations from 1 to `kMaxRankSpecialization`. + // specializations from `min_rank_specialization` to + // `kMaxRankSpecialization`. + constexpr int kMaxRankSpecialization = 5; scf::IfOp if_op = createIfOpForRankSpecializedBroadcastAndOp( - rewriter, op, greater_rank, 1); + rewriter, op, greater_rank, min_rank_specialization); OpBuilder if_builder = if_op.getThenBodyBuilder(rewriter.getListener()); createRankSpecializedBroadcastAndOp(if_builder, op, reshaped_operands, - reduced_shapes, 1); + reduced_shapes, + min_rank_specialization); // Put each subsequent rank specialization inside the else statement of the // previous one. OpBuilder else_builder = if_op.getElseBodyBuilder(rewriter.getListener()); - constexpr int kMaxRankSpecialization = 5; - for (int i = 2; i < kMaxRankSpecialization; i++) { + for (int i = min_rank_specialization + 1; i < kMaxRankSpecialization; i++) { auto inner_if = createIfOpForRankSpecializedBroadcastAndOp( else_builder, op, greater_rank, i); if_builder = inner_if.getThenBodyBuilder(rewriter.getListener()); @@ -468,13 +471,15 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp // If shapes do not have exactly one element, nor are equal // - // See if values are of a rank that we support. + // See if values are of a rank that we support. We already handle cases + // where one of the operands is a scalar, or both are equal. So there should + // be no case left where both have rank 1. OpBuilder if_neq_shapes_builder = if_eq_shapes_op.getElseBodyBuilder(rewriter.getListener()); if_neq_shapes_builder.create( - loc, ConvertUnrankedDynamicBroadcastOpHelper< - ChloOpTy, HloOpTy>::HandleBroadcastAndOp(if_neq_shapes_builder, - op, {lhs, rhs})); + loc, ConvertUnrankedDynamicBroadcastOpHelper:: + HandleBroadcastAndOp(if_neq_shapes_builder, op, {lhs, rhs}, + /*min_rank_specialization=*/2)); rewriter.replaceOp(op, {if_op.getResult(0)}); return success(); @@ -518,9 +523,10 @@ struct ConvertUnrankedDynamicBroadcastSelectOp // more potential for optimization here. This also is missing the // specialization for rank 0. rewriter.replaceOp( - op, {ConvertUnrankedDynamicBroadcastOpHelper< - chlo::BroadcastSelectOp, - mhlo::SelectOp>::HandleBroadcastAndOp(rewriter, op, operands)}); + op, {ConvertUnrankedDynamicBroadcastOpHelper:: + HandleBroadcastAndOp(rewriter, op, operands, + /*min_rank_specialization=*/1)}); return success(); } }; diff --git a/tests/hlo-transform-unranked.mlir b/tests/hlo-transform-unranked.mlir index b3abe71..53cec49 100644 --- a/tests/hlo-transform-unranked.mlir +++ b/tests/hlo-transform-unranked.mlir @@ -207,85 +207,69 @@ func @addUnrankedUnranked( // CHECK-NEXT: %[[RHS_RANK:.*]] = shape.rank %[[MINIMUM_SHAPES]]#1 : tensor -> index // CHECK-NEXT: %[[LHS_RANK_GREATER:.*]] = cmpi sgt, %[[LHS_RANK]], %[[RHS_RANK]] : index // CHECK-NEXT: %[[GREATEST_RANK:.*]] = select %[[LHS_RANK_GREATER]], %[[LHS_RANK]], %[[RHS_RANK]] : index -// Handle rank 1 specialization -// CHECK-NEXT: %[[GREATEST_RANK_IS_1:.*]] = cmpi eq, %[[GREATEST_RANK]], %[[C1]] : index -// CHECK-NEXT: %[[RESULT_RANK_SPECIALIZATION:.*]] = scf.if %[[GREATEST_RANK_IS_1]] -> (tensor<*xf32>) { -// CHECK-NEXT: %[[CONST_SHAPE_1:.*]] = shape.const_shape [1] -// CHECK-NEXT: %[[BROADCASTED_LHS_1:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_1]] : tensor, tensor<1xindex> -> tensor -// CHECK-NEXT: %[[CASTED_LHS_1:.*]] = tensor.cast %[[BROADCASTED_LHS_1]] : tensor to tensor<1xindex> -// CHECK-NEXT: %[[RESHAPED_LHS_1:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS_1]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor -// CHECK-NEXT: %[[BROADCASTED_RHS_1:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_1]] : tensor, tensor<1xindex> -> tensor -// CHECK-NEXT: %[[CASTED_RHS_1:.*]] = tensor.cast %[[BROADCASTED_RHS_1]] : tensor to tensor<1xindex> -// CHECK-NEXT: %[[RESHAPED_RHS_1:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS_1]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor -// CHECK-NEXT: %[[RESULT_RANK_1:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_1]], %[[RESHAPED_RHS_1]] : (tensor, tensor) -> tensor -// CHECK-NEXT: %[[RESULT_1:.*]] = tensor.cast %[[RESULT_RANK_1]] : tensor to tensor<*xf32> -// CHECK-NEXT: scf.yield %[[RESULT_1]] : tensor<*xf32> +// CHECK-NEXT: %[[C2:.*]] = constant 2 : index +// CHECK-NEXT: %[[GREATEST_RANK_IS_2:.*]] = cmpi eq, %[[GREATEST_RANK]], %[[C2]] : index +// Handle rank 2 specialization +// CHECK-NEXT: %[[RESULT_RANK_SPECIALIZATION:.*]] = scf.if %[[GREATEST_RANK_IS_2]] -> (tensor<*xf32>) { +// CHECK-NEXT: %[[CONST_SHAPE_2:.*]] = shape.const_shape [1, 1] +// CHECK-NEXT: %[[BROADCASTED_LHS_2:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_2]] : tensor, tensor<2xindex> -> tensor +// CHECK-NEXT: %[[CASTED_LHS_2:.*]] = tensor.cast %[[BROADCASTED_LHS_2]] : tensor to tensor<2xindex> +// CHECK-NEXT: %[[RESHAPED_LHS_2:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor +// CHECK-NEXT: %[[BROADCASTED_RHS_2:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_2]] : tensor, tensor<2xindex> -> tensor +// CHECK-NEXT: %[[CASTED_RHS_2:.*]] = tensor.cast %[[BROADCASTED_RHS_2]] : tensor to tensor<2xindex> +// CHECK-NEXT: %[[RESHAPED_RHS_2:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor +// CHECK-NEXT: %[[RESULT_RANK_2:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_2]], %[[RESHAPED_RHS_2]] : (tensor, tensor) -> tensor +// CHECK-NEXT: %[[RESULT_2:.*]] = tensor.cast %[[RESULT_RANK_2]] : tensor to tensor<*xf32> +// CHECK-NEXT: scf.yield %[[RESULT_2]] : tensor<*xf32> // CHECK-NEXT: } else { -// CHECK-NEXT: %[[C2:.*]] = constant 2 : index -// CHECK-NEXT: %[[GREATEST_RANK_IS_2:.*]] = cmpi eq, %[[GREATEST_RANK]], %[[C2]] : index -// Handle rank 2 specialization -// CHECK-NEXT: %[[VAL_26:.*]] = scf.if %[[GREATEST_RANK_IS_2]] -> (tensor<*xf32>) { -// CHECK-NEXT: %[[CONST_SHAPE_2:.*]] = shape.const_shape [1, 1] -// CHECK-NEXT: %[[BROADCASTED_LHS_2:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_2]] : tensor, tensor<2xindex> -> tensor -// CHECK-NEXT: %[[CASTED_LHS_2:.*]] = tensor.cast %[[BROADCASTED_LHS_2]] : tensor to tensor<2xindex> -// CHECK-NEXT: %[[RESHAPED_LHS_2:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor -// CHECK-NEXT: %[[BROADCASTED_RHS_2:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_2]] : tensor, tensor<2xindex> -> tensor -// CHECK-NEXT: %[[CASTED_RHS_2:.*]] = tensor.cast %[[BROADCASTED_RHS_2]] : tensor to tensor<2xindex> -// CHECK-NEXT: %[[RESHAPED_RHS_2:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor -// CHECK-NEXT: %[[RESULT_RANK_2:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_2]], %[[RESHAPED_RHS_2]] : (tensor, tensor) -> tensor -// CHECK-NEXT: %[[RESULT_2:.*]] = tensor.cast %[[RESULT_RANK_2]] : tensor to tensor<*xf32> -// CHECK-NEXT: scf.yield %[[RESULT_2]] : tensor<*xf32> +// CHECK-NEXT: %[[C3:.*]] = constant 3 : index +// CHECK-NEXT: %[[GREATEST_RANK_IS_3:.*]] = cmpi eq, %[[GREATEST_RANK]], %[[C3]] : index +// Handle rank 3 specialization +// CHECK-NEXT: %[[VAL_34:.*]] = scf.if %[[GREATEST_RANK_IS_3]] -> (tensor<*xf32>) { +// CHECK-NEXT: %[[CONST_SHAPE_3:.*]] = shape.const_shape [1, 1, 1] +// CHECK-NEXT: %[[BROADCASTED_LHS_3:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_3]] : tensor, tensor<3xindex> -> tensor +// CHECK-NEXT: %[[CASTED_LHS_3:.*]] = tensor.cast %[[BROADCASTED_LHS_3]] : tensor to tensor<3xindex> +// CHECK-NEXT: %[[RESHAPED_LHS_3:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor +// CHECK-NEXT: %[[BROADCASTED_RHS_3:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_3]] : tensor, tensor<3xindex> -> tensor +// CHECK-NEXT: %[[CASTED_RHS_3:.*]] = tensor.cast %[[BROADCASTED_RHS_3]] : tensor to tensor<3xindex> +// CHECK-NEXT: %[[RESHAPED_RHS_3:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor +// CHECK-NEXT: %[[RESULT_RANK_3:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_3]], %[[RESHAPED_RHS_3]] : (tensor, tensor) -> tensor +// CHECK-NEXT: %[[RESULT_3:.*]] = tensor.cast %[[RESULT_RANK_3]] : tensor to tensor<*xf32> +// CHECK-NEXT: scf.yield %[[RESULT_3]] : tensor<*xf32> // CHECK-NEXT: } else { -// CHECK-NEXT: %[[C3:.*]] = constant 3 : index -// CHECK-NEXT: %[[GREATEST_RANK_IS_3:.*]] = cmpi eq, %[[GREATEST_RANK]], %[[C3]] : index -// Handle rank 3 specialization -// CHECK-NEXT: %[[VAL_34:.*]] = scf.if %[[GREATEST_RANK_IS_3]] -> (tensor<*xf32>) { -// CHECK-NEXT: %[[CONST_SHAPE_3:.*]] = shape.const_shape [1, 1, 1] -// CHECK-NEXT: %[[BROADCASTED_LHS_3:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_3]] : tensor, tensor<3xindex> -> tensor -// CHECK-NEXT: %[[CASTED_LHS_3:.*]] = tensor.cast %[[BROADCASTED_LHS_3]] : tensor to tensor<3xindex> -// CHECK-NEXT: %[[RESHAPED_LHS_3:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor -// CHECK-NEXT: %[[BROADCASTED_RHS_3:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_3]] : tensor, tensor<3xindex> -> tensor -// CHECK-NEXT: %[[CASTED_RHS_3:.*]] = tensor.cast %[[BROADCASTED_RHS_3]] : tensor to tensor<3xindex> -// CHECK-NEXT: %[[RESHAPED_RHS_3:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor -// CHECK-NEXT: %[[RESULT_RANK_3:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_3]], %[[RESHAPED_RHS_3]] : (tensor, tensor) -> tensor -// CHECK-NEXT: %[[RESULT_3:.*]] = tensor.cast %[[RESULT_RANK_3]] : tensor to tensor<*xf32> -// CHECK-NEXT: scf.yield %[[RESULT_3]] : tensor<*xf32> +// CHECK-NEXT: %[[C4:.*]] = constant 4 : index +// CHECK-NEXT: %[[GREATEST_RANK_IS_4:.*]] = cmpi eq, %[[GREATEST_RANK]], %[[C4]] : index +// Handle rank 4 specialization +// CHECK-NEXT: %[[VAL_42:.*]] = scf.if %[[GREATEST_RANK_IS_4]] -> (tensor<*xf32>) { +// CHECK-NEXT: %[[CONST_SHAPE_4:.*]] = shape.const_shape [1, 1, 1, 1] +// CHECK-NEXT: %[[BROADCASTED_LHS_4:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_4]] : tensor, tensor<4xindex> -> tensor +// CHECK-NEXT: %[[CASTED_LHS_4:.*]] = tensor.cast %[[BROADCASTED_LHS_4]] : tensor to tensor<4xindex> +// CHECK-NEXT: %[[RESHAPED_LHS_4:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor +// CHECK-NEXT: %[[BROADCASTED_RHS_4:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_4]] : tensor, tensor<4xindex> -> tensor +// CHECK-NEXT: %[[CASTED_RHS_4:.*]] = tensor.cast %[[BROADCASTED_RHS_4]] : tensor to tensor<4xindex> +// CHECK-NEXT: %[[RESHAPED_RHS_4:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor +// CHECK-NEXT: %[[RESULT_RANK_4:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_4]], %[[RESHAPED_RHS_4]] : (tensor, tensor) -> tensor +// CHECK-NEXT: %[[RESULT_4:.*]] = tensor.cast %[[RESULT_RANK_4]] : tensor to tensor<*xf32> +// CHECK-NEXT: scf.yield %[[RESULT_4]] : tensor<*xf32> // CHECK-NEXT: } else { -// CHECK-NEXT: %[[C4:.*]] = constant 4 : index -// CHECK-NEXT: %[[GREATEST_RANK_IS_4:.*]] = cmpi eq, %[[GREATEST_RANK]], %[[C4]] : index -// Handle rank 4 specialization -// CHECK-NEXT: %[[VAL_42:.*]] = scf.if %[[GREATEST_RANK_IS_4]] -> (tensor<*xf32>) { -// CHECK-NEXT: %[[CONST_SHAPE_4:.*]] = shape.const_shape [1, 1, 1, 1] -// CHECK-NEXT: %[[BROADCASTED_LHS_4:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_4]] : tensor, tensor<4xindex> -> tensor -// CHECK-NEXT: %[[CASTED_LHS_4:.*]] = tensor.cast %[[BROADCASTED_LHS_4]] : tensor to tensor<4xindex> -// CHECK-NEXT: %[[RESHAPED_LHS_4:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor -// CHECK-NEXT: %[[BROADCASTED_RHS_4:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_4]] : tensor, tensor<4xindex> -> tensor -// CHECK-NEXT: %[[CASTED_RHS_4:.*]] = tensor.cast %[[BROADCASTED_RHS_4]] : tensor to tensor<4xindex> -// CHECK-NEXT: %[[RESHAPED_RHS_4:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor -// CHECK-NEXT: %[[RESULT_RANK_4:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_4]], %[[RESHAPED_RHS_4]] : (tensor, tensor) -> tensor -// CHECK-NEXT: %[[RESULT_4:.*]] = tensor.cast %[[RESULT_RANK_4]] : tensor to tensor<*xf32> -// CHECK-NEXT: scf.yield %[[RESULT_4]] : tensor<*xf32> -// CHECK-NEXT: } else { -// CHECK-NEXT: %[[C5:.*]] = constant 5 : index -// CHECK-NEXT: %[[GREATEST_RANK_IS_5:.*]] = cmpi eq, %[[GREATEST_RANK]], %[[C5]] : index -// CHECK-NEXT: assert %[[GREATEST_RANK_IS_5]] -// Handle rank 5 specialization -// CHECK-NEXT: %[[CONST_SHAPE_5:.*]] = shape.const_shape [1, 1, 1, 1, 1] -// CHECK-NEXT: %[[BROADCASTED_LHS_5:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_5]] : tensor, tensor<5xindex> -> tensor -// CHECK-NEXT: %[[CASTED_LHS_5:.*]] = tensor.cast %[[BROADCASTED_LHS_5]] : tensor to tensor<5xindex> -// CHECK-NEXT: %[[RESHAPED_LHS_5:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor -// CHECK-NEXT: %[[BROADCASTED_RHS_5:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_5]] : tensor, tensor<5xindex> -> tensor -// CHECK-NEXT: %[[CASTED_RHS_5:.*]] = tensor.cast %[[BROADCASTED_RHS_5]] : tensor to tensor<5xindex> -// CHECK-NEXT: %[[RESHAPED_RHS_5:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor -// CHECK-NEXT: %[[RESULT_RANK_5:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_5]], %[[RESHAPED_RHS_5]] : (tensor, tensor) -> tensor -// CHECK-NEXT: %[[RESULT_5:.*]] = tensor.cast %[[RESULT_RANK_5]] : tensor to tensor<*xf32> -// CHECK-NEXT: scf.yield %[[RESULT_5]] : tensor<*xf32> -// CHECK-NEXT: } -// CHECK-NEXT: scf.yield %[[VAL_66:.*]] : tensor<*xf32> +// CHECK-NEXT: %[[C5:.*]] = constant 5 : index +// CHECK-NEXT: %[[GREATEST_RANK_IS_5:.*]] = cmpi eq, %[[GREATEST_RANK]], %[[C5]] : index +// CHECK-NEXT: assert %[[GREATEST_RANK_IS_5]] +// Handle rank 5 specialization +// CHECK-NEXT: %[[CONST_SHAPE_5:.*]] = shape.const_shape [1, 1, 1, 1, 1] +// CHECK-NEXT: %[[BROADCASTED_LHS_5:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_5]] : tensor, tensor<5xindex> -> tensor +// CHECK-NEXT: %[[CASTED_LHS_5:.*]] = tensor.cast %[[BROADCASTED_LHS_5]] : tensor to tensor<5xindex> +// CHECK-NEXT: %[[RESHAPED_LHS_5:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor +// CHECK-NEXT: %[[BROADCASTED_RHS_5:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_5]] : tensor, tensor<5xindex> -> tensor +// CHECK-NEXT: %[[CASTED_RHS_5:.*]] = tensor.cast %[[BROADCASTED_RHS_5]] : tensor to tensor<5xindex> +// CHECK-NEXT: %[[RESHAPED_RHS_5:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor +// CHECK-NEXT: %[[RESULT_RANK_5:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_5]], %[[RESHAPED_RHS_5]] : (tensor, tensor) -> tensor +// CHECK-NEXT: %[[RESULT_5:.*]] = tensor.cast %[[RESULT_RANK_5]] : tensor to tensor<*xf32> +// CHECK-NEXT: scf.yield %[[RESULT_5]] : tensor<*xf32> // CHECK-NEXT: } -// CHECK-NEXT: scf.yield %[[VAL_67:.*]] : tensor<*xf32> +// CHECK-NEXT: scf.yield %[[VAL_66:.*]] : tensor<*xf32> // CHECK-NEXT: } -// CHECK-NEXT: scf.yield %[[VAL_68:.*]] : tensor<*xf32> +// CHECK-NEXT: scf.yield %[[VAL_67:.*]] : tensor<*xf32> // CHECK-NEXT: } // CHECK-NEXT: %[[RESHAPED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[RESULT_RANK_SPECIALIZATION]], %[[RESULT_SHAPE]]) : (tensor<*xf32>, tensor) -> tensor<*xf32> // CHECK-NEXT: scf.yield %[[RESHAPED_RESULT]] : tensor<*xf32>