[mlir][hlo] Refactor rank specialization to allow an arbitrary number of inputs
This actually simplifies the code a bit. PiperOrigin-RevId: 358201038
This commit is contained in:
		
							parent
							
								
									ca4034b56e
								
							
						
					
					
						commit
						b42def4612
					
				|  | @ -202,6 +202,149 @@ struct ConvertUnrankedScalarDynamicBroadcastBinaryOp | |||
|   } | ||||
| }; | ||||
| 
 | ||||
| template <typename ChloOpTy, typename HloOpTy> | ||||
| struct ConvertUnrankedDynamicBroadcastOpHelper { | ||||
|   // Returns the dynamic result of checking the given value is effectively a
 | ||||
|   // scalar shape (i.e. the number of elements is 1).
 | ||||
|   static Value GreaterRankIsN(OpBuilder &builder, Location loc, | ||||
|                               Value actual_rank, int targeted_rank) { | ||||
|     return builder.create<CmpIOp>( | ||||
|         loc, CmpIPredicate::eq, actual_rank, | ||||
|         builder.create<ConstantIndexOp>(loc, targeted_rank)); | ||||
|   } | ||||
| 
 | ||||
|   static scf::IfOp createIfOpForRankSpecializedBroadcastAndOp( | ||||
|       OpBuilder &builder, ChloOpTy op, Value actual_rank, int targeted_rank) { | ||||
|     // Create the if block to place the current specialized logic in.
 | ||||
|     Value greater_rank_is_n = | ||||
|         GreaterRankIsN(builder, op.getLoc(), actual_rank, targeted_rank); | ||||
|     return builder.create<scf::IfOp>(op.getLoc(), op.getResult().getType(), | ||||
|                                      greater_rank_is_n, true); | ||||
|   } | ||||
| 
 | ||||
|   static Value createBroadcastToKnownRank(OpBuilder &builder, ChloOpTy op, | ||||
|                                           Value value, int targeted_rank) { | ||||
|     auto loc = op.getLoc(); | ||||
|     Value shape = builder.create<shape::ShapeOfOp>(loc, value); | ||||
|     SmallVector<int64_t, 6> ranked_shape(targeted_rank, 1); | ||||
|     auto unknown_rank_extent_tensor_type = RankedTensorType::get( | ||||
|         {RankedTensorType::kDynamicSize}, builder.getIndexType()); | ||||
|     auto known_rank_extent_tensor_type = | ||||
|         RankedTensorType::get({targeted_rank}, builder.getIndexType()); | ||||
|     Value ranked_shape_val = builder.create<shape::ConstShapeOp>( | ||||
|         loc, known_rank_extent_tensor_type, | ||||
|         mlir::DenseIntElementsAttr::get(known_rank_extent_tensor_type, | ||||
|                                         ranked_shape)); | ||||
|     Value extended_value = builder.create<shape::BroadcastOp>( | ||||
|         loc, unknown_rank_extent_tensor_type, shape, ranked_shape_val, nullptr); | ||||
|     return builder.create<tensor::CastOp>(loc, known_rank_extent_tensor_type, | ||||
|                                           extended_value); | ||||
|   } | ||||
| 
 | ||||
|   // Create the if statement and code for a broadcasting op with a result of a
 | ||||
|   // given rank.
 | ||||
|   static void createRankSpecializedBroadcastAndOp(OpBuilder &if_builder, | ||||
|                                                   ChloOpTy op, | ||||
|                                                   ValueRange operands, | ||||
|                                                   int targeted_rank) { | ||||
|     auto loc = op.getLoc(); | ||||
|     SmallVector<Value, 2> reshaped_operands; | ||||
| 
 | ||||
|     auto dynamic_dimensions = llvm::SmallVector<int64_t, 6>( | ||||
|         targeted_rank, RankedTensorType::kDynamicSize); | ||||
| 
 | ||||
|     for (Value operand : operands) { | ||||
|       // Handle shape broadcasting and inference.
 | ||||
|       Value extended_operand_casted = | ||||
|           createBroadcastToKnownRank(if_builder, op, operand, targeted_rank); | ||||
| 
 | ||||
|       // 1. Reshape operands to the given rank (with the same number of
 | ||||
|       // elements)
 | ||||
|       // 2. Compute the ranked-broadcasted ChloOp (which will assert that the
 | ||||
|       // ops
 | ||||
|       //    can be broadcasted and do the actual broadcasting)
 | ||||
|       // 3. Type erase the output back to unranked
 | ||||
|       auto reshaped_type = RankedTensorType::get( | ||||
|           dynamic_dimensions, | ||||
|           operand.getType().template dyn_cast<TensorType>().getElementType()); | ||||
|       Value reshaped_operand = if_builder.create<mhlo::DynamicReshapeOp>( | ||||
|           loc, reshaped_type, operand, extended_operand_casted); | ||||
|       reshaped_operands.push_back(reshaped_operand); | ||||
|     } | ||||
|     auto result_element_type = op.getResult() | ||||
|                                    .getType() | ||||
|                                    .template dyn_cast<TensorType>() | ||||
|                                    .getElementType(); | ||||
|     auto result_type = | ||||
|         RankedTensorType::get(dynamic_dimensions, result_element_type); | ||||
|     Value result = if_builder.create<ChloOpTy>( | ||||
|         loc, ArrayRef<Type>{result_type}, reshaped_operands, op.getAttrs()); | ||||
|     Value reshaped_result = if_builder.create<tensor::CastOp>( | ||||
|         loc, UnrankedTensorType::get(result_element_type), result); | ||||
|     if_builder.create<scf::YieldOp>(loc, reshaped_result); | ||||
|   } | ||||
| 
 | ||||
|   // Iterates over the desired ranks to be specialized and generates the code
 | ||||
|   // snippet for each case.
 | ||||
|   static Value HandleBroadcastAndOp(OpBuilder &rewriter, ChloOpTy op, | ||||
|                                     ValueRange operands) { | ||||
|     auto loc = op.getLoc(); | ||||
| 
 | ||||
|     // Find the larger rank of the operands.
 | ||||
|     auto extent_tensor_type = RankedTensorType::get({ShapedType::kDynamicSize}, | ||||
|                                                     rewriter.getIndexType()); | ||||
|     Value greater_rank; | ||||
|     for (Value operand : operands) { | ||||
|       Value shape = | ||||
|           rewriter.create<shape::ShapeOfOp>(loc, extent_tensor_type, operand); | ||||
|       Value rank = | ||||
|           rewriter.create<shape::RankOp>(loc, rewriter.getIndexType(), shape); | ||||
|       if (!greater_rank) { | ||||
|         greater_rank = rank; | ||||
|       } else { | ||||
|         Value greater_rank_compare = rewriter.create<CmpIOp>( | ||||
|             loc, CmpIPredicate::sgt, greater_rank, rank); | ||||
|         greater_rank = rewriter.create<SelectOp>(loc, greater_rank_compare, | ||||
|                                                  greater_rank, rank); | ||||
|       } | ||||
|     } | ||||
| 
 | ||||
|     // Generate a list of nested if/else statements to handle rank
 | ||||
|     // specializations from 1 to `kMaxRankSpecialization`.
 | ||||
|     scf::IfOp if_op = createIfOpForRankSpecializedBroadcastAndOp( | ||||
|         rewriter, op, greater_rank, 1); | ||||
|     OpBuilder if_builder = if_op.getThenBodyBuilder(rewriter.getListener()); | ||||
|     createRankSpecializedBroadcastAndOp(if_builder, op, operands, 1); | ||||
| 
 | ||||
|     // Put each subsequent rank specialization inside the else statement of the
 | ||||
|     // previous one.
 | ||||
|     OpBuilder else_builder = if_op.getElseBodyBuilder(rewriter.getListener()); | ||||
|     constexpr int kMaxRankSpecialization = 6; | ||||
|     for (int i = 2; i < kMaxRankSpecialization; i++) { | ||||
|       auto inner_if = createIfOpForRankSpecializedBroadcastAndOp( | ||||
|           else_builder, op, greater_rank, i); | ||||
|       if_builder = inner_if.getThenBodyBuilder(rewriter.getListener()); | ||||
|       createRankSpecializedBroadcastAndOp(if_builder, op, operands, i); | ||||
|       else_builder.create<scf::YieldOp>(loc, inner_if.getResult(0)); | ||||
|       else_builder = inner_if.getElseBodyBuilder(rewriter.getListener()); | ||||
|     } | ||||
|     // Fire an assertion if none of the rank specializations applied (one of
 | ||||
|     // the ranks was greater than `kMaxRankSpecialization`).
 | ||||
|     else_builder.create<AssertOp>( | ||||
|         loc, | ||||
|         GreaterRankIsN(else_builder, op.getLoc(), greater_rank, | ||||
|                        kMaxRankSpecialization), | ||||
|         "Input for dynamic binary op lowering was of a rank greater than " + | ||||
|             std::to_string(kMaxRankSpecialization)); | ||||
|     // Add the rank 6 specialization to the innermost else block.
 | ||||
|     createRankSpecializedBroadcastAndOp(else_builder, op, operands, | ||||
|                                         kMaxRankSpecialization); | ||||
| 
 | ||||
|     // Return the result of the outermost if statement.
 | ||||
|     return if_op.getResult(0); | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| // Handles lowering of the following pattern to patterns that will be further
 | ||||
| // matched by other patterns until they result in LHLO:
 | ||||
| //   %result = "chlo.op"(%lhs, %rhs) : (<*xTy>, <*xTy>) -> <*xTy>
 | ||||
|  | @ -298,7 +441,9 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp | |||
|     OpBuilder if_neq_shapes_builder = | ||||
|         if_eq_shapes_op.getElseBodyBuilder(rewriter.getListener()); | ||||
|     if_neq_shapes_builder.create<scf::YieldOp>( | ||||
|         loc, HandleBroadcastAndOp(if_neq_shapes_builder, op, lhs, rhs)); | ||||
|         loc, ConvertUnrankedDynamicBroadcastOpHelper< | ||||
|                  ChloOpTy, HloOpTy>::HandleBroadcastAndOp(if_neq_shapes_builder, | ||||
|                                                           op, {lhs, rhs})); | ||||
| 
 | ||||
|     rewriter.replaceOp(op, {if_op.getResult(0)}); | ||||
|     return success(); | ||||
|  | @ -318,23 +463,6 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp | |||
|                                    rewriter.create<ConstantIndexOp>(loc, 1)); | ||||
|   } | ||||
| 
 | ||||
|   Value GreaterRankIsN(OpBuilder &builder, Location loc, Value actual_rank, | ||||
|                        int targeted_rank) const { | ||||
|     return builder.create<CmpIOp>( | ||||
|         loc, CmpIPredicate::eq, actual_rank, | ||||
|         builder.create<ConstantIndexOp>(loc, targeted_rank)); | ||||
|   } | ||||
| 
 | ||||
|   scf::IfOp createIfOpForRankSpecializedBroadcastAndOp( | ||||
|       OpBuilder &builder, ChloOpTy op, Value actual_rank, | ||||
|       int targeted_rank) const { | ||||
|     // Create the if block to place the current specialized logic in.
 | ||||
|     Value greater_rank_is_n = | ||||
|         GreaterRankIsN(builder, op.getLoc(), actual_rank, targeted_rank); | ||||
|     return builder.create<scf::IfOp>(op.getLoc(), op.getResult().getType(), | ||||
|                                      greater_rank_is_n, true); | ||||
|   } | ||||
| 
 | ||||
|   Value extendToBroadcastShape(OpBuilder &builder, Location loc, Value value, | ||||
|                                Value shape_of_lhs, Value shape_of_rhs) const { | ||||
|     auto unknown_rank_extent_tensor_type = RankedTensorType::get( | ||||
|  | @ -345,122 +473,6 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp | |||
|     return builder.create<mhlo::DynamicReshapeOp>(loc, value.getType(), value, | ||||
|                                                   broadcast_shape); | ||||
|   } | ||||
| 
 | ||||
|   Value createBroadcastToKnownRank(OpBuilder &builder, ChloOpTy op, Value value, | ||||
|                                    int targeted_rank) const { | ||||
|     auto loc = op.getLoc(); | ||||
|     Value shape = builder.create<shape::ShapeOfOp>(loc, value); | ||||
|     SmallVector<int64_t, 6> ranked_shape(targeted_rank, 1); | ||||
|     auto unknown_rank_extent_tensor_type = RankedTensorType::get( | ||||
|         {RankedTensorType::kDynamicSize}, builder.getIndexType()); | ||||
|     auto known_rank_extent_tensor_type = | ||||
|         RankedTensorType::get({targeted_rank}, builder.getIndexType()); | ||||
|     Value ranked_shape_val = builder.create<shape::ConstShapeOp>( | ||||
|         loc, known_rank_extent_tensor_type, | ||||
|         mlir::DenseIntElementsAttr::get(known_rank_extent_tensor_type, | ||||
|                                         ranked_shape)); | ||||
|     Value extended_value = builder.create<shape::BroadcastOp>( | ||||
|         loc, unknown_rank_extent_tensor_type, shape, ranked_shape_val, nullptr); | ||||
|     return builder.create<tensor::CastOp>(loc, known_rank_extent_tensor_type, | ||||
|                                           extended_value); | ||||
|   } | ||||
| 
 | ||||
|   // Create the if statement and code for a broadcasting op with a result of a
 | ||||
|   // given rank.
 | ||||
|   void createRankSpecializedBroadcastAndOp(OpBuilder &if_builder, ChloOpTy op, | ||||
|                                            Value lhs, Value rhs, | ||||
|                                            int targeted_rank) const { | ||||
|     auto loc = op.getLoc(); | ||||
| 
 | ||||
|     // Handle shape broadcasting and inference.
 | ||||
|     Value extended_lhs_casted = | ||||
|         createBroadcastToKnownRank(if_builder, op, lhs, targeted_rank); | ||||
|     Value extended_rhs_casted = | ||||
|         createBroadcastToKnownRank(if_builder, op, rhs, targeted_rank); | ||||
|     auto dynamic_dimensions = llvm::SmallVector<int64_t, 6>( | ||||
|         targeted_rank, RankedTensorType::kDynamicSize); | ||||
|     auto reshaped_type = RankedTensorType::get( | ||||
|         dynamic_dimensions, | ||||
|         lhs.getType().template dyn_cast<TensorType>().getElementType()); | ||||
| 
 | ||||
|     // 1. Reshape operands to the given rank (with the same number of elements)
 | ||||
|     // 2. Compute the ranked-broadcasted ChloOp (which will assert that the ops
 | ||||
|     //    can be broadcasted and do the actual broadcasting)
 | ||||
|     // 3. Type erase the output back to unranked
 | ||||
|     Value reshaped_lhs = if_builder.create<mhlo::DynamicReshapeOp>( | ||||
|         loc, reshaped_type, lhs, extended_lhs_casted); | ||||
|     Value reshaped_rhs = if_builder.create<mhlo::DynamicReshapeOp>( | ||||
|         loc, reshaped_type, rhs, extended_rhs_casted); | ||||
|     auto result_element_type = op.getResult() | ||||
|                                    .getType() | ||||
|                                    .template dyn_cast<TensorType>() | ||||
|                                    .getElementType(); | ||||
|     auto result_type = | ||||
|         RankedTensorType::get(dynamic_dimensions, result_element_type); | ||||
|     Value result = if_builder.create<ChloOpTy>( | ||||
|         loc, ArrayRef<Type>{result_type}, | ||||
|         ArrayRef<Value>{reshaped_lhs, reshaped_rhs}, op.getAttrs()); | ||||
|     Value reshaped_result = if_builder.create<tensor::CastOp>( | ||||
|         loc, UnrankedTensorType::get(result_element_type), result); | ||||
|     if_builder.create<scf::YieldOp>(loc, reshaped_result); | ||||
|   } | ||||
| 
 | ||||
|   // Iterates over the desired ranks to be specialized and generates the code
 | ||||
|   // snippet for each case.
 | ||||
|   Value HandleBroadcastAndOp(OpBuilder &rewriter, ChloOpTy op, Value lhs, | ||||
|                              Value rhs) const { | ||||
|     auto loc = op.getLoc(); | ||||
| 
 | ||||
|     // Find the larger rank of the 2 operands.
 | ||||
|     auto extent_tensor_type = RankedTensorType::get({ShapedType::kDynamicSize}, | ||||
|                                                     rewriter.getIndexType()); | ||||
|     Value lhs_shape = | ||||
|         rewriter.create<shape::ShapeOfOp>(loc, extent_tensor_type, lhs); | ||||
|     Value rhs_shape = | ||||
|         rewriter.create<shape::ShapeOfOp>(loc, extent_tensor_type, rhs); | ||||
|     Value lhs_rank = | ||||
|         rewriter.create<shape::RankOp>(loc, rewriter.getIndexType(), lhs_shape); | ||||
|     Value rhs_rank = | ||||
|         rewriter.create<shape::RankOp>(loc, rewriter.getIndexType(), rhs_shape); | ||||
|     Value greater_rank_lhs = | ||||
|         rewriter.create<CmpIOp>(loc, CmpIPredicate::sgt, lhs_rank, rhs_rank); | ||||
|     Value greater_rank = | ||||
|         rewriter.create<SelectOp>(loc, greater_rank_lhs, lhs_rank, rhs_rank); | ||||
| 
 | ||||
|     // Generate a list of nested if/else statements to handle rank
 | ||||
|     // specializations from 1 to `kMaxRankSpecialization`.
 | ||||
|     scf::IfOp if_op = createIfOpForRankSpecializedBroadcastAndOp( | ||||
|         rewriter, op, greater_rank, 1); | ||||
|     OpBuilder if_builder = if_op.getThenBodyBuilder(rewriter.getListener()); | ||||
|     createRankSpecializedBroadcastAndOp(if_builder, op, lhs, rhs, 1); | ||||
| 
 | ||||
|     // Put each subsequent rank specialization inside the else statement of the
 | ||||
|     // previous one.
 | ||||
|     OpBuilder else_builder = if_op.getElseBodyBuilder(rewriter.getListener()); | ||||
|     constexpr int kMaxRankSpecialization = 6; | ||||
|     for (int i = 2; i < kMaxRankSpecialization; i++) { | ||||
|       auto inner_if = createIfOpForRankSpecializedBroadcastAndOp( | ||||
|           else_builder, op, greater_rank, i); | ||||
|       if_builder = inner_if.getThenBodyBuilder(rewriter.getListener()); | ||||
|       createRankSpecializedBroadcastAndOp(if_builder, op, lhs, rhs, i); | ||||
|       else_builder.create<scf::YieldOp>(loc, inner_if.getResult(0)); | ||||
|       else_builder = inner_if.getElseBodyBuilder(rewriter.getListener()); | ||||
|     } | ||||
|     // Fire an assertion if none of the rank specializations applied (one of
 | ||||
|     // the ranks was greater than `kMaxRankSpecialization`).
 | ||||
|     else_builder.create<AssertOp>( | ||||
|         loc, | ||||
|         GreaterRankIsN(else_builder, op.getLoc(), greater_rank, | ||||
|                        kMaxRankSpecialization), | ||||
|         "Input for dynamic binary op lowering was of a rank greater than " + | ||||
|             std::to_string(kMaxRankSpecialization)); | ||||
|     // Add the rank 6 specialization to the innermost else block.
 | ||||
|     createRankSpecializedBroadcastAndOp(else_builder, op, lhs, rhs, | ||||
|                                         kMaxRankSpecialization); | ||||
| 
 | ||||
|     // Return the result of the outermost if statement.
 | ||||
|     return if_op.getResult(0); | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| struct TransformUnrankedHloPass | ||||
|  |  | |||
|  | @ -209,9 +209,9 @@ func @addUnrankedUnranked( | |||
| // CHECK-NEXT:                   %[[CONST_SHAPE_1:.*]] = shape.const_shape [1] | ||||
| // CHECK-NEXT:                   %[[BROADCASTED_LHS_1:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex> | ||||
| // CHECK-NEXT:                   %[[CASTED_LHS_1:.*]] = tensor.cast %[[BROADCASTED_LHS_1]] : tensor<?xindex> to tensor<1xindex> | ||||
| // CHECK-NEXT:                   %[[RESHAPED_LHS_1:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_1]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32> | ||||
| // CHECK-NEXT:                   %[[BROADCASTED_RHS_1:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex> | ||||
| // CHECK-NEXT:                   %[[CASTED_RHS_1:.*]] = tensor.cast %[[BROADCASTED_RHS_1]] : tensor<?xindex> to tensor<1xindex> | ||||
| // CHECK-NEXT:                   %[[RESHAPED_LHS_1:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_1]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32> | ||||
| // CHECK-NEXT:                   %[[RESHAPED_RHS_1:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_1]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32> | ||||
| // CHECK-NEXT:                   %[[RESULT_RANK_1:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_1]], %[[RESHAPED_RHS_1]] : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> | ||||
| // CHECK-NEXT:                   %[[RESULT_1:.*]] = tensor.cast %[[RESULT_RANK_1]] : tensor<?xf32> to tensor<*xf32> | ||||
|  | @ -224,9 +224,9 @@ func @addUnrankedUnranked( | |||
| // CHECK-NEXT:                     %[[CONST_SHAPE_2:.*]] = shape.const_shape [1, 1] | ||||
| // CHECK-NEXT:                     %[[BROADCASTED_LHS_2:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_2]] : tensor<?xindex>, tensor<2xindex> -> tensor<?xindex> | ||||
| // CHECK-NEXT:                     %[[CASTED_LHS_2:.*]] = tensor.cast %[[BROADCASTED_LHS_2]] : tensor<?xindex> to tensor<2xindex> | ||||
| // CHECK-NEXT:                     %[[RESHAPED_LHS_2:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor<?x?xf32> | ||||
| // CHECK-NEXT:                     %[[BROADCASTED_RHS_2:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_2]] : tensor<?xindex>, tensor<2xindex> -> tensor<?xindex> | ||||
| // CHECK-NEXT:                     %[[CASTED_RHS_2:.*]] = tensor.cast %[[BROADCASTED_RHS_2]] : tensor<?xindex> to tensor<2xindex> | ||||
| // CHECK-NEXT:                     %[[RESHAPED_LHS_2:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor<?x?xf32> | ||||
| // CHECK-NEXT:                     %[[RESHAPED_RHS_2:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor<?x?xf32> | ||||
| // CHECK-NEXT:                     %[[RESULT_RANK_2:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_2]], %[[RESHAPED_RHS_2]] : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32> | ||||
| // CHECK-NEXT:                     %[[RESULT_2:.*]] = tensor.cast %[[RESULT_RANK_2]] : tensor<?x?xf32> to tensor<*xf32> | ||||
|  | @ -239,9 +239,9 @@ func @addUnrankedUnranked( | |||
| // CHECK-NEXT:                       %[[CONST_SHAPE_3:.*]] = shape.const_shape [1, 1, 1] | ||||
| // CHECK-NEXT:                       %[[BROADCASTED_LHS_3:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_3]] : tensor<?xindex>, tensor<3xindex> -> tensor<?xindex> | ||||
| // CHECK-NEXT:                       %[[CASTED_LHS_3:.*]] = tensor.cast %[[BROADCASTED_LHS_3]] : tensor<?xindex> to tensor<3xindex> | ||||
| // CHECK-NEXT:                       %[[RESHAPED_LHS_3:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor<?x?x?xf32> | ||||
| // CHECK-NEXT:                       %[[BROADCASTED_RHS_3:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_3]] : tensor<?xindex>, tensor<3xindex> -> tensor<?xindex> | ||||
| // CHECK-NEXT:                       %[[CASTED_RHS_3:.*]] = tensor.cast %[[BROADCASTED_RHS_3]] : tensor<?xindex> to tensor<3xindex> | ||||
| // CHECK-NEXT:                       %[[RESHAPED_LHS_3:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor<?x?x?xf32> | ||||
| // CHECK-NEXT:                       %[[RESHAPED_RHS_3:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor<?x?x?xf32> | ||||
| // CHECK-NEXT:                       %[[RESULT_RANK_3:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_3]], %[[RESHAPED_RHS_3]] : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32> | ||||
| // CHECK-NEXT:                       %[[RESULT_3:.*]] = tensor.cast %[[RESULT_RANK_3]] : tensor<?x?x?xf32> to tensor<*xf32> | ||||
|  | @ -254,9 +254,9 @@ func @addUnrankedUnranked( | |||
| // CHECK-NEXT:                         %[[CONST_SHAPE_4:.*]] = shape.const_shape [1, 1, 1, 1] | ||||
| // CHECK-NEXT:                         %[[BROADCASTED_LHS_4:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_4]] : tensor<?xindex>, tensor<4xindex> -> tensor<?xindex> | ||||
| // CHECK-NEXT:                         %[[CASTED_LHS_4:.*]] = tensor.cast %[[BROADCASTED_LHS_4]] : tensor<?xindex> to tensor<4xindex> | ||||
| // CHECK-NEXT:                         %[[RESHAPED_LHS_4:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32> | ||||
| // CHECK-NEXT:                         %[[BROADCASTED_RHS_4:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_4]] : tensor<?xindex>, tensor<4xindex> -> tensor<?xindex> | ||||
| // CHECK-NEXT:                         %[[CASTED_RHS_4:.*]] = tensor.cast %[[BROADCASTED_RHS_4]] : tensor<?xindex> to tensor<4xindex> | ||||
| // CHECK-NEXT:                         %[[RESHAPED_LHS_4:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32> | ||||
| // CHECK-NEXT:                         %[[RESHAPED_RHS_4:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32> | ||||
| // CHECK-NEXT:                         %[[RESULT_RANK_4:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_4]], %[[RESHAPED_RHS_4]] : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> | ||||
| // CHECK-NEXT:                         %[[RESULT_4:.*]] = tensor.cast %[[RESULT_RANK_4]] : tensor<?x?x?x?xf32> to tensor<*xf32> | ||||
|  | @ -269,9 +269,9 @@ func @addUnrankedUnranked( | |||
| // CHECK-NEXT:                           %[[CONST_SHAPE_5:.*]] = shape.const_shape [1, 1, 1, 1, 1] | ||||
| // CHECK-NEXT:                           %[[BROADCASTED_LHS_5:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_5]] : tensor<?xindex>, tensor<5xindex> -> tensor<?xindex> | ||||
| // CHECK-NEXT:                           %[[CASTED_LHS_5:.*]] = tensor.cast %[[BROADCASTED_LHS_5]] : tensor<?xindex> to tensor<5xindex> | ||||
| // CHECK-NEXT:                           %[[RESHAPED_LHS_5:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor<?x?x?x?x?xf32> | ||||
| // CHECK-NEXT:                           %[[BROADCASTED_RHS_5:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_5]] : tensor<?xindex>, tensor<5xindex> -> tensor<?xindex> | ||||
| // CHECK-NEXT:                           %[[CASTED_RHS_5:.*]] = tensor.cast %[[BROADCASTED_RHS_5]] : tensor<?xindex> to tensor<5xindex> | ||||
| // CHECK-NEXT:                           %[[RESHAPED_LHS_5:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor<?x?x?x?x?xf32> | ||||
| // CHECK-NEXT:                           %[[RESHAPED_RHS_5:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor<?x?x?x?x?xf32> | ||||
| // CHECK-NEXT:                           %[[RESULT_RANK_5:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_5]], %[[RESHAPED_RHS_5]] : (tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> | ||||
| // CHECK-NEXT:                           %[[RESULT_5:.*]] = tensor.cast %[[RESULT_RANK_5]] : tensor<?x?x?x?x?xf32> to tensor<*xf32> | ||||
|  | @ -284,9 +284,9 @@ func @addUnrankedUnranked( | |||
| // CHECK-NEXT:                           %[[CONST_SHAPE_6:.*]] = shape.const_shape [1, 1, 1, 1, 1, 1] | ||||
| // CHECK-NEXT:                           %[[BROADCASTED_LHS_6:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_6]] : tensor<?xindex>, tensor<6xindex> -> tensor<?xindex> | ||||
| // CHECK-NEXT:                           %[[CASTED_LHS_6:.*]] = tensor.cast %[[BROADCASTED_LHS_6]] : tensor<?xindex> to tensor<6xindex> | ||||
| // CHECK-NEXT:                           %[[RESHAPED_LHS_6:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_6]]) : (tensor<*xf32>, tensor<6xindex>) -> tensor<?x?x?x?x?x?xf32> | ||||
| // CHECK-NEXT:                           %[[BROADCASTED_RHS_6:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_6]] : tensor<?xindex>, tensor<6xindex> -> tensor<?xindex> | ||||
| // CHECK-NEXT:                           %[[CASTED_RHS_6:.*]] = tensor.cast %[[BROADCASTED_RHS_6]] : tensor<?xindex> to tensor<6xindex> | ||||
| // CHECK-NEXT:                           %[[RESHAPED_LHS_6:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_6]]) : (tensor<*xf32>, tensor<6xindex>) -> tensor<?x?x?x?x?x?xf32> | ||||
| // CHECK-NEXT:                           %[[RESHAPED_RHS_6:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_6]]) : (tensor<*xf32>, tensor<6xindex>) -> tensor<?x?x?x?x?x?xf32> | ||||
| // CHECK-NEXT:                           %[[RESULT_RANK_6:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_6]], %[[RESHAPED_RHS_6]] : (tensor<?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?xf32>) -> tensor<?x?x?x?x?x?xf32> | ||||
| // CHECK-NEXT:                           %[[RESULT_6:.*]] = tensor.cast %[[RESULT_RANK_6]] : tensor<?x?x?x?x?x?xf32> to tensor<*xf32> | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue