Move code from helper struct to the only user.
We don't need the separate helper struct anymore, because it is now only used in one place. PiperOrigin-RevId: 366012639
This commit is contained in:
		
							parent
							
								
									4033a56750
								
							
						
					
					
						commit
						c8157ba4df
					
				|  | @ -206,180 +206,6 @@ struct ConvertUnrankedScalarDynamicBroadcastBinaryOp | |||
|   } | ||||
| }; | ||||
| 
 | ||||
| template <typename ChloOpTy> | ||||
| struct ConvertUnrankedDynamicBroadcastOpHelper { | ||||
|   // Returns the dynamic result of checking the given value is effectively a
 | ||||
|   // scalar shape (i.e. the number of elements is 1).
 | ||||
|   static Value GreaterRankIsN(OpBuilder &builder, Location loc, | ||||
|                               Value actual_rank, int targeted_rank) { | ||||
|     return builder.create<CmpIOp>( | ||||
|         loc, CmpIPredicate::eq, actual_rank, | ||||
|         builder.create<ConstantIndexOp>(loc, targeted_rank)); | ||||
|   } | ||||
| 
 | ||||
|   static scf::IfOp createIfOpForRankSpecializedBroadcastAndOp( | ||||
|       OpBuilder &builder, ChloOpTy op, Value actual_rank, int targeted_rank) { | ||||
|     // Create the if block to place the current specialized logic in.
 | ||||
|     Value greater_rank_is_n = | ||||
|         GreaterRankIsN(builder, op.getLoc(), actual_rank, targeted_rank); | ||||
|     return builder.create<scf::IfOp>(op.getLoc(), op.getResult().getType(), | ||||
|                                      greater_rank_is_n, true); | ||||
|   } | ||||
| 
 | ||||
|   static Value createBroadcastToKnownRank(OpBuilder &builder, ChloOpTy op, | ||||
|                                           Value shape, int targeted_rank) { | ||||
|     auto loc = op.getLoc(); | ||||
|     SmallVector<int64_t, 6> ranked_shape(targeted_rank, 1); | ||||
|     auto unknown_rank_extent_tensor_type = RankedTensorType::get( | ||||
|         {RankedTensorType::kDynamicSize}, builder.getIndexType()); | ||||
|     auto known_rank_extent_tensor_type = | ||||
|         RankedTensorType::get({targeted_rank}, builder.getIndexType()); | ||||
|     Value ranked_shape_val = builder.create<shape::ConstShapeOp>( | ||||
|         loc, known_rank_extent_tensor_type, | ||||
|         mlir::DenseIntElementsAttr::get(known_rank_extent_tensor_type, | ||||
|                                         ranked_shape)); | ||||
|     Value extended_value = builder.create<shape::BroadcastOp>( | ||||
|         loc, unknown_rank_extent_tensor_type, shape, ranked_shape_val, nullptr); | ||||
|     return builder.create<tensor::CastOp>(loc, known_rank_extent_tensor_type, | ||||
|                                           extended_value); | ||||
|   } | ||||
| 
 | ||||
|   // Create the if statement and code for a broadcasting op with a result of a
 | ||||
|   // given rank.
 | ||||
|   static void createRankSpecializedBroadcastAndOp(OpBuilder &if_builder, | ||||
|                                                   ChloOpTy op, | ||||
|                                                   ValueRange operands, | ||||
|                                                   ValueRange operand_shapes, | ||||
|                                                   int targeted_rank) { | ||||
|     auto loc = op.getLoc(); | ||||
|     SmallVector<Value, 2> reshaped_operands; | ||||
| 
 | ||||
|     auto dynamic_dimensions = llvm::SmallVector<int64_t, 6>( | ||||
|         targeted_rank, RankedTensorType::kDynamicSize); | ||||
| 
 | ||||
|     for (auto it : llvm::zip(operands, operand_shapes)) { | ||||
|       Value operand, shape; | ||||
|       std::tie(operand, shape) = it; | ||||
|       // Handle shape broadcasting and inference.
 | ||||
|       Value extended_operand_casted = | ||||
|           createBroadcastToKnownRank(if_builder, op, shape, targeted_rank); | ||||
| 
 | ||||
|       // 1. Reshape operands to the given rank (with the same number of
 | ||||
|       // elements)
 | ||||
|       // 2. Compute the ranked-broadcasted ChloOp (which will assert that the
 | ||||
|       // ops
 | ||||
|       //    can be broadcasted and do the actual broadcasting)
 | ||||
|       // 3. Type erase the output back to unranked
 | ||||
|       auto reshaped_type = RankedTensorType::get( | ||||
|           dynamic_dimensions, | ||||
|           operand.getType().template dyn_cast<TensorType>().getElementType()); | ||||
|       Value reshaped_operand = if_builder.create<mhlo::DynamicReshapeOp>( | ||||
|           loc, reshaped_type, operand, extended_operand_casted); | ||||
|       reshaped_operands.push_back(reshaped_operand); | ||||
|     } | ||||
|     auto result_element_type = op.getResult() | ||||
|                                    .getType() | ||||
|                                    .template dyn_cast<TensorType>() | ||||
|                                    .getElementType(); | ||||
|     auto result_type = | ||||
|         RankedTensorType::get(dynamic_dimensions, result_element_type); | ||||
|     Value result = if_builder.create<ChloOpTy>( | ||||
|         loc, ArrayRef<Type>{result_type}, reshaped_operands, op->getAttrs()); | ||||
|     Value reshaped_result = if_builder.create<tensor::CastOp>( | ||||
|         loc, UnrankedTensorType::get(result_element_type), result); | ||||
|     if_builder.create<scf::YieldOp>(loc, reshaped_result); | ||||
|   } | ||||
| 
 | ||||
|   // Iterates over the desired ranks to be specialized and generates the code
 | ||||
|   // snippet for each case.
 | ||||
|   static Value HandleBroadcastAndOp(OpBuilder &rewriter, ChloOpTy op, | ||||
|                                     ValueRange operands) { | ||||
|     auto loc = op.getLoc(); | ||||
| 
 | ||||
|     // Get the minimum broadcast shapes of the operands.
 | ||||
|     SmallVector<Value> shapes; | ||||
|     shapes.reserve(operands.size()); | ||||
|     auto extent_tensor_type = RankedTensorType::get({ShapedType::kDynamicSize}, | ||||
|                                                     rewriter.getIndexType()); | ||||
|     for (Value operand : operands) { | ||||
|       Value shape = | ||||
|           rewriter.create<shape::ShapeOfOp>(loc, extent_tensor_type, operand); | ||||
|       shapes.push_back(shape); | ||||
|     } | ||||
|     auto broadcast_shape = rewriter.create<shape::BroadcastOp>( | ||||
|         loc, extent_tensor_type, shapes, nullptr); | ||||
|     SmallVector<Type> result_types(shapes.size(), extent_tensor_type); | ||||
|     auto reduced_shapes = | ||||
|         rewriter | ||||
|             .create<chlo::MinimumBroadcastShapesOp>(loc, result_types, shapes) | ||||
|             .results(); | ||||
|     SmallVector<Value> reshaped_operands; | ||||
|     reshaped_operands.reserve(operands.size()); | ||||
|     for (auto it : llvm::zip(operands, reduced_shapes)) { | ||||
|       Value operand; | ||||
|       Value reduced_shape; | ||||
|       std::tie(operand, reduced_shape) = it; | ||||
|       auto reshaped_operand = rewriter.create<mhlo::DynamicReshapeOp>( | ||||
|           loc, operand.getType(), operand, reduced_shape); | ||||
|       reshaped_operands.push_back(reshaped_operand); | ||||
|     } | ||||
| 
 | ||||
|     // Find the largest rank of the operands.
 | ||||
|     Value greater_rank; | ||||
|     for (Value shape : reduced_shapes) { | ||||
|       Value rank = | ||||
|           rewriter.create<shape::RankOp>(loc, rewriter.getIndexType(), shape); | ||||
|       if (!greater_rank) { | ||||
|         greater_rank = rank; | ||||
|       } else { | ||||
|         Value greater_rank_compare = rewriter.create<CmpIOp>( | ||||
|             loc, CmpIPredicate::sgt, greater_rank, rank); | ||||
|         greater_rank = rewriter.create<SelectOp>(loc, greater_rank_compare, | ||||
|                                                  greater_rank, rank); | ||||
|       } | ||||
|     } | ||||
| 
 | ||||
|     // Generate a list of nested if/else statements to handle rank
 | ||||
|     // specializations from 1 to `kMaxRankSpecialization`.
 | ||||
|     scf::IfOp if_op = createIfOpForRankSpecializedBroadcastAndOp( | ||||
|         rewriter, op, greater_rank, 1); | ||||
|     OpBuilder if_builder = if_op.getThenBodyBuilder(rewriter.getListener()); | ||||
|     createRankSpecializedBroadcastAndOp(if_builder, op, reshaped_operands, | ||||
|                                         reduced_shapes, 1); | ||||
| 
 | ||||
|     // Put each subsequent rank specialization inside the else statement of the
 | ||||
|     // previous one.
 | ||||
|     OpBuilder else_builder = if_op.getElseBodyBuilder(rewriter.getListener()); | ||||
|     constexpr int kMaxRankSpecialization = 5; | ||||
|     for (int i = 2; i < kMaxRankSpecialization; i++) { | ||||
|       auto inner_if = createIfOpForRankSpecializedBroadcastAndOp( | ||||
|           else_builder, op, greater_rank, i); | ||||
|       if_builder = inner_if.getThenBodyBuilder(rewriter.getListener()); | ||||
|       createRankSpecializedBroadcastAndOp(if_builder, op, reshaped_operands, | ||||
|                                           reduced_shapes, i); | ||||
|       else_builder.create<scf::YieldOp>(loc, inner_if.getResult(0)); | ||||
|       else_builder = inner_if.getElseBodyBuilder(rewriter.getListener()); | ||||
|     } | ||||
|     // Fire an assertion if none of the rank specializations applied (one of
 | ||||
|     // the ranks was greater than `kMaxRankSpecialization`).
 | ||||
|     else_builder.create<AssertOp>( | ||||
|         loc, | ||||
|         GreaterRankIsN(else_builder, op.getLoc(), greater_rank, | ||||
|                        kMaxRankSpecialization), | ||||
|         "Input for dynamic binary op lowering was of a rank greater than " + | ||||
|             std::to_string(kMaxRankSpecialization)); | ||||
|     // Add the rank 5 specialization to the innermost else block.
 | ||||
|     createRankSpecializedBroadcastAndOp(else_builder, op, reshaped_operands, | ||||
|                                         reduced_shapes, kMaxRankSpecialization); | ||||
| 
 | ||||
|     // Return the reshaped result of the outermost if statement.
 | ||||
|     auto result = if_op.getResult(0); | ||||
|     auto reshaped_result = rewriter.create<mhlo::DynamicReshapeOp>( | ||||
|         loc, result.getType(), result, broadcast_shape); | ||||
|     return reshaped_result; | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| // Handles lowering of the following pattern to patterns that will be further
 | ||||
| // matched by other patterns until they result in LHLO:
 | ||||
| //   %result = "chlo.op"(%op0, %op1, ...) : (<*xTy>, <*xTy>, ...) -> <*xTy>
 | ||||
|  | @ -498,8 +324,7 @@ struct ConvertUnrankedDynamicBroadcastNaryOp | |||
|         if_eq_shapes_op.getElseBodyBuilder(rewriter.getListener()); | ||||
|     if_neq_shapes_builder.create<scf::YieldOp>( | ||||
|         loc, | ||||
|         ConvertUnrankedDynamicBroadcastOpHelper<ChloOpTy>::HandleBroadcastAndOp( | ||||
|             if_neq_shapes_builder, op, transformed_operands)); | ||||
|         HandleBroadcastAndOp(if_neq_shapes_builder, op, transformed_operands)); | ||||
| 
 | ||||
|     rewriter.replaceOp(op, {if_op.getResult(0)}); | ||||
|     return success(); | ||||
|  | @ -529,6 +354,177 @@ struct ConvertUnrankedDynamicBroadcastNaryOp | |||
|     return builder.create<mhlo::DynamicReshapeOp>(loc, result_type, value, | ||||
|                                                   broadcast_shape); | ||||
|   } | ||||
| 
 | ||||
|   // Returns the dynamic result of checking the given value is effectively a
 | ||||
|   // scalar shape (i.e. the number of elements is 1).
 | ||||
|   Value GreaterRankIsN(OpBuilder &builder, Location loc, Value actual_rank, | ||||
|                        int targeted_rank) const { | ||||
|     return builder.create<CmpIOp>( | ||||
|         loc, CmpIPredicate::eq, actual_rank, | ||||
|         builder.create<ConstantIndexOp>(loc, targeted_rank)); | ||||
|   } | ||||
| 
 | ||||
|   scf::IfOp createIfOpForRankSpecializedBroadcastAndOp( | ||||
|       OpBuilder &builder, ChloOpTy op, Value actual_rank, | ||||
|       int targeted_rank) const { | ||||
|     // Create the if block to place the current specialized logic in.
 | ||||
|     Value greater_rank_is_n = | ||||
|         GreaterRankIsN(builder, op.getLoc(), actual_rank, targeted_rank); | ||||
|     return builder.create<scf::IfOp>(op.getLoc(), op.getResult().getType(), | ||||
|                                      greater_rank_is_n, true); | ||||
|   } | ||||
| 
 | ||||
|   Value createBroadcastToKnownRank(OpBuilder &builder, ChloOpTy op, Value shape, | ||||
|                                    int targeted_rank) const { | ||||
|     auto loc = op.getLoc(); | ||||
|     SmallVector<int64_t, 6> ranked_shape(targeted_rank, 1); | ||||
|     auto unknown_rank_extent_tensor_type = RankedTensorType::get( | ||||
|         {RankedTensorType::kDynamicSize}, builder.getIndexType()); | ||||
|     auto known_rank_extent_tensor_type = | ||||
|         RankedTensorType::get({targeted_rank}, builder.getIndexType()); | ||||
|     Value ranked_shape_val = builder.create<shape::ConstShapeOp>( | ||||
|         loc, known_rank_extent_tensor_type, | ||||
|         mlir::DenseIntElementsAttr::get(known_rank_extent_tensor_type, | ||||
|                                         ranked_shape)); | ||||
|     Value extended_value = builder.create<shape::BroadcastOp>( | ||||
|         loc, unknown_rank_extent_tensor_type, shape, ranked_shape_val, nullptr); | ||||
|     return builder.create<tensor::CastOp>(loc, known_rank_extent_tensor_type, | ||||
|                                           extended_value); | ||||
|   } | ||||
| 
 | ||||
