diff --git a/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc b/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc index 842423e..a47ce7b 100644 --- a/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc +++ b/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc @@ -164,7 +164,10 @@ struct ConvertUnrankedScalarDynamicBroadcastBinaryOp // the more generic case of both inputs being unranked. if (!(lhs_is_scalar ^ rhs_is_scalar)) return failure(); + auto scalar_element_type = lhs_is_scalar ? lhs_ranked_type.getElementType() + : rhs_ranked_type.getElementType(); auto result_type = op.getResult().getType().template dyn_cast(); + auto result_element_type = result_type.getElementType(); // Reshape the non-scalar value into a dynamically sized, rank-1 tensor Value shape = @@ -173,16 +176,16 @@ struct ConvertUnrankedScalarDynamicBroadcastBinaryOp Value size_tensor = rewriter.create(loc, num_elements); Value reshaped = rewriter.create( - loc, RankedTensorType::get({-1}, result_type.getElementType()), + loc, RankedTensorType::get({-1}, scalar_element_type), lhs_is_scalar ? rhs : lhs, size_tensor); // Create a new ranked Chlo op that will be further lowered by other // patterns into Mhlo. SmallVector new_operands{lhs_is_scalar ? lhs : reshaped, rhs_is_scalar ? rhs : reshaped}; - Value computed = - rewriter.create(loc, SmallVector{reshaped.getType()}, - new_operands, op.getAttrs()); + Value computed = rewriter.create( + loc, TypeRange{RankedTensorType::get({-1}, result_element_type)}, + new_operands, op.getAttrs()); // Reshape the result back into an unranked tensor. rewriter.replaceOpWithNewOp(op, result_type, @@ -287,8 +290,7 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp } private: - // Returns the dyanamic result of checking the given value is a scalar - // tensor. + // Returns the dynamic result of checking the given value is a scalar tensor. Value IsScalarTensor(OpBuilder &rewriter, ChloOpTy op, Value tensor) const { auto loc = op.getLoc(); @@ -300,30 +302,38 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp rewriter.create(loc, 0)); } - // Create the if statement and code for a broadcasting op with a result of a - // given rank. - scf::IfOp createRankSpecializedBroadcastAndOp(OpBuilder &builder, ChloOpTy op, - Value lhs, Value rhs, - Value actual_rank, - int targeted_rank) const { - auto loc = op.getLoc(); - - // Create the if block to place the current specialized logic in. - Value greater_rank_is_n = builder.create( + Value GreaterRankIsN(OpBuilder &builder, Location loc, Value actual_rank, + int targeted_rank) const { + return builder.create( loc, CmpIPredicate::eq, actual_rank, builder.create(loc, targeted_rank)); - auto if_op = - builder.create(loc, lhs.getType(), greater_rank_is_n, true); - OpBuilder if_builder = if_op.getThenBodyBuilder(builder.getListener()); + } + + scf::IfOp createIfOpForRankSpecializedBroadcastAndOp( + OpBuilder &builder, ChloOpTy op, Value actual_rank, + int targeted_rank) const { + // Create the if block to place the current specialized logic in. + Value greater_rank_is_n = + GreaterRankIsN(builder, op.getLoc(), actual_rank, targeted_rank); + return builder.create(op.getLoc(), op.getResult().getType(), + greater_rank_is_n, true); + } + + // Create the if statement and code for a broadcasting op with a result of a + // given rank. + void createRankSpecializedBroadcastAndOp(OpBuilder &if_builder, ChloOpTy op, + Value lhs, Value rhs, + int targeted_rank) const { + auto loc = op.getLoc(); // Handle shape broadcasting and inferrence. Value lhs_shape = if_builder.create(loc, lhs); Value rhs_shape = if_builder.create(loc, rhs); SmallVector ranked_shape(targeted_rank, 1); auto unknown_rank_extent_tensor_type = RankedTensorType::get( - {RankedTensorType::kDynamicSize}, builder.getIndexType()); + {RankedTensorType::kDynamicSize}, if_builder.getIndexType()); auto known_rank_extent_tensor_type = - RankedTensorType::get({targeted_rank}, builder.getIndexType()); + RankedTensorType::get({targeted_rank}, if_builder.getIndexType()); auto reshaped_type = RankedTensorType::get( llvm::SmallVector(targeted_rank, RankedTensorType::kDynamicSize), @@ -351,23 +361,26 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp loc, reshaped_type, lhs, extended_lhs_casted); Value reshaped_rhs = if_builder.create( loc, reshaped_type, rhs, extended_rhs_casted); + auto result_element_type = op.getResult() + .getType() + .template dyn_cast() + .getElementType(); + auto result_type = RankedTensorType::get( + llvm::SmallVector(targeted_rank, + RankedTensorType::kDynamicSize), + result_element_type); Value result = if_builder.create( - loc, ArrayRef{reshaped_type}, + loc, ArrayRef{result_type}, ArrayRef{reshaped_lhs, reshaped_rhs}, op.getAttrs()); Value reshaped_result = if_builder.create( - loc, UnrankedTensorType::get(reshaped_type.getElementType()), result); + loc, UnrankedTensorType::get(result_element_type), result); if_builder.create(loc, reshaped_result); - - // Return the if_op, so the result can be used and the else block can be - // used for the next rank specialized step. - return if_op; } // Iterates over the desired ranks to be specialized and generates the code // snippet for each case. Value HandleBroadcastAndOp(OpBuilder &rewriter, ChloOpTy op, Value lhs, Value rhs) const { - constexpr int max_rank_specialization = 7; auto loc = op.getLoc(); // Find the larger rank of the 2 operands. @@ -388,26 +401,34 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp // Generate a list of nested if/else statements to handle rank // specializations from 1-6. - scf::IfOp if_op = createRankSpecializedBroadcastAndOp(rewriter, op, lhs, - rhs, greater_rank, 1); + scf::IfOp if_op = createIfOpForRankSpecializedBroadcastAndOp( + rewriter, op, greater_rank, 1); + OpBuilder if_builder = if_op.getThenBodyBuilder(rewriter.getListener()); + createRankSpecializedBroadcastAndOp(if_builder, op, lhs, rhs, 1); // Put each subsequent rank specialization inside the else statement of the // previous one. OpBuilder else_builder = if_op.getElseBodyBuilder(rewriter.getListener()); - for (int i = 2; i < max_rank_specialization; i++) { - auto inner_if = createRankSpecializedBroadcastAndOp(else_builder, op, lhs, - rhs, greater_rank, i); - + constexpr int kMaxRankSpecialization = 6; + for (int i = 2; i < kMaxRankSpecialization; i++) { + auto inner_if = createIfOpForRankSpecializedBroadcastAndOp( + else_builder, op, greater_rank, i); + if_builder = inner_if.getThenBodyBuilder(rewriter.getListener()); + createRankSpecializedBroadcastAndOp(if_builder, op, lhs, rhs, i); else_builder.create(loc, inner_if.getResult(0)); else_builder = inner_if.getElseBodyBuilder(rewriter.getListener()); } - - // Fire an assertion if none of the rank specializations applied (one of the - // ranks was greater than 6). + // Fire an assertion if none of the rank specializations applied (one of + // the ranks was greater than 6). else_builder.create( - loc, else_builder.create(loc, 0, 1), - "Input for dynamic binary op lowering was of a rank greater than 6"); - else_builder.create(loc, lhs); + loc, + GreaterRankIsN(else_builder, op.getLoc(), greater_rank, + kMaxRankSpecialization), + "Input for dynamic binary op lowering was of a rank greater than " + "6"); + // Add the rank 6 specialization to the innermost else block. + createRankSpecializedBroadcastAndOp(else_builder, op, lhs, rhs, + kMaxRankSpecialization); // Return the result of the outermost if statement. return if_op.getResult(0); diff --git a/tests/hlo-transform-unranked.mlir b/tests/hlo-transform-unranked.mlir index cce0a94..af83b4a 100644 --- a/tests/hlo-transform-unranked.mlir +++ b/tests/hlo-transform-unranked.mlir @@ -276,24 +276,18 @@ func @addUnrankedUnranked( // CHECK-NEXT: } else { // CHECK-NEXT: %[[C6:.*]] = constant 6 : index // CHECK-NEXT: %[[GREATEST_RANK_IS_6:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C6]] : index +// CHECK-NEXT: assert %[[GREATEST_RANK_IS_6]] // Handle rank 6 specialization -// CHECK-NEXT: %[[VAL_58:.*]] = scf.if %[[GREATEST_RANK_IS_6]] -> (tensor<*xf32>) { -// CHECK-NEXT: %[[CONST_SHAPE_6:.*]] = shape.const_shape [1, 1, 1, 1, 1, 1] -// CHECK-NEXT: %[[BROADCASTED_LHS_6:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_6]] : tensor, tensor<6xindex> -> tensor -// CHECK-NEXT: %[[CASTED_LHS_6:.*]] = tensor_cast %[[BROADCASTED_LHS_6]] : tensor to tensor<6xindex> -// CHECK-NEXT: %[[BROADCASTED_RHS_6:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_6]] : tensor, tensor<6xindex> -> tensor -// CHECK-NEXT: %[[CASTED_RHS_6:.*]] = tensor_cast %[[BROADCASTED_RHS_6]] : tensor to tensor<6xindex> -// CHECK-NEXT: %[[RESHAPED_LHS_6:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_6]]) : (tensor<*xf32>, tensor<6xindex>) -> tensor -// CHECK-NEXT: %[[RESHAPED_RHS_6:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_6]]) : (tensor<*xf32>, tensor<6xindex>) -> tensor -// CHECK-NEXT: %[[RESULT_RANK_6:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_6]], %[[RESHAPED_RHS_6]] : (tensor, tensor) -> tensor -// CHECK-NEXT: %[[RESULT_6:.*]] = tensor_cast %[[RESULT_RANK_6]] : tensor to tensor<*xf32> -// CHECK-NEXT: scf.yield %[[RESULT_6]] : tensor<*xf32> -// CHECK-NEXT: } else { -// CHECK-NEXT: %false = constant false -// CHECK-NEXT: assert %false -// CHECK-NEXT: scf.yield %[[LHS]] : tensor<*xf32> -// CHECK-NEXT: } -// CHECK-NEXT: scf.yield %[[VAL_64:.*]] : tensor<*xf32> +// CHECK-NEXT: %[[CONST_SHAPE_6:.*]] = shape.const_shape [1, 1, 1, 1, 1, 1] +// CHECK-NEXT: %[[BROADCASTED_LHS_6:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_6]] : tensor, tensor<6xindex> -> tensor +// CHECK-NEXT: %[[CASTED_LHS_6:.*]] = tensor_cast %[[BROADCASTED_LHS_6]] : tensor to tensor<6xindex> +// CHECK-NEXT: %[[BROADCASTED_RHS_6:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_6]] : tensor, tensor<6xindex> -> tensor +// CHECK-NEXT: %[[CASTED_RHS_6:.*]] = tensor_cast %[[BROADCASTED_RHS_6]] : tensor to tensor<6xindex> +// CHECK-NEXT: %[[RESHAPED_LHS_6:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_6]]) : (tensor<*xf32>, tensor<6xindex>) -> tensor +// CHECK-NEXT: %[[RESHAPED_RHS_6:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_6]]) : (tensor<*xf32>, tensor<6xindex>) -> tensor +// CHECK-NEXT: %[[RESULT_RANK_6:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_6]], %[[RESHAPED_RHS_6]] : (tensor, tensor) -> tensor +// CHECK-NEXT: %[[RESULT_6:.*]] = tensor_cast %[[RESULT_RANK_6]] : tensor to tensor<*xf32> +// CHECK-NEXT: scf.yield %[[RESULT_6]] : tensor<*xf32> // CHECK-NEXT: } // CHECK-NEXT: scf.yield %[[VAL_65:.*]] : tensor<*xf32> // CHECK-NEXT: }