Legalize MinimumBroadcastShapes op.
Use it in TransformUnrankedHloPass, which allows to reduce the maximum rank for rank specialized broadcast from 6 to 5. PiperOrigin-RevId: 360415743
This commit is contained in:
		
							parent
							
								
									329b1fd071
								
							
						
					
					
						commit
						0683db3b24
					
				|  | @ -51,6 +51,7 @@ struct ChloLegalizeToHloPass | ||||||
|     conversionTarget.addLegalDialect< |     conversionTarget.addLegalDialect< | ||||||
|         MhloDialect, mlir::StandardOpsDialect, mlir::tensor::TensorDialect, |         MhloDialect, mlir::StandardOpsDialect, mlir::tensor::TensorDialect, | ||||||
|         mlir::shape::ShapeDialect, mlir::scf::SCFDialect>(); |         mlir::shape::ShapeDialect, mlir::scf::SCFDialect>(); | ||||||
|  |     conversionTarget.addLegalOp<chlo::MinimumBroadcastShapesOp>(); | ||||||
| 
 | 
 | ||||||
|     if (broadcast_only_) { |     if (broadcast_only_) { | ||||||
|       chlo::PopulateChloBroadcastingPatterns(&getContext(), |       chlo::PopulateChloBroadcastingPatterns(&getContext(), | ||||||
|  |  | ||||||
|  | @ -223,9 +223,8 @@ struct ConvertUnrankedDynamicBroadcastOpHelper { | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   static Value createBroadcastToKnownRank(OpBuilder &builder, ChloOpTy op, |   static Value createBroadcastToKnownRank(OpBuilder &builder, ChloOpTy op, | ||||||
|                                           Value value, int targeted_rank) { |                                           Value shape, int targeted_rank) { | ||||||
|     auto loc = op.getLoc(); |     auto loc = op.getLoc(); | ||||||
|     Value shape = builder.create<shape::ShapeOfOp>(loc, value); |  | ||||||
|     SmallVector<int64_t, 6> ranked_shape(targeted_rank, 1); |     SmallVector<int64_t, 6> ranked_shape(targeted_rank, 1); | ||||||
|     auto unknown_rank_extent_tensor_type = RankedTensorType::get( |     auto unknown_rank_extent_tensor_type = RankedTensorType::get( | ||||||
|         {RankedTensorType::kDynamicSize}, builder.getIndexType()); |         {RankedTensorType::kDynamicSize}, builder.getIndexType()); | ||||||
|  | @ -246,6 +245,7 @@ struct ConvertUnrankedDynamicBroadcastOpHelper { | ||||||
|   static void createRankSpecializedBroadcastAndOp(OpBuilder &if_builder, |   static void createRankSpecializedBroadcastAndOp(OpBuilder &if_builder, | ||||||
|                                                   ChloOpTy op, |                                                   ChloOpTy op, | ||||||
|                                                   ValueRange operands, |                                                   ValueRange operands, | ||||||
|  |                                                   ValueRange operand_shapes, | ||||||
|                                                   int targeted_rank) { |                                                   int targeted_rank) { | ||||||
|     auto loc = op.getLoc(); |     auto loc = op.getLoc(); | ||||||
|     SmallVector<Value, 2> reshaped_operands; |     SmallVector<Value, 2> reshaped_operands; | ||||||
|  | @ -253,10 +253,12 @@ struct ConvertUnrankedDynamicBroadcastOpHelper { | ||||||
|     auto dynamic_dimensions = llvm::SmallVector<int64_t, 6>( |     auto dynamic_dimensions = llvm::SmallVector<int64_t, 6>( | ||||||
|         targeted_rank, RankedTensorType::kDynamicSize); |         targeted_rank, RankedTensorType::kDynamicSize); | ||||||
| 
 | 
 | ||||||
|     for (Value operand : operands) { |     for (auto it : llvm::zip(operands, operand_shapes)) { | ||||||
|  |       Value operand, shape; | ||||||
|  |       std::tie(operand, shape) = it; | ||||||
|       // Handle shape broadcasting and inference.
 |       // Handle shape broadcasting and inference.
 | ||||||
|       Value extended_operand_casted = |       Value extended_operand_casted = | ||||||
|           createBroadcastToKnownRank(if_builder, op, operand, targeted_rank); |           createBroadcastToKnownRank(if_builder, op, shape, targeted_rank); | ||||||
| 
 | 
 | ||||||
|       // 1. Reshape operands to the given rank (with the same number of
 |       // 1. Reshape operands to the given rank (with the same number of
 | ||||||
|       // elements)
 |       // elements)
 | ||||||
|  | @ -290,13 +292,37 @@ struct ConvertUnrankedDynamicBroadcastOpHelper { | ||||||
|                                     ValueRange operands) { |                                     ValueRange operands) { | ||||||
|     auto loc = op.getLoc(); |     auto loc = op.getLoc(); | ||||||
| 
 | 
 | ||||||
|     // Find the larger rank of the operands.
 |     // Get the minimum broadcast shapes of the operands.
 | ||||||
|  |     SmallVector<Value> shapes; | ||||||
|  |     shapes.reserve(operands.size()); | ||||||
|     auto extent_tensor_type = RankedTensorType::get({ShapedType::kDynamicSize}, |     auto extent_tensor_type = RankedTensorType::get({ShapedType::kDynamicSize}, | ||||||
|                                                     rewriter.getIndexType()); |                                                     rewriter.getIndexType()); | ||||||
|     Value greater_rank; |  | ||||||
|     for (Value operand : operands) { |     for (Value operand : operands) { | ||||||
|       Value shape = |       Value shape = | ||||||
|           rewriter.create<shape::ShapeOfOp>(loc, extent_tensor_type, operand); |           rewriter.create<shape::ShapeOfOp>(loc, extent_tensor_type, operand); | ||||||
|  |       shapes.push_back(shape); | ||||||
|  |     } | ||||||
|  |     auto broadcast_shape = rewriter.create<shape::BroadcastOp>( | ||||||
|  |         loc, extent_tensor_type, shapes, nullptr); | ||||||
|  |     SmallVector<Type> result_types(shapes.size(), extent_tensor_type); | ||||||
|  |     auto reduced_shapes = | ||||||
|  |         rewriter | ||||||
|  |             .create<chlo::MinimumBroadcastShapesOp>(loc, result_types, shapes) | ||||||
|  |             .results(); | ||||||
|  |     SmallVector<Value> reshaped_operands; | ||||||
|  |     reshaped_operands.reserve(operands.size()); | ||||||
|  |     for (auto it : llvm::zip(operands, reduced_shapes)) { | ||||||
|  |       Value operand; | ||||||
|  |       Value reduced_shape; | ||||||
|  |       std::tie(operand, reduced_shape) = it; | ||||||
|  |       auto reshaped_operand = rewriter.create<mhlo::DynamicReshapeOp>( | ||||||
|  |           loc, operand.getType(), operand, reduced_shape); | ||||||
|  |       reshaped_operands.push_back(reshaped_operand); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     // Find the largest rank of the operands.
 | ||||||
|  |     Value greater_rank; | ||||||
|  |     for (Value shape : reduced_shapes) { | ||||||
|       Value rank = |       Value rank = | ||||||
|           rewriter.create<shape::RankOp>(loc, rewriter.getIndexType(), shape); |           rewriter.create<shape::RankOp>(loc, rewriter.getIndexType(), shape); | ||||||
|       if (!greater_rank) { |       if (!greater_rank) { | ||||||
|  | @ -314,17 +340,19 @@ struct ConvertUnrankedDynamicBroadcastOpHelper { | ||||||
|     scf::IfOp if_op = createIfOpForRankSpecializedBroadcastAndOp( |     scf::IfOp if_op = createIfOpForRankSpecializedBroadcastAndOp( | ||||||
|         rewriter, op, greater_rank, 1); |         rewriter, op, greater_rank, 1); | ||||||
|     OpBuilder if_builder = if_op.getThenBodyBuilder(rewriter.getListener()); |     OpBuilder if_builder = if_op.getThenBodyBuilder(rewriter.getListener()); | ||||||
|     createRankSpecializedBroadcastAndOp(if_builder, op, operands, 1); |     createRankSpecializedBroadcastAndOp(if_builder, op, reshaped_operands, | ||||||
|  |                                         reduced_shapes, 1); | ||||||
| 
 | 
 | ||||||
|     // Put each subsequent rank specialization inside the else statement of the
 |     // Put each subsequent rank specialization inside the else statement of the
 | ||||||
|     // previous one.
 |     // previous one.
 | ||||||
|     OpBuilder else_builder = if_op.getElseBodyBuilder(rewriter.getListener()); |     OpBuilder else_builder = if_op.getElseBodyBuilder(rewriter.getListener()); | ||||||
|     constexpr int kMaxRankSpecialization = 6; |     constexpr int kMaxRankSpecialization = 5; | ||||||
|     for (int i = 2; i < kMaxRankSpecialization; i++) { |     for (int i = 2; i < kMaxRankSpecialization; i++) { | ||||||
|       auto inner_if = createIfOpForRankSpecializedBroadcastAndOp( |       auto inner_if = createIfOpForRankSpecializedBroadcastAndOp( | ||||||
|           else_builder, op, greater_rank, i); |           else_builder, op, greater_rank, i); | ||||||
|       if_builder = inner_if.getThenBodyBuilder(rewriter.getListener()); |       if_builder = inner_if.getThenBodyBuilder(rewriter.getListener()); | ||||||
|       createRankSpecializedBroadcastAndOp(if_builder, op, operands, i); |       createRankSpecializedBroadcastAndOp(if_builder, op, reshaped_operands, | ||||||
|  |                                           reduced_shapes, i); | ||||||
|       else_builder.create<scf::YieldOp>(loc, inner_if.getResult(0)); |       else_builder.create<scf::YieldOp>(loc, inner_if.getResult(0)); | ||||||
|       else_builder = inner_if.getElseBodyBuilder(rewriter.getListener()); |       else_builder = inner_if.getElseBodyBuilder(rewriter.getListener()); | ||||||
|     } |     } | ||||||
|  | @ -336,12 +364,15 @@ struct ConvertUnrankedDynamicBroadcastOpHelper { | ||||||
|                        kMaxRankSpecialization), |                        kMaxRankSpecialization), | ||||||
|         "Input for dynamic binary op lowering was of a rank greater than " + |         "Input for dynamic binary op lowering was of a rank greater than " + | ||||||
|             std::to_string(kMaxRankSpecialization)); |             std::to_string(kMaxRankSpecialization)); | ||||||
|     // Add the rank 6 specialization to the innermost else block.
 |     // Add the rank 5 specialization to the innermost else block.
 | ||||||
|     createRankSpecializedBroadcastAndOp(else_builder, op, operands, |     createRankSpecializedBroadcastAndOp(else_builder, op, reshaped_operands, | ||||||
|                                         kMaxRankSpecialization); |                                         reduced_shapes, kMaxRankSpecialization); | ||||||
| 
 | 
 | ||||||
|     // Return the result of the outermost if statement.
 |     // Return the reshaped result of the outermost if statement.
 | ||||||
|     return if_op.getResult(0); |     auto result = if_op.getResult(0); | ||||||
|  |     auto reshaped_result = rewriter.create<mhlo::DynamicReshapeOp>( | ||||||
|  |         loc, result.getType(), result, broadcast_shape); | ||||||
|  |     return reshaped_result; | ||||||
|   } |   } | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
|  | @ -497,16 +528,17 @@ struct ConvertUnrankedDynamicBroadcastSelectOp | ||||||
| struct TransformUnrankedHloPass | struct TransformUnrankedHloPass | ||||||
|     : public PassWrapper<TransformUnrankedHloPass, FunctionPass> { |     : public PassWrapper<TransformUnrankedHloPass, FunctionPass> { | ||||||
|   void getDependentDialects(DialectRegistry ®istry) const override { |   void getDependentDialects(DialectRegistry ®istry) const override { | ||||||
|     registry.insert<shape::ShapeDialect, mhlo::MhloDialect>(); |     registry.insert<chlo::HloClientDialect, mhlo::MhloDialect, | ||||||
|  |                     shape::ShapeDialect>(); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   void runOnFunction() override { |   void runOnFunction() override { | ||||||
|     // Setup conversion target.
 |     // Setup conversion target.
 | ||||||
|     MLIRContext &ctx = getContext(); |     MLIRContext &ctx = getContext(); | ||||||
|     ConversionTarget target(ctx); |     ConversionTarget target(ctx); | ||||||
|     target.addLegalDialect<mhlo::MhloDialect, StandardOpsDialect, |     target.addLegalDialect<chlo::HloClientDialect, mhlo::MhloDialect, | ||||||
|                            shape::ShapeDialect, scf::SCFDialect, |                            StandardOpsDialect, shape::ShapeDialect, | ||||||
|                            tensor::TensorDialect>(); |                            scf::SCFDialect, tensor::TensorDialect>(); | ||||||
|     target.addLegalOp<FuncOp>(); |     target.addLegalOp<FuncOp>(); | ||||||
| #define ADD_LEGAL_MHLO(op) AddLegalOpOnRankedTensor<mhlo::op>(&target) | #define ADD_LEGAL_MHLO(op) AddLegalOpOnRankedTensor<mhlo::op>(&target) | ||||||
| #define ADD_LEGAL_CHLO(op) AddLegalOpOnRankedTensor<chlo::op>(&target) | #define ADD_LEGAL_CHLO(op) AddLegalOpOnRankedTensor<chlo::op>(&target) | ||||||
|  |  | ||||||
|  | @ -199,20 +199,24 @@ func @addUnrankedUnranked( | ||||||
| // CHECK-NEXT:                 %[[RESHAPED_SAME_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[FLATTENED_RESULT]], %[[ANY_SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32> | // CHECK-NEXT:                 %[[RESHAPED_SAME_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[FLATTENED_RESULT]], %[[ANY_SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32> | ||||||
| // CHECK-NEXT:                 scf.yield %[[RESHAPED_SAME_RESULT]] : tensor<*xf32> | // CHECK-NEXT:                 scf.yield %[[RESHAPED_SAME_RESULT]] : tensor<*xf32> | ||||||
| // CHECK-NEXT:               } else { | // CHECK-NEXT:               } else { | ||||||
| // CHECK-NEXT:                 %[[LHS_RANK:.*]] = shape.rank %[[LHS_SHAPE]] : tensor<?xindex> -> index | // CHECK-NEXT:                 %[[RESULT_SHAPE:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex> -> tensor<?xindex> | ||||||
| // CHECK-NEXT:                 %[[RHS_RANK:.*]] = shape.rank %[[RHS_SHAPE]] : tensor<?xindex> -> index | // CHECK-NEXT:                 %[[MINIMUM_SHAPES:.*]]:2 = chlo.minimum_broadcast_shapes %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>, tensor<?xindex> | ||||||
|  | // CHECK-NEXT:                 %[[MINIMUM_RESHAPED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[MINIMUM_SHAPES]]#0) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32> | ||||||
|  | // CHECK-NEXT:                 %[[MINIMUM_RESHAPED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[MINIMUM_SHAPES]]#1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32> | ||||||
|  | // CHECK-NEXT:                 %[[LHS_RANK:.*]] = shape.rank %[[MINIMUM_SHAPES]]#0 : tensor<?xindex> -> index | ||||||
|  | // CHECK-NEXT:                 %[[RHS_RANK:.*]] = shape.rank %[[MINIMUM_SHAPES]]#1 : tensor<?xindex> -> index | ||||||
| // CHECK-NEXT:                 %[[LHS_RANK_GREATER:.*]] = cmpi sgt, %[[LHS_RANK]], %[[RHS_RANK]] : index | // CHECK-NEXT:                 %[[LHS_RANK_GREATER:.*]] = cmpi sgt, %[[LHS_RANK]], %[[RHS_RANK]] : index | ||||||
| // CHECK-NEXT:                 %[[GREATEST_RANK:.*]] = select %[[LHS_RANK_GREATER]], %[[LHS_RANK]], %[[RHS_RANK]] : index | // CHECK-NEXT:                 %[[GREATEST_RANK:.*]] = select %[[LHS_RANK_GREATER]], %[[LHS_RANK]], %[[RHS_RANK]] : index | ||||||
| //                             Handle rank 1 specialization | //                             Handle rank 1 specialization | ||||||
| // CHECK-NEXT:                 %[[GREATEST_RANK_IS_1:.*]] = cmpi eq, %[[GREATEST_RANK]], %[[C1]] : index | // CHECK-NEXT:                 %[[GREATEST_RANK_IS_1:.*]] = cmpi eq, %[[GREATEST_RANK]], %[[C1]] : index | ||||||
| // CHECK-NEXT:                 %[[RESULT_RANK_1:.*]] = scf.if %[[GREATEST_RANK_IS_1]] -> (tensor<*xf32>) { | // CHECK-NEXT:                 %[[RESULT_RANK_SPECIALIZATION:.*]] = scf.if %[[GREATEST_RANK_IS_1]] -> (tensor<*xf32>) { | ||||||
| // CHECK-NEXT:                   %[[CONST_SHAPE_1:.*]] = shape.const_shape [1] | // CHECK-NEXT:                   %[[CONST_SHAPE_1:.*]] = shape.const_shape [1] | ||||||
| // CHECK-NEXT:                   %[[BROADCASTED_LHS_1:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex> | // CHECK-NEXT:                   %[[BROADCASTED_LHS_1:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex> | ||||||
| // CHECK-NEXT:                   %[[CASTED_LHS_1:.*]] = tensor.cast %[[BROADCASTED_LHS_1]] : tensor<?xindex> to tensor<1xindex> | // CHECK-NEXT:                   %[[CASTED_LHS_1:.*]] = tensor.cast %[[BROADCASTED_LHS_1]] : tensor<?xindex> to tensor<1xindex> | ||||||
| // CHECK-NEXT:                   %[[RESHAPED_LHS_1:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_1]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32> | // CHECK-NEXT:                   %[[RESHAPED_LHS_1:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS_1]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32> | ||||||
| // CHECK-NEXT:                   %[[BROADCASTED_RHS_1:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex> | // CHECK-NEXT:                   %[[BROADCASTED_RHS_1:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex> | ||||||
| // CHECK-NEXT:                   %[[CASTED_RHS_1:.*]] = tensor.cast %[[BROADCASTED_RHS_1]] : tensor<?xindex> to tensor<1xindex> | // CHECK-NEXT:                   %[[CASTED_RHS_1:.*]] = tensor.cast %[[BROADCASTED_RHS_1]] : tensor<?xindex> to tensor<1xindex> | ||||||
| // CHECK-NEXT:                   %[[RESHAPED_RHS_1:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_1]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32> | // CHECK-NEXT:                   %[[RESHAPED_RHS_1:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS_1]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32> | ||||||
| // CHECK-NEXT:                   %[[RESULT_RANK_1:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_1]], %[[RESHAPED_RHS_1]] : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> | // CHECK-NEXT:                   %[[RESULT_RANK_1:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_1]], %[[RESHAPED_RHS_1]] : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> | ||||||
| // CHECK-NEXT:                   %[[RESULT_1:.*]] = tensor.cast %[[RESULT_RANK_1]] : tensor<?xf32> to tensor<*xf32> | // CHECK-NEXT:                   %[[RESULT_1:.*]] = tensor.cast %[[RESULT_RANK_1]] : tensor<?xf32> to tensor<*xf32> | ||||||
| // CHECK-NEXT:                   scf.yield %[[RESULT_1]] : tensor<*xf32> | // CHECK-NEXT:                   scf.yield %[[RESULT_1]] : tensor<*xf32> | ||||||
|  | @ -222,12 +226,12 @@ func @addUnrankedUnranked( | ||||||
| //                               Handle rank 2 specialization | //                               Handle rank 2 specialization | ||||||
| // CHECK-NEXT:                   %[[VAL_26:.*]] = scf.if %[[GREATEST_RANK_IS_2]] -> (tensor<*xf32>) { | // CHECK-NEXT:                   %[[VAL_26:.*]] = scf.if %[[GREATEST_RANK_IS_2]] -> (tensor<*xf32>) { | ||||||
| // CHECK-NEXT:                     %[[CONST_SHAPE_2:.*]] = shape.const_shape [1, 1] | // CHECK-NEXT:                     %[[CONST_SHAPE_2:.*]] = shape.const_shape [1, 1] | ||||||
| // CHECK-NEXT:                     %[[BROADCASTED_LHS_2:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_2]] : tensor<?xindex>, tensor<2xindex> -> tensor<?xindex> | // CHECK-NEXT:                     %[[BROADCASTED_LHS_2:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_2]] : tensor<?xindex>, tensor<2xindex> -> tensor<?xindex> | ||||||
| // CHECK-NEXT:                     %[[CASTED_LHS_2:.*]] = tensor.cast %[[BROADCASTED_LHS_2]] : tensor<?xindex> to tensor<2xindex> | // CHECK-NEXT:                     %[[CASTED_LHS_2:.*]] = tensor.cast %[[BROADCASTED_LHS_2]] : tensor<?xindex> to tensor<2xindex> | ||||||
| // CHECK-NEXT:                     %[[RESHAPED_LHS_2:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor<?x?xf32> | // CHECK-NEXT:                     %[[RESHAPED_LHS_2:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor<?x?xf32> | ||||||
| // CHECK-NEXT:                     %[[BROADCASTED_RHS_2:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_2]] : tensor<?xindex>, tensor<2xindex> -> tensor<?xindex> | // CHECK-NEXT:                     %[[BROADCASTED_RHS_2:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_2]] : tensor<?xindex>, tensor<2xindex> -> tensor<?xindex> | ||||||
| // CHECK-NEXT:                     %[[CASTED_RHS_2:.*]] = tensor.cast %[[BROADCASTED_RHS_2]] : tensor<?xindex> to tensor<2xindex> | // CHECK-NEXT:                     %[[CASTED_RHS_2:.*]] = tensor.cast %[[BROADCASTED_RHS_2]] : tensor<?xindex> to tensor<2xindex> | ||||||
| // CHECK-NEXT:                     %[[RESHAPED_RHS_2:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor<?x?xf32> | // CHECK-NEXT:                     %[[RESHAPED_RHS_2:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor<?x?xf32> | ||||||
| // CHECK-NEXT:                     %[[RESULT_RANK_2:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_2]], %[[RESHAPED_RHS_2]] : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32> | // CHECK-NEXT:                     %[[RESULT_RANK_2:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_2]], %[[RESHAPED_RHS_2]] : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32> | ||||||
| // CHECK-NEXT:                     %[[RESULT_2:.*]] = tensor.cast %[[RESULT_RANK_2]] : tensor<?x?xf32> to tensor<*xf32> | // CHECK-NEXT:                     %[[RESULT_2:.*]] = tensor.cast %[[RESULT_RANK_2]] : tensor<?x?xf32> to tensor<*xf32> | ||||||
| // CHECK-NEXT:                     scf.yield %[[RESULT_2]] : tensor<*xf32> | // CHECK-NEXT:                     scf.yield %[[RESULT_2]] : tensor<*xf32> | ||||||
|  | @ -237,12 +241,12 @@ func @addUnrankedUnranked( | ||||||
| //                                 Handle rank 3 specialization | //                                 Handle rank 3 specialization | ||||||
| // CHECK-NEXT:                     %[[VAL_34:.*]] = scf.if %[[GREATEST_RANK_IS_3]] -> (tensor<*xf32>) { | // CHECK-NEXT:                     %[[VAL_34:.*]] = scf.if %[[GREATEST_RANK_IS_3]] -> (tensor<*xf32>) { | ||||||
| // CHECK-NEXT:                       %[[CONST_SHAPE_3:.*]] = shape.const_shape [1, 1, 1] | // CHECK-NEXT:                       %[[CONST_SHAPE_3:.*]] = shape.const_shape [1, 1, 1] | ||||||
| // CHECK-NEXT:                       %[[BROADCASTED_LHS_3:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_3]] : tensor<?xindex>, tensor<3xindex> -> tensor<?xindex> | // CHECK-NEXT:                       %[[BROADCASTED_LHS_3:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_3]] : tensor<?xindex>, tensor<3xindex> -> tensor<?xindex> | ||||||
| // CHECK-NEXT:                       %[[CASTED_LHS_3:.*]] = tensor.cast %[[BROADCASTED_LHS_3]] : tensor<?xindex> to tensor<3xindex> | // CHECK-NEXT:                       %[[CASTED_LHS_3:.*]] = tensor.cast %[[BROADCASTED_LHS_3]] : tensor<?xindex> to tensor<3xindex> | ||||||
| // CHECK-NEXT:                       %[[RESHAPED_LHS_3:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor<?x?x?xf32> | // CHECK-NEXT:                       %[[RESHAPED_LHS_3:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor<?x?x?xf32> | ||||||
| // CHECK-NEXT:                       %[[BROADCASTED_RHS_3:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_3]] : tensor<?xindex>, tensor<3xindex> -> tensor<?xindex> | // CHECK-NEXT:                       %[[BROADCASTED_RHS_3:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_3]] : tensor<?xindex>, tensor<3xindex> -> tensor<?xindex> | ||||||
| // CHECK-NEXT:                       %[[CASTED_RHS_3:.*]] = tensor.cast %[[BROADCASTED_RHS_3]] : tensor<?xindex> to tensor<3xindex> | // CHECK-NEXT:                       %[[CASTED_RHS_3:.*]] = tensor.cast %[[BROADCASTED_RHS_3]] : tensor<?xindex> to tensor<3xindex> | ||||||
| // CHECK-NEXT:                       %[[RESHAPED_RHS_3:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor<?x?x?xf32> | // CHECK-NEXT:                       %[[RESHAPED_RHS_3:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor<?x?x?xf32> | ||||||
| // CHECK-NEXT:                       %[[RESULT_RANK_3:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_3]], %[[RESHAPED_RHS_3]] : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32> | // CHECK-NEXT:                       %[[RESULT_RANK_3:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_3]], %[[RESHAPED_RHS_3]] : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32> | ||||||
| // CHECK-NEXT:                       %[[RESULT_3:.*]] = tensor.cast %[[RESULT_RANK_3]] : tensor<?x?x?xf32> to tensor<*xf32> | // CHECK-NEXT:                       %[[RESULT_3:.*]] = tensor.cast %[[RESULT_RANK_3]] : tensor<?x?x?xf32> to tensor<*xf32> | ||||||
| // CHECK-NEXT:                       scf.yield %[[RESULT_3]] : tensor<*xf32> | // CHECK-NEXT:                       scf.yield %[[RESULT_3]] : tensor<*xf32> | ||||||
|  | @ -252,47 +256,30 @@ func @addUnrankedUnranked( | ||||||
| //                                   Handle rank 4 specialization | //                                   Handle rank 4 specialization | ||||||
| // CHECK-NEXT:                       %[[VAL_42:.*]] = scf.if %[[GREATEST_RANK_IS_4]] -> (tensor<*xf32>) { | // CHECK-NEXT:                       %[[VAL_42:.*]] = scf.if %[[GREATEST_RANK_IS_4]] -> (tensor<*xf32>) { | ||||||
| // CHECK-NEXT:                         %[[CONST_SHAPE_4:.*]] = shape.const_shape [1, 1, 1, 1] | // CHECK-NEXT:                         %[[CONST_SHAPE_4:.*]] = shape.const_shape [1, 1, 1, 1] | ||||||
| // CHECK-NEXT:                         %[[BROADCASTED_LHS_4:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_4]] : tensor<?xindex>, tensor<4xindex> -> tensor<?xindex> | // CHECK-NEXT:                         %[[BROADCASTED_LHS_4:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_4]] : tensor<?xindex>, tensor<4xindex> -> tensor<?xindex> | ||||||
| // CHECK-NEXT:                         %[[CASTED_LHS_4:.*]] = tensor.cast %[[BROADCASTED_LHS_4]] : tensor<?xindex> to tensor<4xindex> | // CHECK-NEXT:                         %[[CASTED_LHS_4:.*]] = tensor.cast %[[BROADCASTED_LHS_4]] : tensor<?xindex> to tensor<4xindex> | ||||||
| // CHECK-NEXT:                         %[[RESHAPED_LHS_4:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32> | // CHECK-NEXT:                         %[[RESHAPED_LHS_4:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32> | ||||||
| // CHECK-NEXT:                         %[[BROADCASTED_RHS_4:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_4]] : tensor<?xindex>, tensor<4xindex> -> tensor<?xindex> | // CHECK-NEXT:                         %[[BROADCASTED_RHS_4:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_4]] : tensor<?xindex>, tensor<4xindex> -> tensor<?xindex> | ||||||
| // CHECK-NEXT:                         %[[CASTED_RHS_4:.*]] = tensor.cast %[[BROADCASTED_RHS_4]] : tensor<?xindex> to tensor<4xindex> | // CHECK-NEXT:                         %[[CASTED_RHS_4:.*]] = tensor.cast %[[BROADCASTED_RHS_4]] : tensor<?xindex> to tensor<4xindex> | ||||||
| // CHECK-NEXT:                         %[[RESHAPED_RHS_4:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32> | // CHECK-NEXT:                         %[[RESHAPED_RHS_4:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32> | ||||||
| // CHECK-NEXT:                         %[[RESULT_RANK_4:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_4]], %[[RESHAPED_RHS_4]] : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> | // CHECK-NEXT:                         %[[RESULT_RANK_4:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_4]], %[[RESHAPED_RHS_4]] : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> | ||||||
| // CHECK-NEXT:                         %[[RESULT_4:.*]] = tensor.cast %[[RESULT_RANK_4]] : tensor<?x?x?x?xf32> to tensor<*xf32> | // CHECK-NEXT:                         %[[RESULT_4:.*]] = tensor.cast %[[RESULT_RANK_4]] : tensor<?x?x?x?xf32> to tensor<*xf32> | ||||||
| // CHECK-NEXT:                         scf.yield %[[RESULT_4]] : tensor<*xf32> | // CHECK-NEXT:                         scf.yield %[[RESULT_4]] : tensor<*xf32> | ||||||
| // CHECK-NEXT:                       } else { | // CHECK-NEXT:                       } else { | ||||||
| // CHECK-NEXT:                         %[[C5:.*]] = constant 5 : index | // CHECK-NEXT:                         %[[C5:.*]] = constant 5 : index | ||||||
| // CHECK-NEXT:                         %[[GREATEST_RANK_IS_5:.*]] = cmpi eq, %[[GREATEST_RANK]], %[[C5]] : index | // CHECK-NEXT:                         %[[GREATEST_RANK_IS_5:.*]] = cmpi eq, %[[GREATEST_RANK]], %[[C5]] : index | ||||||
|  | // CHECK-NEXT:                         assert %[[GREATEST_RANK_IS_5]] | ||||||
| //                                     Handle rank 5 specialization | //                                     Handle rank 5 specialization | ||||||
| // CHECK-NEXT:                         %[[VAL_50:.*]] = scf.if %[[GREATEST_RANK_IS_5]] -> (tensor<*xf32>) { | // CHECK-NEXT:                         %[[CONST_SHAPE_5:.*]] = shape.const_shape [1, 1, 1, 1, 1] | ||||||
| // CHECK-NEXT:                           %[[CONST_SHAPE_5:.*]] = shape.const_shape [1, 1, 1, 1, 1] | // CHECK-NEXT:                         %[[BROADCASTED_LHS_5:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_5]] : tensor<?xindex>, tensor<5xindex> -> tensor<?xindex> | ||||||
| // CHECK-NEXT:                           %[[BROADCASTED_LHS_5:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_5]] : tensor<?xindex>, tensor<5xindex> -> tensor<?xindex> | // CHECK-NEXT:                         %[[CASTED_LHS_5:.*]] = tensor.cast %[[BROADCASTED_LHS_5]] : tensor<?xindex> to tensor<5xindex> | ||||||
| // CHECK-NEXT:                           %[[CASTED_LHS_5:.*]] = tensor.cast %[[BROADCASTED_LHS_5]] : tensor<?xindex> to tensor<5xindex> | // CHECK-NEXT:                         %[[RESHAPED_LHS_5:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor<?x?x?x?x?xf32> | ||||||
| // CHECK-NEXT:                           %[[RESHAPED_LHS_5:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor<?x?x?x?x?xf32> | // CHECK-NEXT:                         %[[BROADCASTED_RHS_5:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_5]] : tensor<?xindex>, tensor<5xindex> -> tensor<?xindex> | ||||||
| // CHECK-NEXT:                           %[[BROADCASTED_RHS_5:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_5]] : tensor<?xindex>, tensor<5xindex> -> tensor<?xindex> | // CHECK-NEXT:                         %[[CASTED_RHS_5:.*]] = tensor.cast %[[BROADCASTED_RHS_5]] : tensor<?xindex> to tensor<5xindex> | ||||||
| // CHECK-NEXT:                           %[[CASTED_RHS_5:.*]] = tensor.cast %[[BROADCASTED_RHS_5]] : tensor<?xindex> to tensor<5xindex> | // CHECK-NEXT:                         %[[RESHAPED_RHS_5:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor<?x?x?x?x?xf32> | ||||||
| // CHECK-NEXT:                           %[[RESHAPED_RHS_5:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor<?x?x?x?x?xf32> | // CHECK-NEXT:                         %[[RESULT_RANK_5:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_5]], %[[RESHAPED_RHS_5]] : (tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> | ||||||
| // CHECK-NEXT:                           %[[RESULT_RANK_5:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_5]], %[[RESHAPED_RHS_5]] : (tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> | // CHECK-NEXT:                         %[[RESULT_5:.*]] = tensor.cast %[[RESULT_RANK_5]] : tensor<?x?x?x?x?xf32> to tensor<*xf32> | ||||||
| // CHECK-NEXT:                           %[[RESULT_5:.*]] = tensor.cast %[[RESULT_RANK_5]] : tensor<?x?x?x?x?xf32> to tensor<*xf32> | // CHECK-NEXT:                         scf.yield %[[RESULT_5]] : tensor<*xf32> | ||||||
| // CHECK-NEXT:                           scf.yield %[[RESULT_5]] : tensor<*xf32> |  | ||||||
| // CHECK-NEXT:                         } else { |  | ||||||
| // CHECK-NEXT:                           %[[C6:.*]] = constant 6 : index |  | ||||||
| // CHECK-NEXT:                           %[[GREATEST_RANK_IS_6:.*]] = cmpi eq, %[[GREATEST_RANK]], %[[C6]] : index |  | ||||||
| // CHECK-NEXT:                           assert %[[GREATEST_RANK_IS_6]] |  | ||||||
| //                                       Handle rank 6 specialization |  | ||||||
| // CHECK-NEXT:                           %[[CONST_SHAPE_6:.*]] = shape.const_shape [1, 1, 1, 1, 1, 1] |  | ||||||
| // CHECK-NEXT:                           %[[BROADCASTED_LHS_6:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_6]] : tensor<?xindex>, tensor<6xindex> -> tensor<?xindex> |  | ||||||
| // CHECK-NEXT:                           %[[CASTED_LHS_6:.*]] = tensor.cast %[[BROADCASTED_LHS_6]] : tensor<?xindex> to tensor<6xindex> |  | ||||||
| // CHECK-NEXT:                           %[[RESHAPED_LHS_6:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_6]]) : (tensor<*xf32>, tensor<6xindex>) -> tensor<?x?x?x?x?x?xf32> |  | ||||||
| // CHECK-NEXT:                           %[[BROADCASTED_RHS_6:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_6]] : tensor<?xindex>, tensor<6xindex> -> tensor<?xindex> |  | ||||||
| // CHECK-NEXT:                           %[[CASTED_RHS_6:.*]] = tensor.cast %[[BROADCASTED_RHS_6]] : tensor<?xindex> to tensor<6xindex> |  | ||||||
| // CHECK-NEXT:                           %[[RESHAPED_RHS_6:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_6]]) : (tensor<*xf32>, tensor<6xindex>) -> tensor<?x?x?x?x?x?xf32> |  | ||||||
| // CHECK-NEXT:                           %[[RESULT_RANK_6:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_6]], %[[RESHAPED_RHS_6]] : (tensor<?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?xf32>) -> tensor<?x?x?x?x?x?xf32> |  | ||||||
| // CHECK-NEXT:                           %[[RESULT_6:.*]] = tensor.cast %[[RESULT_RANK_6]] : tensor<?x?x?x?x?x?xf32> to tensor<*xf32> |  | ||||||
| // CHECK-NEXT:                           scf.yield %[[RESULT_6]] : tensor<*xf32> |  | ||||||
| // CHECK-NEXT:                         } |  | ||||||
| // CHECK-NEXT:                         scf.yield %[[VAL_65:.*]] : tensor<*xf32> |  | ||||||
| // CHECK-NEXT:                       } | // CHECK-NEXT:                       } | ||||||
| // CHECK-NEXT:                       scf.yield %[[VAL_66:.*]] : tensor<*xf32> | // CHECK-NEXT:                       scf.yield %[[VAL_66:.*]] : tensor<*xf32> | ||||||
| // CHECK-NEXT:                     } | // CHECK-NEXT:                     } | ||||||
|  | @ -300,7 +287,8 @@ func @addUnrankedUnranked( | ||||||
| // CHECK-NEXT:                   } | // CHECK-NEXT:                   } | ||||||
| // CHECK-NEXT:                   scf.yield %[[VAL_68:.*]] : tensor<*xf32> | // CHECK-NEXT:                   scf.yield %[[VAL_68:.*]] : tensor<*xf32> | ||||||
| // CHECK-NEXT:                 } | // CHECK-NEXT:                 } | ||||||
| // CHECK-NEXT:                 scf.yield %[[VAL_69:.*]] : tensor<*xf32> | // CHECK-NEXT:                 %[[RESHAPED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[RESULT_RANK_SPECIALIZATION]], %[[RESULT_SHAPE]]) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32> | ||||||
|  | // CHECK-NEXT:                 scf.yield %[[RESHAPED_RESULT]] : tensor<*xf32> | ||||||
| // CHECK-NEXT:               } | // CHECK-NEXT:               } | ||||||
| // CHECK-NEXT:               scf.yield %[[VAL_70:.*]] : tensor<*xf32> | // CHECK-NEXT:               scf.yield %[[VAL_70:.*]] : tensor<*xf32> | ||||||
| // CHECK-NEXT:             } | // CHECK-NEXT:             } | ||||||
|  | @ -325,13 +313,18 @@ func @selectUnrankedUnrankedUnranked( | ||||||
| // CHECK-SAME:     %[[LHS:.*]]: tensor<*xf32>, | // CHECK-SAME:     %[[LHS:.*]]: tensor<*xf32>, | ||||||
| // CHECK-SAME:     %[[RHS:.*]]: tensor<*xf32>) -> tensor<*xf32> { | // CHECK-SAME:     %[[RHS:.*]]: tensor<*xf32>) -> tensor<*xf32> { | ||||||
| // CHECK-NEXT:    %[[PRED_SHAPE:.*]] = shape.shape_of %[[PRED]] : tensor<*xi1> -> tensor<?xindex> | // CHECK-NEXT:    %[[PRED_SHAPE:.*]] = shape.shape_of %[[PRED]] : tensor<*xi1> -> tensor<?xindex> | ||||||
| // CHECK-NEXT:    %[[PRED_RANK:.*]] = shape.rank %[[PRED_SHAPE]] : tensor<?xindex> -> index |  | ||||||
| // CHECK-NEXT:    %[[LHS_SHAPE:.*]] = shape.shape_of %[[LHS]] : tensor<*xf32> -> tensor<?xindex> | // CHECK-NEXT:    %[[LHS_SHAPE:.*]] = shape.shape_of %[[LHS]] : tensor<*xf32> -> tensor<?xindex> | ||||||
| // CHECK-NEXT:    %[[LHS_RANK:.*]] = shape.rank %[[LHS_SHAPE]] : tensor<?xindex> -> index | // CHECK-NEXT:    %[[RHS_SHAPE:.*]] = shape.shape_of %[[RHS]] : tensor<*xf32> -> tensor<?xindex> | ||||||
|  | // CHECK-NEXT:    %[[RESULT_SHAPE:.*]] = shape.broadcast %[[PRED_SHAPE]], %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex>, tensor<?xindex> -> tensor<?xindex> | ||||||
|  | // CHECK-NEXT:    %[[MINIMUM_SHAPES:.*]]:3 = chlo.minimum_broadcast_shapes %[[PRED_SHAPE]], %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>, tensor<?xindex>, tensor<?xindex> | ||||||
|  | // CHECK-NEXT:    %[[MINIMUM_RESHAPED_PRED:.*]] = "mhlo.dynamic_reshape"(%[[PRED]], %[[MINIMUM_SHAPES]]#0) : (tensor<*xi1>, tensor<?xindex>) -> tensor<*xi1> | ||||||
|  | // CHECK-NEXT:    %[[MINIMUM_RESHAPED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[MINIMUM_SHAPES]]#1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32> | ||||||
|  | // CHECK-NEXT:    %[[MINIMUM_RESHAPED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[MINIMUM_SHAPES]]#2) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32> | ||||||
|  | // CHECK-NEXT:    %[[PRED_RANK:.*]] = shape.rank %[[MINIMUM_SHAPES]]#0 : tensor<?xindex> -> index | ||||||
|  | // CHECK-NEXT:    %[[LHS_RANK:.*]] = shape.rank %[[MINIMUM_SHAPES]]#1 : tensor<?xindex> -> index | ||||||
| // CHECK-NEXT:    %[[GREATER_RANK_CMP:.*]] = cmpi sgt, %[[PRED_RANK]], %[[LHS_RANK]] : index | // CHECK-NEXT:    %[[GREATER_RANK_CMP:.*]] = cmpi sgt, %[[PRED_RANK]], %[[LHS_RANK]] : index | ||||||
| // CHECK-NEXT:    %[[GREATER_RANK:.*]] = select %[[GREATER_RANK_CMP]], %[[PRED_RANK]], %[[LHS_RANK]] : index | // CHECK-NEXT:    %[[GREATER_RANK:.*]] = select %[[GREATER_RANK_CMP]], %[[PRED_RANK]], %[[LHS_RANK]] : index | ||||||
| // CHECK-NEXT:    %[[RHS_SHAPE:.*]] = shape.shape_of %[[RHS]] : tensor<*xf32> -> tensor<?xindex> | // CHECK-NEXT:    %[[RHS_RANK:.*]] = shape.rank %[[MINIMUM_SHAPES]]#2 : tensor<?xindex> -> index | ||||||
| // CHECK-NEXT:    %[[RHS_RANK:.*]] = shape.rank %[[RHS_SHAPE]] : tensor<?xindex> -> index |  | ||||||
| // CHECK-NEXT:    %[[GREATEST_RANK_CMP:.*]] = cmpi sgt, %[[GREATER_RANK]], %[[RHS_RANK]] : index | // CHECK-NEXT:    %[[GREATEST_RANK_CMP:.*]] = cmpi sgt, %[[GREATER_RANK]], %[[RHS_RANK]] : index | ||||||
| // CHECK-NEXT:    %[[GREATEST_RANK:.*]] = select %[[GREATEST_RANK_CMP]], %[[GREATER_RANK]], %[[RHS_RANK]] : index | // CHECK-NEXT:    %[[GREATEST_RANK:.*]] = select %[[GREATEST_RANK_CMP]], %[[GREATER_RANK]], %[[RHS_RANK]] : index | ||||||
| // CHECK-NEXT:    %c1 = constant 1 : index | // CHECK-NEXT:    %c1 = constant 1 : index | ||||||
|  | @ -339,15 +332,15 @@ func @selectUnrankedUnrankedUnranked( | ||||||
| //                Handle rank 1 specialization | //                Handle rank 1 specialization | ||||||
| // CHECK-NEXT:    scf.if %[[GREATEST_RANK_IS_1]] -> (tensor<*xf32>) { | // CHECK-NEXT:    scf.if %[[GREATEST_RANK_IS_1]] -> (tensor<*xf32>) { | ||||||
| // CHECK-NEXT:      %[[CONST_SHAPE_1:.*]] = shape.const_shape [1] : tensor<1xindex> | // CHECK-NEXT:      %[[CONST_SHAPE_1:.*]] = shape.const_shape [1] : tensor<1xindex> | ||||||
| // CHECK-NEXT:      %[[BROADCASTED_PRED:.*]] = shape.broadcast %[[PRED_SHAPE]], %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex> | // CHECK-NEXT:      %[[BROADCASTED_PRED:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#0, %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex> | ||||||
| // CHECK-NEXT:      %[[CASTED_PRED:.*]] = tensor.cast %[[BROADCASTED_PRED]] : tensor<?xindex> to tensor<1xindex> | // CHECK-NEXT:      %[[CASTED_PRED:.*]] = tensor.cast %[[BROADCASTED_PRED]] : tensor<?xindex> to tensor<1xindex> | ||||||
| // CHECK-NEXT:      %[[RESHAPED_PRED:.*]] = "mhlo.dynamic_reshape"(%[[PRED]], %[[CASTED_PRED]]) : (tensor<*xi1>, tensor<1xindex>) -> tensor<?xi1> | // CHECK-NEXT:      %[[RESHAPED_PRED:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_PRED]], %[[CASTED_PRED]]) : (tensor<*xi1>, tensor<1xindex>) -> tensor<?xi1> | ||||||
| // CHECK-NEXT:      %[[BROADCASTED_LHS:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex> | // CHECK-NEXT:      %[[BROADCASTED_LHS:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#1, %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex> | ||||||
| // CHECK-NEXT:      %[[CASTED_LHS:.*]] = tensor.cast %[[BROADCASTED_LHS]] : tensor<?xindex> to tensor<1xindex> | // CHECK-NEXT:      %[[CASTED_LHS:.*]] = tensor.cast %[[BROADCASTED_LHS]] : tensor<?xindex> to tensor<1xindex> | ||||||
| // CHECK-NEXT:      %[[RESHAPED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32> | // CHECK-NEXT:      %[[RESHAPED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_LHS]], %[[CASTED_LHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32> | ||||||
| // CHECK-NEXT:      %[[BROADCASTED_RHS:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex> | // CHECK-NEXT:      %[[BROADCASTED_RHS:.*]] = shape.broadcast %[[MINIMUM_SHAPES]]#2, %[[CONST_SHAPE_1]] : tensor<?xindex>, tensor<1xindex> -> tensor<?xindex> | ||||||
| // CHECK-NEXT:      %[[CASTED_RHS:.*]] = tensor.cast %[[BROADCASTED_RHS]] : tensor<?xindex> to tensor<1xindex> | // CHECK-NEXT:      %[[CASTED_RHS:.*]] = tensor.cast %[[BROADCASTED_RHS]] : tensor<?xindex> to tensor<1xindex> | ||||||
| // CHECK-NEXT:      %[[RESHAPED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32> | // CHECK-NEXT:      %[[RESHAPED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[MINIMUM_RESHAPED_RHS]], %[[CASTED_RHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32> | ||||||
| // CHECK-NEXT:      %[[RESULT_RANK_1:.*]] = chlo.broadcast_select %[[RESHAPED_PRED]], %[[RESHAPED_LHS]], %[[RESHAPED_RHS]] : (tensor<?xi1>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> | // CHECK-NEXT:      %[[RESULT_RANK_1:.*]] = chlo.broadcast_select %[[RESHAPED_PRED]], %[[RESHAPED_LHS]], %[[RESHAPED_RHS]] : (tensor<?xi1>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> | ||||||
| // CHECK-NEXT:      %[[RESULT_1:.*]] = tensor.cast %[[RESULT_RANK_1:.*]] : tensor<?xf32> to tensor<*xf32> | // CHECK-NEXT:      %[[RESULT_1:.*]] = tensor.cast %[[RESULT_RANK_1:.*]] : tensor<?xf32> to tensor<*xf32> | ||||||
| // CHECK-NEXT:      scf.yield %[[RESULT_1]] : tensor<*xf32> | // CHECK-NEXT:      scf.yield %[[RESULT_1]] : tensor<*xf32> | ||||||
|  | @ -357,4 +350,3 @@ func @selectUnrankedUnrankedUnranked( | ||||||
| // CHECK:      chlo.broadcast_select {{.*}} : (tensor<?x?x?xi1>, tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32> | // CHECK:      chlo.broadcast_select {{.*}} : (tensor<?x?x?xi1>, tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32> | ||||||
| // CHECK:      chlo.broadcast_select {{.*}} : (tensor<?x?x?x?xi1>, tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> | // CHECK:      chlo.broadcast_select {{.*}} : (tensor<?x?x?x?xi1>, tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> | ||||||
| // CHECK:      chlo.broadcast_select {{.*}} : (tensor<?x?x?x?x?xi1>, tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> | // CHECK:      chlo.broadcast_select {{.*}} : (tensor<?x?x?x?x?xi1>, tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> | ||||||
| // CHECK:      chlo.broadcast_select {{.*}} : (tensor<?x?x?x?x?x?xi1>, tensor<?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?xf32>) -> tensor<?x?x?x?x?x?xf32> |  | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue