From b42def4612a9144acc33dca873deb021b92b4a11 Mon Sep 17 00:00:00 2001 From: Benjamin Kramer Date: Thu, 18 Feb 2021 09:52:03 -0800 Subject: [PATCH] [mlir][hlo] Refactor rank specialization to allow an arbitrary number of inputs This actually simplifies the code a bit. PiperOrigin-RevId: 358201038 --- .../mhlo/transforms/transform_unranked_hlo.cc | 280 +++++++++--------- tests/hlo-transform-unranked.mlir | 12 +- 2 files changed, 152 insertions(+), 140 deletions(-) diff --git a/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc b/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc index 7c47b6f..3f12d51 100644 --- a/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc +++ b/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc @@ -202,6 +202,149 @@ struct ConvertUnrankedScalarDynamicBroadcastBinaryOp } }; +template +struct ConvertUnrankedDynamicBroadcastOpHelper { + // Returns the dynamic result of checking the given value is effectively a + // scalar shape (i.e. the number of elements is 1). + static Value GreaterRankIsN(OpBuilder &builder, Location loc, + Value actual_rank, int targeted_rank) { + return builder.create( + loc, CmpIPredicate::eq, actual_rank, + builder.create(loc, targeted_rank)); + } + + static scf::IfOp createIfOpForRankSpecializedBroadcastAndOp( + OpBuilder &builder, ChloOpTy op, Value actual_rank, int targeted_rank) { + // Create the if block to place the current specialized logic in. + Value greater_rank_is_n = + GreaterRankIsN(builder, op.getLoc(), actual_rank, targeted_rank); + return builder.create(op.getLoc(), op.getResult().getType(), + greater_rank_is_n, true); + } + + static Value createBroadcastToKnownRank(OpBuilder &builder, ChloOpTy op, + Value value, int targeted_rank) { + auto loc = op.getLoc(); + Value shape = builder.create(loc, value); + SmallVector ranked_shape(targeted_rank, 1); + auto unknown_rank_extent_tensor_type = RankedTensorType::get( + {RankedTensorType::kDynamicSize}, builder.getIndexType()); + auto known_rank_extent_tensor_type = + RankedTensorType::get({targeted_rank}, builder.getIndexType()); + Value ranked_shape_val = builder.create( + loc, known_rank_extent_tensor_type, + mlir::DenseIntElementsAttr::get(known_rank_extent_tensor_type, + ranked_shape)); + Value extended_value = builder.create( + loc, unknown_rank_extent_tensor_type, shape, ranked_shape_val, nullptr); + return builder.create(loc, known_rank_extent_tensor_type, + extended_value); + } + + // Create the if statement and code for a broadcasting op with a result of a + // given rank. + static void createRankSpecializedBroadcastAndOp(OpBuilder &if_builder, + ChloOpTy op, + ValueRange operands, + int targeted_rank) { + auto loc = op.getLoc(); + SmallVector reshaped_operands; + + auto dynamic_dimensions = llvm::SmallVector( + targeted_rank, RankedTensorType::kDynamicSize); + + for (Value operand : operands) { + // Handle shape broadcasting and inference. + Value extended_operand_casted = + createBroadcastToKnownRank(if_builder, op, operand, targeted_rank); + + // 1. Reshape operands to the given rank (with the same number of + // elements) + // 2. Compute the ranked-broadcasted ChloOp (which will assert that the + // ops + // can be broadcasted and do the actual broadcasting) + // 3. Type erase the output back to unranked + auto reshaped_type = RankedTensorType::get( + dynamic_dimensions, + operand.getType().template dyn_cast().getElementType()); + Value reshaped_operand = if_builder.create( + loc, reshaped_type, operand, extended_operand_casted); + reshaped_operands.push_back(reshaped_operand); + } + auto result_element_type = op.getResult() + .getType() + .template dyn_cast() + .getElementType(); + auto result_type = + RankedTensorType::get(dynamic_dimensions, result_element_type); + Value result = if_builder.create( + loc, ArrayRef{result_type}, reshaped_operands, op.getAttrs()); + Value reshaped_result = if_builder.create( + loc, UnrankedTensorType::get(result_element_type), result); + if_builder.create(loc, reshaped_result); + } + + // Iterates over the desired ranks to be specialized and generates the code + // snippet for each case. + static Value HandleBroadcastAndOp(OpBuilder &rewriter, ChloOpTy op, + ValueRange operands) { + auto loc = op.getLoc(); + + // Find the larger rank of the operands. + auto extent_tensor_type = RankedTensorType::get({ShapedType::kDynamicSize}, + rewriter.getIndexType()); + Value greater_rank; + for (Value operand : operands) { + Value shape = + rewriter.create(loc, extent_tensor_type, operand); + Value rank = + rewriter.create(loc, rewriter.getIndexType(), shape); + if (!greater_rank) { + greater_rank = rank; + } else { + Value greater_rank_compare = rewriter.create( + loc, CmpIPredicate::sgt, greater_rank, rank); + greater_rank = rewriter.create(loc, greater_rank_compare, + greater_rank, rank); + } + } + + // Generate a list of nested if/else statements to handle rank + // specializations from 1 to `kMaxRankSpecialization`. + scf::IfOp if_op = createIfOpForRankSpecializedBroadcastAndOp( + rewriter, op, greater_rank, 1); + OpBuilder if_builder = if_op.getThenBodyBuilder(rewriter.getListener()); + createRankSpecializedBroadcastAndOp(if_builder, op, operands, 1); + + // Put each subsequent rank specialization inside the else statement of the + // previous one. + OpBuilder else_builder = if_op.getElseBodyBuilder(rewriter.getListener()); + constexpr int kMaxRankSpecialization = 6; + for (int i = 2; i < kMaxRankSpecialization; i++) { + auto inner_if = createIfOpForRankSpecializedBroadcastAndOp( + else_builder, op, greater_rank, i); + if_builder = inner_if.getThenBodyBuilder(rewriter.getListener()); + createRankSpecializedBroadcastAndOp(if_builder, op, operands, i); + else_builder.create(loc, inner_if.getResult(0)); + else_builder = inner_if.getElseBodyBuilder(rewriter.getListener()); + } + // Fire an assertion if none of the rank specializations applied (one of + // the ranks was greater than `kMaxRankSpecialization`). + else_builder.create( + loc, + GreaterRankIsN(else_builder, op.getLoc(), greater_rank, + kMaxRankSpecialization), + "Input for dynamic binary op lowering was of a rank greater than " + + std::to_string(kMaxRankSpecialization)); + // Add the rank 6 specialization to the innermost else block. + createRankSpecializedBroadcastAndOp(else_builder, op, operands, + kMaxRankSpecialization); + + // Return the result of the outermost if statement. + return if_op.getResult(0); + } +}; + // Handles lowering of the following pattern to patterns that will be further // matched by other patterns until they result in LHLO: // %result = "chlo.op"(%lhs, %rhs) : (<*xTy>, <*xTy>) -> <*xTy> @@ -298,7 +441,9 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp OpBuilder if_neq_shapes_builder = if_eq_shapes_op.getElseBodyBuilder(rewriter.getListener()); if_neq_shapes_builder.create( - loc, HandleBroadcastAndOp(if_neq_shapes_builder, op, lhs, rhs)); + loc, ConvertUnrankedDynamicBroadcastOpHelper< + ChloOpTy, HloOpTy>::HandleBroadcastAndOp(if_neq_shapes_builder, + op, {lhs, rhs})); rewriter.replaceOp(op, {if_op.getResult(0)}); return success(); @@ -318,23 +463,6 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp rewriter.create(loc, 1)); } - Value GreaterRankIsN(OpBuilder &builder, Location loc, Value actual_rank, - int targeted_rank) const { - return builder.create( - loc, CmpIPredicate::eq, actual_rank, - builder.create(loc, targeted_rank)); - } - - scf::IfOp createIfOpForRankSpecializedBroadcastAndOp( - OpBuilder &builder, ChloOpTy op, Value actual_rank, - int targeted_rank) const { - // Create the if block to place the current specialized logic in. - Value greater_rank_is_n = - GreaterRankIsN(builder, op.getLoc(), actual_rank, targeted_rank); - return builder.create(op.getLoc(), op.getResult().getType(), - greater_rank_is_n, true); - } - Value extendToBroadcastShape(OpBuilder &builder, Location loc, Value value, Value shape_of_lhs, Value shape_of_rhs) const { auto unknown_rank_extent_tensor_type = RankedTensorType::get( @@ -345,122 +473,6 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp return builder.create(loc, value.getType(), value, broadcast_shape); } - - Value createBroadcastToKnownRank(OpBuilder &builder, ChloOpTy op, Value value, - int targeted_rank) const { - auto loc = op.getLoc(); - Value shape = builder.create(loc, value); - SmallVector ranked_shape(targeted_rank, 1); - auto unknown_rank_extent_tensor_type = RankedTensorType::get( - {RankedTensorType::kDynamicSize}, builder.getIndexType()); - auto known_rank_extent_tensor_type = - RankedTensorType::get({targeted_rank}, builder.getIndexType()); - Value ranked_shape_val = builder.create( - loc, known_rank_extent_tensor_type, - mlir::DenseIntElementsAttr::get(known_rank_extent_tensor_type, - ranked_shape)); - Value extended_value = builder.create( - loc, unknown_rank_extent_tensor_type, shape, ranked_shape_val, nullptr); - return builder.create(loc, known_rank_extent_tensor_type, - extended_value); - } - - // Create the if statement and code for a broadcasting op with a result of a - // given rank. - void createRankSpecializedBroadcastAndOp(OpBuilder &if_builder, ChloOpTy op, - Value lhs, Value rhs, - int targeted_rank) const { - auto loc = op.getLoc(); - - // Handle shape broadcasting and inference. - Value extended_lhs_casted = - createBroadcastToKnownRank(if_builder, op, lhs, targeted_rank); - Value extended_rhs_casted = - createBroadcastToKnownRank(if_builder, op, rhs, targeted_rank); - auto dynamic_dimensions = llvm::SmallVector( - targeted_rank, RankedTensorType::kDynamicSize); - auto reshaped_type = RankedTensorType::get( - dynamic_dimensions, - lhs.getType().template dyn_cast().getElementType()); - - // 1. Reshape operands to the given rank (with the same number of elements) - // 2. Compute the ranked-broadcasted ChloOp (which will assert that the ops - // can be broadcasted and do the actual broadcasting) - // 3. Type erase the output back to unranked - Value reshaped_lhs = if_builder.create( - loc, reshaped_type, lhs, extended_lhs_casted); - Value reshaped_rhs = if_builder.create( - loc, reshaped_type, rhs, extended_rhs_casted); - auto result_element_type = op.getResult() - .getType() - .template dyn_cast() - .getElementType(); - auto result_type = - RankedTensorType::get(dynamic_dimensions, result_element_type); - Value result = if_builder.create( - loc, ArrayRef{result_type}, - ArrayRef{reshaped_lhs, reshaped_rhs}, op.getAttrs()); - Value reshaped_result = if_builder.create( - loc, UnrankedTensorType::get(result_element_type), result); - if_builder.create(loc, reshaped_result); - } - - // Iterates over the desired ranks to be specialized and generates the code - // snippet for each case. - Value HandleBroadcastAndOp(OpBuilder &rewriter, ChloOpTy op, Value lhs, - Value rhs) const { - auto loc = op.getLoc(); - - // Find the larger rank of the 2 operands. - auto extent_tensor_type = RankedTensorType::get({ShapedType::kDynamicSize}, - rewriter.getIndexType()); - Value lhs_shape = - rewriter.create(loc, extent_tensor_type, lhs); - Value rhs_shape = - rewriter.create(loc, extent_tensor_type, rhs); - Value lhs_rank = - rewriter.create(loc, rewriter.getIndexType(), lhs_shape); - Value rhs_rank = - rewriter.create(loc, rewriter.getIndexType(), rhs_shape); - Value greater_rank_lhs = - rewriter.create(loc, CmpIPredicate::sgt, lhs_rank, rhs_rank); - Value greater_rank = - rewriter.create(loc, greater_rank_lhs, lhs_rank, rhs_rank); - - // Generate a list of nested if/else statements to handle rank - // specializations from 1 to `kMaxRankSpecialization`. - scf::IfOp if_op = createIfOpForRankSpecializedBroadcastAndOp( - rewriter, op, greater_rank, 1); - OpBuilder if_builder = if_op.getThenBodyBuilder(rewriter.getListener()); - createRankSpecializedBroadcastAndOp(if_builder, op, lhs, rhs, 1); - - // Put each subsequent rank specialization inside the else statement of the - // previous one. - OpBuilder else_builder = if_op.getElseBodyBuilder(rewriter.getListener()); - constexpr int kMaxRankSpecialization = 6; - for (int i = 2; i < kMaxRankSpecialization; i++) { - auto inner_if = createIfOpForRankSpecializedBroadcastAndOp( - else_builder, op, greater_rank, i); - if_builder = inner_if.getThenBodyBuilder(rewriter.getListener()); - createRankSpecializedBroadcastAndOp(if_builder, op, lhs, rhs, i); - else_builder.create(loc, inner_if.getResult(0)); - else_builder = inner_if.getElseBodyBuilder(rewriter.getListener()); - } - // Fire an assertion if none of the rank specializations applied (one of - // the ranks was greater than `kMaxRankSpecialization`). - else_builder.create( - loc, - GreaterRankIsN(else_builder, op.getLoc(), greater_rank, - kMaxRankSpecialization), - "Input for dynamic binary op lowering was of a rank greater than " + - std::to_string(kMaxRankSpecialization)); - // Add the rank 6 specialization to the innermost else block. - createRankSpecializedBroadcastAndOp(else_builder, op, lhs, rhs, - kMaxRankSpecialization); - - // Return the result of the outermost if statement. - return if_op.getResult(0); - } }; struct TransformUnrankedHloPass diff --git a/tests/hlo-transform-unranked.mlir b/tests/hlo-transform-unranked.mlir index a074763..43270f3 100644 --- a/tests/hlo-transform-unranked.mlir +++ b/tests/hlo-transform-unranked.mlir @@ -209,9 +209,9 @@ func @addUnrankedUnranked( // CHECK-NEXT: %[[CONST_SHAPE_1:.*]] = shape.const_shape [1] // CHECK-NEXT: %[[BROADCASTED_LHS_1:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_1]] : tensor, tensor<1xindex> -> tensor // CHECK-NEXT: %[[CASTED_LHS_1:.*]] = tensor.cast %[[BROADCASTED_LHS_1]] : tensor to tensor<1xindex> +// CHECK-NEXT: %[[RESHAPED_LHS_1:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_1]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor // CHECK-NEXT: %[[BROADCASTED_RHS_1:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_1]] : tensor, tensor<1xindex> -> tensor // CHECK-NEXT: %[[CASTED_RHS_1:.*]] = tensor.cast %[[BROADCASTED_RHS_1]] : tensor to tensor<1xindex> -// CHECK-NEXT: %[[RESHAPED_LHS_1:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_1]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor // CHECK-NEXT: %[[RESHAPED_RHS_1:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_1]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor // CHECK-NEXT: %[[RESULT_RANK_1:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_1]], %[[RESHAPED_RHS_1]] : (tensor, tensor) -> tensor // CHECK-NEXT: %[[RESULT_1:.*]] = tensor.cast %[[RESULT_RANK_1]] : tensor to tensor<*xf32> @@ -224,9 +224,9 @@ func @addUnrankedUnranked( // CHECK-NEXT: %[[CONST_SHAPE_2:.*]] = shape.const_shape [1, 1] // CHECK-NEXT: %[[BROADCASTED_LHS_2:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_2]] : tensor, tensor<2xindex> -> tensor // CHECK-NEXT: %[[CASTED_LHS_2:.*]] = tensor.cast %[[BROADCASTED_LHS_2]] : tensor to tensor<2xindex> +// CHECK-NEXT: %[[RESHAPED_LHS_2:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor // CHECK-NEXT: %[[BROADCASTED_RHS_2:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_2]] : tensor, tensor<2xindex> -> tensor // CHECK-NEXT: %[[CASTED_RHS_2:.*]] = tensor.cast %[[BROADCASTED_RHS_2]] : tensor to tensor<2xindex> -// CHECK-NEXT: %[[RESHAPED_LHS_2:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor // CHECK-NEXT: %[[RESHAPED_RHS_2:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor // CHECK-NEXT: %[[RESULT_RANK_2:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_2]], %[[RESHAPED_RHS_2]] : (tensor, tensor) -> tensor // CHECK-NEXT: %[[RESULT_2:.*]] = tensor.cast %[[RESULT_RANK_2]] : tensor to tensor<*xf32> @@ -239,9 +239,9 @@ func @addUnrankedUnranked( // CHECK-NEXT: %[[CONST_SHAPE_3:.*]] = shape.const_shape [1, 1, 1] // CHECK-NEXT: %[[BROADCASTED_LHS_3:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_3]] : tensor, tensor<3xindex> -> tensor // CHECK-NEXT: %[[CASTED_LHS_3:.*]] = tensor.cast %[[BROADCASTED_LHS_3]] : tensor to tensor<3xindex> +// CHECK-NEXT: %[[RESHAPED_LHS_3:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor // CHECK-NEXT: %[[BROADCASTED_RHS_3:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_3]] : tensor, tensor<3xindex> -> tensor // CHECK-NEXT: %[[CASTED_RHS_3:.*]] = tensor.cast %[[BROADCASTED_RHS_3]] : tensor to tensor<3xindex> -// CHECK-NEXT: %[[RESHAPED_LHS_3:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor // CHECK-NEXT: %[[RESHAPED_RHS_3:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor // CHECK-NEXT: %[[RESULT_RANK_3:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_3]], %[[RESHAPED_RHS_3]] : (tensor, tensor) -> tensor // CHECK-NEXT: %[[RESULT_3:.*]] = tensor.cast %[[RESULT_RANK_3]] : tensor to tensor<*xf32> @@ -254,9 +254,9 @@ func @addUnrankedUnranked( // CHECK-NEXT: %[[CONST_SHAPE_4:.*]] = shape.const_shape [1, 1, 1, 1] // CHECK-NEXT: %[[BROADCASTED_LHS_4:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_4]] : tensor, tensor<4xindex> -> tensor // CHECK-NEXT: %[[CASTED_LHS_4:.*]] = tensor.cast %[[BROADCASTED_LHS_4]] : tensor to tensor<4xindex> +// CHECK-NEXT: %[[RESHAPED_LHS_4:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor // CHECK-NEXT: %[[BROADCASTED_RHS_4:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_4]] : tensor, tensor<4xindex> -> tensor // CHECK-NEXT: %[[CASTED_RHS_4:.*]] = tensor.cast %[[BROADCASTED_RHS_4]] : tensor to tensor<4xindex> -// CHECK-NEXT: %[[RESHAPED_LHS_4:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor // CHECK-NEXT: %[[RESHAPED_RHS_4:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor // CHECK-NEXT: %[[RESULT_RANK_4:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_4]], %[[RESHAPED_RHS_4]] : (tensor, tensor) -> tensor // CHECK-NEXT: %[[RESULT_4:.*]] = tensor.cast %[[RESULT_RANK_4]] : tensor to tensor<*xf32> @@ -269,9 +269,9 @@ func @addUnrankedUnranked( // CHECK-NEXT: %[[CONST_SHAPE_5:.*]] = shape.const_shape [1, 1, 1, 1, 1] // CHECK-NEXT: %[[BROADCASTED_LHS_5:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_5]] : tensor, tensor<5xindex> -> tensor // CHECK-NEXT: %[[CASTED_LHS_5:.*]] = tensor.cast %[[BROADCASTED_LHS_5]] : tensor to tensor<5xindex> +// CHECK-NEXT: %[[RESHAPED_LHS_5:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor // CHECK-NEXT: %[[BROADCASTED_RHS_5:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_5]] : tensor, tensor<5xindex> -> tensor // CHECK-NEXT: %[[CASTED_RHS_5:.*]] = tensor.cast %[[BROADCASTED_RHS_5]] : tensor to tensor<5xindex> -// CHECK-NEXT: %[[RESHAPED_LHS_5:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor // CHECK-NEXT: %[[RESHAPED_RHS_5:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor // CHECK-NEXT: %[[RESULT_RANK_5:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_5]], %[[RESHAPED_RHS_5]] : (tensor, tensor) -> tensor // CHECK-NEXT: %[[RESULT_5:.*]] = tensor.cast %[[RESULT_RANK_5]] : tensor to tensor<*xf32> @@ -284,9 +284,9 @@ func @addUnrankedUnranked( // CHECK-NEXT: %[[CONST_SHAPE_6:.*]] = shape.const_shape [1, 1, 1, 1, 1, 1] // CHECK-NEXT: %[[BROADCASTED_LHS_6:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_6]] : tensor, tensor<6xindex> -> tensor // CHECK-NEXT: %[[CASTED_LHS_6:.*]] = tensor.cast %[[BROADCASTED_LHS_6]] : tensor to tensor<6xindex> +// CHECK-NEXT: %[[RESHAPED_LHS_6:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_6]]) : (tensor<*xf32>, tensor<6xindex>) -> tensor // CHECK-NEXT: %[[BROADCASTED_RHS_6:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_6]] : tensor, tensor<6xindex> -> tensor // CHECK-NEXT: %[[CASTED_RHS_6:.*]] = tensor.cast %[[BROADCASTED_RHS_6]] : tensor to tensor<6xindex> -// CHECK-NEXT: %[[RESHAPED_LHS_6:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_6]]) : (tensor<*xf32>, tensor<6xindex>) -> tensor // CHECK-NEXT: %[[RESHAPED_RHS_6:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_6]]) : (tensor<*xf32>, tensor<6xindex>) -> tensor // CHECK-NEXT: %[[RESULT_RANK_6:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_6]], %[[RESHAPED_RHS_6]] : (tensor, tensor) -> tensor // CHECK-NEXT: %[[RESULT_6:.*]] = tensor.cast %[[RESULT_RANK_6]] : tensor to tensor<*xf32>