diff --git a/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc b/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc index e23e210..503faad 100644 --- a/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc +++ b/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc @@ -206,180 +206,6 @@ struct ConvertUnrankedScalarDynamicBroadcastBinaryOp } }; -template -struct ConvertUnrankedDynamicBroadcastOpHelper { - // Returns the dynamic result of checking the given value is effectively a - // scalar shape (i.e. the number of elements is 1). - static Value GreaterRankIsN(OpBuilder &builder, Location loc, - Value actual_rank, int targeted_rank) { - return builder.create( - loc, CmpIPredicate::eq, actual_rank, - builder.create(loc, targeted_rank)); - } - - static scf::IfOp createIfOpForRankSpecializedBroadcastAndOp( - OpBuilder &builder, ChloOpTy op, Value actual_rank, int targeted_rank) { - // Create the if block to place the current specialized logic in. - Value greater_rank_is_n = - GreaterRankIsN(builder, op.getLoc(), actual_rank, targeted_rank); - return builder.create(op.getLoc(), op.getResult().getType(), - greater_rank_is_n, true); - } - - static Value createBroadcastToKnownRank(OpBuilder &builder, ChloOpTy op, - Value shape, int targeted_rank) { - auto loc = op.getLoc(); - SmallVector ranked_shape(targeted_rank, 1); - auto unknown_rank_extent_tensor_type = RankedTensorType::get( - {RankedTensorType::kDynamicSize}, builder.getIndexType()); - auto known_rank_extent_tensor_type = - RankedTensorType::get({targeted_rank}, builder.getIndexType()); - Value ranked_shape_val = builder.create( - loc, known_rank_extent_tensor_type, - mlir::DenseIntElementsAttr::get(known_rank_extent_tensor_type, - ranked_shape)); - Value extended_value = builder.create( - loc, unknown_rank_extent_tensor_type, shape, ranked_shape_val, nullptr); - return builder.create(loc, known_rank_extent_tensor_type, - extended_value); - } - - // Create the if statement and code for a broadcasting op with a result of a - // given rank. - static void createRankSpecializedBroadcastAndOp(OpBuilder &if_builder, - ChloOpTy op, - ValueRange operands, - ValueRange operand_shapes, - int targeted_rank) { - auto loc = op.getLoc(); - SmallVector reshaped_operands; - - auto dynamic_dimensions = llvm::SmallVector( - targeted_rank, RankedTensorType::kDynamicSize); - - for (auto it : llvm::zip(operands, operand_shapes)) { - Value operand, shape; - std::tie(operand, shape) = it; - // Handle shape broadcasting and inference. - Value extended_operand_casted = - createBroadcastToKnownRank(if_builder, op, shape, targeted_rank); - - // 1. Reshape operands to the given rank (with the same number of - // elements) - // 2. Compute the ranked-broadcasted ChloOp (which will assert that the - // ops - // can be broadcasted and do the actual broadcasting) - // 3. Type erase the output back to unranked - auto reshaped_type = RankedTensorType::get( - dynamic_dimensions, - operand.getType().template dyn_cast().getElementType()); - Value reshaped_operand = if_builder.create( - loc, reshaped_type, operand, extended_operand_casted); - reshaped_operands.push_back(reshaped_operand); - } - auto result_element_type = op.getResult() - .getType() - .template dyn_cast() - .getElementType(); - auto result_type = - RankedTensorType::get(dynamic_dimensions, result_element_type); - Value result = if_builder.create( - loc, ArrayRef{result_type}, reshaped_operands, op->getAttrs()); - Value reshaped_result = if_builder.create( - loc, UnrankedTensorType::get(result_element_type), result); - if_builder.create(loc, reshaped_result); - } - - // Iterates over the desired ranks to be specialized and generates the code - // snippet for each case. - static Value HandleBroadcastAndOp(OpBuilder &rewriter, ChloOpTy op, - ValueRange operands) { - auto loc = op.getLoc(); - - // Get the minimum broadcast shapes of the operands. - SmallVector shapes; - shapes.reserve(operands.size()); - auto extent_tensor_type = RankedTensorType::get({ShapedType::kDynamicSize}, - rewriter.getIndexType()); - for (Value operand : operands) { - Value shape = - rewriter.create(loc, extent_tensor_type, operand); - shapes.push_back(shape); - } - auto broadcast_shape = rewriter.create( - loc, extent_tensor_type, shapes, nullptr); - SmallVector result_types(shapes.size(), extent_tensor_type); - auto reduced_shapes = - rewriter - .create(loc, result_types, shapes) - .results(); - SmallVector reshaped_operands; - reshaped_operands.reserve(operands.size()); - for (auto it : llvm::zip(operands, reduced_shapes)) { - Value operand; - Value reduced_shape; - std::tie(operand, reduced_shape) = it; - auto reshaped_operand = rewriter.create( - loc, operand.getType(), operand, reduced_shape); - reshaped_operands.push_back(reshaped_operand); - } - - // Find the largest rank of the operands. - Value greater_rank; - for (Value shape : reduced_shapes) { - Value rank = - rewriter.create(loc, rewriter.getIndexType(), shape); - if (!greater_rank) { - greater_rank = rank; - } else { - Value greater_rank_compare = rewriter.create( - loc, CmpIPredicate::sgt, greater_rank, rank); - greater_rank = rewriter.create(loc, greater_rank_compare, - greater_rank, rank); - } - } - - // Generate a list of nested if/else statements to handle rank - // specializations from 1 to `kMaxRankSpecialization`. - scf::IfOp if_op = createIfOpForRankSpecializedBroadcastAndOp( - rewriter, op, greater_rank, 1); - OpBuilder if_builder = if_op.getThenBodyBuilder(rewriter.getListener()); - createRankSpecializedBroadcastAndOp(if_builder, op, reshaped_operands, - reduced_shapes, 1); - - // Put each subsequent rank specialization inside the else statement of the - // previous one. - OpBuilder else_builder = if_op.getElseBodyBuilder(rewriter.getListener()); - constexpr int kMaxRankSpecialization = 5; - for (int i = 2; i < kMaxRankSpecialization; i++) { - auto inner_if = createIfOpForRankSpecializedBroadcastAndOp( - else_builder, op, greater_rank, i); - if_builder = inner_if.getThenBodyBuilder(rewriter.getListener()); - createRankSpecializedBroadcastAndOp(if_builder, op, reshaped_operands, - reduced_shapes, i); - else_builder.create(loc, inner_if.getResult(0)); - else_builder = inner_if.getElseBodyBuilder(rewriter.getListener()); - } - // Fire an assertion if none of the rank specializations applied (one of - // the ranks was greater than `kMaxRankSpecialization`). - else_builder.create( - loc, - GreaterRankIsN(else_builder, op.getLoc(), greater_rank, - kMaxRankSpecialization), - "Input for dynamic binary op lowering was of a rank greater than " + - std::to_string(kMaxRankSpecialization)); - // Add the rank 5 specialization to the innermost else block. - createRankSpecializedBroadcastAndOp(else_builder, op, reshaped_operands, - reduced_shapes, kMaxRankSpecialization); - - // Return the reshaped result of the outermost if statement. - auto result = if_op.getResult(0); - auto reshaped_result = rewriter.create( - loc, result.getType(), result, broadcast_shape); - return reshaped_result; - } -}; - // Handles lowering of the following pattern to patterns that will be further // matched by other patterns until they result in LHLO: // %result = "chlo.op"(%op0, %op1, ...) : (<*xTy>, <*xTy>, ...) -> <*xTy> @@ -498,8 +324,7 @@ struct ConvertUnrankedDynamicBroadcastNaryOp if_eq_shapes_op.getElseBodyBuilder(rewriter.getListener()); if_neq_shapes_builder.create( loc, - ConvertUnrankedDynamicBroadcastOpHelper::HandleBroadcastAndOp( - if_neq_shapes_builder, op, transformed_operands)); + HandleBroadcastAndOp(if_neq_shapes_builder, op, transformed_operands)); rewriter.replaceOp(op, {if_op.getResult(0)}); return success(); @@ -529,6 +354,177 @@ struct ConvertUnrankedDynamicBroadcastNaryOp return builder.create(loc, result_type, value, broadcast_shape); } + + // Returns the dynamic result of checking the given value is effectively a + // scalar shape (i.e. the number of elements is 1). + Value GreaterRankIsN(OpBuilder &builder, Location loc, Value actual_rank, + int targeted_rank) const { + return builder.create( + loc, CmpIPredicate::eq, actual_rank, + builder.create(loc, targeted_rank)); + } + + scf::IfOp createIfOpForRankSpecializedBroadcastAndOp( + OpBuilder &builder, ChloOpTy op, Value actual_rank, + int targeted_rank) const { + // Create the if block to place the current specialized logic in. + Value greater_rank_is_n = + GreaterRankIsN(builder, op.getLoc(), actual_rank, targeted_rank); + return builder.create(op.getLoc(), op.getResult().getType(), + greater_rank_is_n, true); + } + + Value createBroadcastToKnownRank(OpBuilder &builder, ChloOpTy op, Value shape, + int targeted_rank) const { + auto loc = op.getLoc(); + SmallVector ranked_shape(targeted_rank, 1); + auto unknown_rank_extent_tensor_type = RankedTensorType::get( + {RankedTensorType::kDynamicSize}, builder.getIndexType()); + auto known_rank_extent_tensor_type = + RankedTensorType::get({targeted_rank}, builder.getIndexType()); + Value ranked_shape_val = builder.create( + loc, known_rank_extent_tensor_type, + mlir::DenseIntElementsAttr::get(known_rank_extent_tensor_type, + ranked_shape)); + Value extended_value = builder.create( + loc, unknown_rank_extent_tensor_type, shape, ranked_shape_val, nullptr); + return builder.create(loc, known_rank_extent_tensor_type, + extended_value); + } + + // Create the if statement and code for a broadcasting op with a result of a + // given rank. + void createRankSpecializedBroadcastAndOp(OpBuilder &if_builder, ChloOpTy op, + ValueRange operands, + ValueRange operand_shapes, + int targeted_rank) const { + auto loc = op.getLoc(); + SmallVector reshaped_operands; + + auto dynamic_dimensions = llvm::SmallVector( + targeted_rank, RankedTensorType::kDynamicSize); + + for (auto it : llvm::zip(operands, operand_shapes)) { + Value operand, shape; + std::tie(operand, shape) = it; + // Handle shape broadcasting and inference. + Value extended_operand_casted = + createBroadcastToKnownRank(if_builder, op, shape, targeted_rank); + + // 1. Reshape operands to the given rank (with the same number of + // elements) + // 2. Compute the ranked-broadcasted ChloOp (which will assert that the + // ops + // can be broadcasted and do the actual broadcasting) + // 3. Type erase the output back to unranked + auto reshaped_type = RankedTensorType::get( + dynamic_dimensions, + operand.getType().template dyn_cast().getElementType()); + Value reshaped_operand = if_builder.create( + loc, reshaped_type, operand, extended_operand_casted); + reshaped_operands.push_back(reshaped_operand); + } + auto result_element_type = op.getResult() + .getType() + .template dyn_cast() + .getElementType(); + auto result_type = + RankedTensorType::get(dynamic_dimensions, result_element_type); + Value result = if_builder.create( + loc, ArrayRef{result_type}, reshaped_operands, op->getAttrs()); + Value reshaped_result = if_builder.create( + loc, UnrankedTensorType::get(result_element_type), result); + if_builder.create(loc, reshaped_result); + } + + // Iterates over the desired ranks to be specialized and generates the code + // snippet for each case. + Value HandleBroadcastAndOp(OpBuilder &rewriter, ChloOpTy op, + ValueRange operands) const { + auto loc = op.getLoc(); + + // Get the minimum broadcast shapes of the operands. + SmallVector shapes; + shapes.reserve(operands.size()); + auto extent_tensor_type = RankedTensorType::get({ShapedType::kDynamicSize}, + rewriter.getIndexType()); + for (Value operand : operands) { + Value shape = + rewriter.create(loc, extent_tensor_type, operand); + shapes.push_back(shape); + } + auto broadcast_shape = rewriter.create( + loc, extent_tensor_type, shapes, nullptr); + SmallVector result_types(shapes.size(), extent_tensor_type); + auto reduced_shapes = + rewriter + .create(loc, result_types, shapes) + .results(); + SmallVector reshaped_operands; + reshaped_operands.reserve(operands.size()); + for (auto it : llvm::zip(operands, reduced_shapes)) { + Value operand; + Value reduced_shape; + std::tie(operand, reduced_shape) = it; + auto reshaped_operand = rewriter.create( + loc, operand.getType(), operand, reduced_shape); + reshaped_operands.push_back(reshaped_operand); + } + + // Find the largest rank of the operands. + Value greater_rank; + for (Value shape : reduced_shapes) { + Value rank = + rewriter.create(loc, rewriter.getIndexType(), shape); + if (!greater_rank) { + greater_rank = rank; + } else { + Value greater_rank_compare = rewriter.create( + loc, CmpIPredicate::sgt, greater_rank, rank); + greater_rank = rewriter.create(loc, greater_rank_compare, + greater_rank, rank); + } + } + + // Generate a list of nested if/else statements to handle rank + // specializations from 1 to `kMaxRankSpecialization`. + scf::IfOp if_op = createIfOpForRankSpecializedBroadcastAndOp( + rewriter, op, greater_rank, 1); + OpBuilder if_builder = if_op.getThenBodyBuilder(rewriter.getListener()); + createRankSpecializedBroadcastAndOp(if_builder, op, reshaped_operands, + reduced_shapes, 1); + + // Put each subsequent rank specialization inside the else statement of the + // previous one. + OpBuilder else_builder = if_op.getElseBodyBuilder(rewriter.getListener()); + constexpr int kMaxRankSpecialization = 5; + for (int i = 2; i < kMaxRankSpecialization; i++) { + auto inner_if = createIfOpForRankSpecializedBroadcastAndOp( + else_builder, op, greater_rank, i); + if_builder = inner_if.getThenBodyBuilder(rewriter.getListener()); + createRankSpecializedBroadcastAndOp(if_builder, op, reshaped_operands, + reduced_shapes, i); + else_builder.create(loc, inner_if.getResult(0)); + else_builder = inner_if.getElseBodyBuilder(rewriter.getListener()); + } + // Fire an assertion if none of the rank specializations applied (one of + // the ranks was greater than `kMaxRankSpecialization`). + else_builder.create( + loc, + GreaterRankIsN(else_builder, op.getLoc(), greater_rank, + kMaxRankSpecialization), + "Input for dynamic binary op lowering was of a rank greater than " + + std::to_string(kMaxRankSpecialization)); + // Add the rank 5 specialization to the innermost else block. + createRankSpecializedBroadcastAndOp(else_builder, op, reshaped_operands, + reduced_shapes, kMaxRankSpecialization); + + // Return the reshaped result of the outermost if statement. + auto result = if_op.getResult(0); + auto reshaped_result = rewriter.create( + loc, result.getType(), result, broadcast_shape); + return reshaped_result; + } }; struct TransformUnrankedHloPass