|   // Create the if statement and code for a broadcasting op with a result of a
 | ||||
|   // given rank.
 | ||||
|   void createRankSpecializedBroadcastAndOp(OpBuilder &if_builder, ChloOpTy op, | ||||
|                                            ValueRange operands, | ||||
|                                            ValueRange operand_shapes, | ||||
|                                            int targeted_rank) const { | ||||
|     auto loc = op.getLoc(); | ||||
|     SmallVector<Value, 2> reshaped_operands; | ||||
| 
 | ||||
|     auto dynamic_dimensions = llvm::SmallVector<int64_t, 6>( | ||||
|         targeted_rank, RankedTensorType::kDynamicSize); | ||||
| 
 | ||||
|     for (auto it : llvm::zip(operands, operand_shapes)) { | ||||
|       Value operand, shape; | ||||
|       std::tie(operand, shape) = it; | ||||
|       // Handle shape broadcasting and inference.
 | ||||
|       Value extended_operand_casted = | ||||
|           createBroadcastToKnownRank(if_builder, op, shape, targeted_rank); | ||||
| 
 | ||||
|       // 1. Reshape operands to the given rank (with the same number of
 | ||||
|       // elements)
 | ||||
|       // 2. Compute the ranked-broadcasted ChloOp (which will assert that the
 | ||||
|       // ops
 | ||||
|       //    can be broadcasted and do the actual broadcasting)
 | ||||
|       // 3. Type erase the output back to unranked
 | ||||
|       auto reshaped_type = RankedTensorType::get( | ||||
|           dynamic_dimensions, | ||||
|           operand.getType().template dyn_cast<TensorType>().getElementType()); | ||||
|       Value reshaped_operand = if_builder.create<mhlo::DynamicReshapeOp>( | ||||
|           loc, reshaped_type, operand, extended_operand_casted); | ||||
|       reshaped_operands.push_back(reshaped_operand); | ||||
|     } | ||||
|     auto result_element_type = op.getResult() | ||||
|                                    .getType() | ||||
|                                    .template dyn_cast<TensorType>() | ||||
|                                    .getElementType(); | ||||
|     auto result_type = | ||||
|         RankedTensorType::get(dynamic_dimensions, result_element_type); | ||||
|     Value result = if_builder.create<ChloOpTy>( | ||||
|         loc, ArrayRef<Type>{result_type}, reshaped_operands, op->getAttrs()); | ||||
|     Value reshaped_result = if_builder.create<tensor::CastOp>( | ||||
|         loc, UnrankedTensorType::get(result_element_type), result); | ||||
|     if_builder.create<scf::YieldOp>(loc, reshaped_result); | ||||
|   } | ||||
| 
 | ||||
|   // Iterates over the desired ranks to be specialized and generates the code
 | ||||
|   // snippet for each case.
 | ||||
|   Value HandleBroadcastAndOp(OpBuilder &rewriter, ChloOpTy op, | ||||
|                              ValueRange operands) const { | ||||
|     auto loc = op.getLoc(); | ||||
| 
 | ||||
|     // Get the minimum broadcast shapes of the operands.
 | ||||
|     SmallVector<Value> shapes; | ||||
|     shapes.reserve(operands.size()); | ||||
|     auto extent_tensor_type = RankedTensorType::get({ShapedType::kDynamicSize}, | ||||
|                                                     rewriter.getIndexType()); | ||||
|     for (Value operand : operands) { | ||||
|       Value shape = | ||||
|           rewriter.create<shape::ShapeOfOp>(loc, extent_tensor_type, operand); | ||||
|       shapes.push_back(shape); | ||||
|     } | ||||
|     auto broadcast_shape = rewriter.create<shape::BroadcastOp>( | ||||
|         loc, extent_tensor_type, shapes, nullptr); | ||||
|     SmallVector<Type> result_types(shapes.size(), extent_tensor_type); | ||||
|     auto reduced_shapes = | ||||
|         rewriter | ||||
|             .create<chlo::MinimumBroadcastShapesOp>(loc, result_types, shapes) | ||||
|             .results(); | ||||
|     SmallVector<Value> reshaped_operands; | ||||
|     reshaped_operands.reserve(operands.size()); | ||||
|     for (auto it : llvm::zip(operands, reduced_shapes)) { | ||||
|       Value operand; | ||||
|       Value reduced_shape; | ||||
|       std::tie(operand, reduced_shape) = it; | ||||
|       auto reshaped_operand = rewriter.create<mhlo::DynamicReshapeOp>( | ||||
|           loc, operand.getType(), operand, reduced_shape); | ||||
|       reshaped_operands.push_back(reshaped_operand); | ||||
|     } | ||||
| 
 | ||||
|     // Find the largest rank of the operands.
 | ||||
|     Value greater_rank; | ||||
|     for (Value shape : reduced_shapes) { | ||||
|       Value rank = | ||||
|           rewriter.create<shape::RankOp>(loc, rewriter.getIndexType(), shape); | ||||
|       if (!greater_rank) { | ||||
|         greater_rank = rank; | ||||
|       } else { | ||||
|         Value greater_rank_compare = rewriter.create<CmpIOp>( | ||||
|             loc, CmpIPredicate::sgt, greater_rank, rank); | ||||
|         greater_rank = rewriter.create<SelectOp>(loc, greater_rank_compare, | ||||
|                                                  greater_rank, rank); | ||||
|       } | ||||
|     } | ||||
| 
 | ||||
|     // Generate a list of nested if/else statements to handle rank
 | ||||
|     // specializations from 1 to `kMaxRankSpecialization`.
 | ||||
|     scf::IfOp if_op = createIfOpForRankSpecializedBroadcastAndOp( | ||||
|         rewriter, op, greater_rank, 1); | ||||
|     OpBuilder if_builder = if_op.getThenBodyBuilder(rewriter.getListener()); | ||||
|     createRankSpecializedBroadcastAndOp(if_builder, op, reshaped_operands, | ||||
|                                         reduced_shapes, 1); | ||||
| 
 | ||||
|     // Put each subsequent rank specialization inside the else statement of the
 | ||||
|     // previous one.
 | ||||
|     OpBuilder else_builder = if_op.getElseBodyBuilder(rewriter.getListener()); | ||||
|     constexpr int kMaxRankSpecialization = 5; | ||||
|     for (int i = 2; i < kMaxRankSpecialization; i++) { | ||||
|       auto inner_if = createIfOpForRankSpecializedBroadcastAndOp( | ||||
|           else_builder, op, greater_rank, i); | ||||
|       if_builder = inner_if.getThenBodyBuilder(rewriter.getListener()); | ||||
|       createRankSpecializedBroadcastAndOp(if_builder, op, reshaped_operands, | ||||
|                                           reduced_shapes, i); | ||||
|       else_builder.create<scf::YieldOp>(loc, inner_if.getResult(0)); | ||||
|       else_builder = inner_if.getElseBodyBuilder(rewriter.getListener()); | ||||
|     } | ||||
|     // Fire an assertion if none of the rank specializations applied (one of
 | ||||
|     // the ranks was greater than `kMaxRankSpecialization`).
 | ||||
|     else_builder.create<AssertOp>( | ||||
|         loc, | ||||
|         GreaterRankIsN(else_builder, op.getLoc(), greater_rank, | ||||
|                        kMaxRankSpecialization), | ||||
|         "Input for dynamic binary op lowering was of a rank greater than " + | ||||
|             std::to_string(kMaxRankSpecialization)); | ||||
|     // Add the rank 5 specialization to the innermost else block.
 | ||||
|     createRankSpecializedBroadcastAndOp(else_builder, op, reshaped_operands, | ||||
|                                         reduced_shapes, kMaxRankSpecialization); | ||||
| 
 | ||||
|     // Return the reshaped result of the outermost if statement.
 | ||||
|     auto result = if_op.getResult(0); | ||||
|     auto reshaped_result = rewriter.create<mhlo::DynamicReshapeOp>( | ||||
|         loc, result.getType(), result, broadcast_shape); | ||||
|     return reshaped_result; | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| struct TransformUnrankedHloPass | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue