Insert explicit casts to model extra shape knowledge for unranked chlo transform
When transforming unranked binary operations from CHLO to HLO, we insert `shape.broadcast` operations. Due to context, we know that the result of the `shape.broadcast` operation has a static shape. Instead of modelling this in the type of the broadcast operation itself, which is illegal, we now use an explicit cast. PiperOrigin-RevId: 331989879
This commit is contained in:
		
							parent
							
								
									1880f87737
								
							
						
					
					
						commit
						2aa07b0091
					
				| 
						 | 
					@ -373,30 +373,37 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp
 | 
				
			||||||
    Value lhs_shape = if_builder.create<shape::ShapeOfOp>(loc, lhs);
 | 
					    Value lhs_shape = if_builder.create<shape::ShapeOfOp>(loc, lhs);
 | 
				
			||||||
    Value rhs_shape = if_builder.create<shape::ShapeOfOp>(loc, rhs);
 | 
					    Value rhs_shape = if_builder.create<shape::ShapeOfOp>(loc, rhs);
 | 
				
			||||||
    SmallVector<int64_t, 6> ranked_shape(targeted_rank, 1);
 | 
					    SmallVector<int64_t, 6> ranked_shape(targeted_rank, 1);
 | 
				
			||||||
    auto extent_tensor_type =
 | 
					    auto unknown_rank_extent_tensor_type = RankedTensorType::get(
 | 
				
			||||||
 | 
					        {RankedTensorType::kDynamicSize}, builder.getIndexType());
 | 
				
			||||||
 | 
					    auto known_rank_extent_tensor_type =
 | 
				
			||||||
        RankedTensorType::get({targeted_rank}, builder.getIndexType());
 | 
					        RankedTensorType::get({targeted_rank}, builder.getIndexType());
 | 
				
			||||||
    auto reshaped_type = RankedTensorType::get(
 | 
					    auto reshaped_type = RankedTensorType::get(
 | 
				
			||||||
        llvm::SmallVector<int64_t, 6>(targeted_rank,
 | 
					        llvm::SmallVector<int64_t, 6>(targeted_rank,
 | 
				
			||||||
                                      RankedTensorType::kDynamicSize),
 | 
					                                      RankedTensorType::kDynamicSize),
 | 
				
			||||||
        lhs.getType().template dyn_cast<TensorType>().getElementType());
 | 
					        lhs.getType().template dyn_cast<TensorType>().getElementType());
 | 
				
			||||||
    Value ranked_shape_val = if_builder.create<shape::ConstShapeOp>(
 | 
					    Value ranked_shape_val = if_builder.create<shape::ConstShapeOp>(
 | 
				
			||||||
        loc, extent_tensor_type,
 | 
					        loc, known_rank_extent_tensor_type,
 | 
				
			||||||
        mlir::DenseIntElementsAttr::get(extent_tensor_type, ranked_shape));
 | 
					        mlir::DenseIntElementsAttr::get(known_rank_extent_tensor_type,
 | 
				
			||||||
    // TODO(tpopp): Return extent tensors when possible to signal that this is a
 | 
					                                        ranked_shape));
 | 
				
			||||||
    // guaranteed safe broadcast by construction.
 | 
					 | 
				
			||||||
    Value extended_lhs = if_builder.create<shape::BroadcastOp>(
 | 
					    Value extended_lhs = if_builder.create<shape::BroadcastOp>(
 | 
				
			||||||
        loc, extent_tensor_type, lhs_shape, ranked_shape_val, nullptr);
 | 
					        loc, unknown_rank_extent_tensor_type, lhs_shape, ranked_shape_val,
 | 
				
			||||||
 | 
					        nullptr);
 | 
				
			||||||
 | 
					    Value extended_lhs_casted = if_builder.create<TensorCastOp>(
 | 
				
			||||||
 | 
					        loc, known_rank_extent_tensor_type, extended_lhs);
 | 
				
			||||||
    Value extended_rhs = if_builder.create<shape::BroadcastOp>(
 | 
					    Value extended_rhs = if_builder.create<shape::BroadcastOp>(
 | 
				
			||||||
        loc, extent_tensor_type, rhs_shape, ranked_shape_val, nullptr);
 | 
					        loc, unknown_rank_extent_tensor_type, rhs_shape, ranked_shape_val,
 | 
				
			||||||
 | 
					        nullptr);
 | 
				
			||||||
 | 
					    Value extended_rhs_casted = if_builder.create<TensorCastOp>(
 | 
				
			||||||
 | 
					        loc, known_rank_extent_tensor_type, extended_rhs);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // 1. Reshape operands to the given rank (with the same number of elements)
 | 
					    // 1. Reshape operands to the given rank (with the same number of elements)
 | 
				
			||||||
    // 2. Compute the ranked-broadcasted ChloOp (which will assert that the ops
 | 
					    // 2. Compute the ranked-broadcasted ChloOp (which will assert that the ops
 | 
				
			||||||
    //    can be broadcasted and do the actual broadcasting)
 | 
					    //    can be broadcasted and do the actual broadcasting)
 | 
				
			||||||
    // 3. Type erase the output back to unranked
 | 
					    // 3. Type erase the output back to unranked
 | 
				
			||||||
    Value reshaped_lhs = if_builder.create<mhlo::DynamicReshapeOp>(
 | 
					    Value reshaped_lhs = if_builder.create<mhlo::DynamicReshapeOp>(
 | 
				
			||||||
        loc, reshaped_type, lhs, extended_lhs);
 | 
					        loc, reshaped_type, lhs, extended_lhs_casted);
 | 
				
			||||||
    Value reshaped_rhs = if_builder.create<mhlo::DynamicReshapeOp>(
 | 
					    Value reshaped_rhs = if_builder.create<mhlo::DynamicReshapeOp>(
 | 
				
			||||||
        loc, reshaped_type, rhs, extended_rhs);
 | 
					        loc, reshaped_type, rhs, extended_rhs_casted);
 | 
				
			||||||
    Value result = if_builder.create<ChloOpTy>(
 | 
					    Value result = if_builder.create<ChloOpTy>(
 | 
				
			||||||
        loc, ArrayRef<Type>{reshaped_type},
 | 
					        loc, ArrayRef<Type>{reshaped_type},
 | 
				
			||||||
        ArrayRef<Value>{reshaped_lhs, reshaped_rhs}, op.getAttrs());
 | 
					        ArrayRef<Value>{reshaped_lhs, reshaped_rhs}, op.getAttrs());
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -353,10 +353,12 @@ func @addUnrankedUnranked(
 | 
				
			||||||
//                        Handle rank 2 specialization
 | 
					//                        Handle rank 2 specialization
 | 
				
			||||||
// CHECK:                 %[[VAL_26:.*]] = scf.if %[[GREATEST_RANK_IS_2]] -> (tensor<*xf32>) {
 | 
					// CHECK:                 %[[VAL_26:.*]] = scf.if %[[GREATEST_RANK_IS_2]] -> (tensor<*xf32>) {
 | 
				
			||||||
// CHECK:                   %[[CONST_SHAPE_2:.*]] = shape.const_shape [1, 1]
 | 
					// CHECK:                   %[[CONST_SHAPE_2:.*]] = shape.const_shape [1, 1]
 | 
				
			||||||
// CHECK:                   %[[BROADCASTED_LHS_2:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_2]] : tensor<?xindex>, tensor<2xindex> -> tensor<2xindex>
 | 
					// CHECK:                   %[[BROADCASTED_LHS_2:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_2]] : tensor<?xindex>, tensor<2xindex> -> tensor<?xindex>
 | 
				
			||||||
// CHECK:                   %[[BROADCASTED_RHS_2:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_2]] : tensor<?xindex>, tensor<2xindex> -> tensor<2xindex>
 | 
					// CHECK:                   %[[CASTED_LHS_2:.*]] = tensor_cast %[[BROADCASTED_LHS_2]] : tensor<?xindex> to tensor<2xindex>
 | 
				
			||||||
// CHECK:                   %[[RESHAPED_LHS_2:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[BROADCASTED_LHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor<?x?xf32>
 | 
					// CHECK:                   %[[BROADCASTED_RHS_2:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_2]] : tensor<?xindex>, tensor<2xindex> -> tensor<?xindex>
 | 
				
			||||||
// CHECK:                   %[[RESHAPED_RHS_2:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[BROADCASTED_RHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor<?x?xf32>
 | 
					// CHECK:                   %[[CASTED_RHS_2:.*]] = tensor_cast %[[BROADCASTED_RHS_2]] : tensor<?xindex> to tensor<2xindex>
 | 
				
			||||||
 | 
					// CHECK:                   %[[RESHAPED_LHS_2:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor<?x?xf32>
 | 
				
			||||||
 | 
					// CHECK:                   %[[RESHAPED_RHS_2:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor<?x?xf32>
 | 
				
			||||||
// CHECK:                   %[[RESULT_RANK_2:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_2]], %[[RESHAPED_RHS_2]] : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
 | 
					// CHECK:                   %[[RESULT_RANK_2:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_2]], %[[RESHAPED_RHS_2]] : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
 | 
				
			||||||
// CHECK:                   %[[RESULT_2:.*]] = tensor_cast %[[RESULT_RANK_2]] : tensor<?x?xf32> to tensor<*xf32>
 | 
					// CHECK:                   %[[RESULT_2:.*]] = tensor_cast %[[RESULT_RANK_2]] : tensor<?x?xf32> to tensor<*xf32>
 | 
				
			||||||
// CHECK:                   scf.yield %[[RESULT_2]] : tensor<*xf32>
 | 
					// CHECK:                   scf.yield %[[RESULT_2]] : tensor<*xf32>
 | 
				
			||||||
| 
						 | 
					@ -366,10 +368,12 @@ func @addUnrankedUnranked(
 | 
				
			||||||
//                          Handle rank 3 specialization
 | 
					//                          Handle rank 3 specialization
 | 
				
			||||||
// CHECK:                   %[[VAL_34:.*]] = scf.if %[[GREATEST_RANK_IS_3]] -> (tensor<*xf32>) {
 | 
					// CHECK:                   %[[VAL_34:.*]] = scf.if %[[GREATEST_RANK_IS_3]] -> (tensor<*xf32>) {
 | 
				
			||||||
// CHECK:                     %[[CONST_SHAPE_3:.*]] = shape.const_shape [1, 1, 1]
 | 
					// CHECK:                     %[[CONST_SHAPE_3:.*]] = shape.const_shape [1, 1, 1]
 | 
				
			||||||
// CHECK:                     %[[BROADCASTED_LHS_3:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_3]] : tensor<?xindex>, tensor<3xindex> -> tensor<3xindex>
 | 
					// CHECK:                     %[[BROADCASTED_LHS_3:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_3]] : tensor<?xindex>, tensor<3xindex> -> tensor<?xindex>
 | 
				
			||||||
// CHECK:                     %[[BROADCASTED_RHS_3:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_3]] : tensor<?xindex>, tensor<3xindex> -> tensor<3xindex>
 | 
					// CHECK:                     %[[CASTED_LHS_3:.*]] = tensor_cast %[[BROADCASTED_LHS_3]] : tensor<?xindex> to tensor<3xindex>
 | 
				
			||||||
// CHECK:                     %[[RESHAPED_LHS_3:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[BROADCASTED_LHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
 | 
					// CHECK:                     %[[BROADCASTED_RHS_3:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_3]] : tensor<?xindex>, tensor<3xindex> -> tensor<?xindex>
 | 
				
			||||||
// CHECK:                     %[[RESHAPED_RHS_3:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[BROADCASTED_RHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
 | 
					// CHECK:                     %[[CASTED_RHS_3:.*]] = tensor_cast %[[BROADCASTED_RHS_3]] : tensor<?xindex> to tensor<3xindex>
 | 
				
			||||||
 | 
					// CHECK:                     %[[RESHAPED_LHS_3:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
 | 
				
			||||||
 | 
					// CHECK:                     %[[RESHAPED_RHS_3:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
 | 
				
			||||||
// CHECK:                     %[[RESULT_RANK_3:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_3]], %[[RESHAPED_RHS_3]] : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
 | 
					// CHECK:                     %[[RESULT_RANK_3:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_3]], %[[RESHAPED_RHS_3]] : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
 | 
				
			||||||
// CHECK:                     %[[RESULT_3:.*]] = tensor_cast %[[RESULT_RANK_3]] : tensor<?x?x?xf32> to tensor<*xf32>
 | 
					// CHECK:                     %[[RESULT_3:.*]] = tensor_cast %[[RESULT_RANK_3]] : tensor<?x?x?xf32> to tensor<*xf32>
 | 
				
			||||||
// CHECK:                     scf.yield %[[RESULT_3]] : tensor<*xf32>
 | 
					// CHECK:                     scf.yield %[[RESULT_3]] : tensor<*xf32>
 | 
				
			||||||
| 
						 | 
					@ -379,10 +383,12 @@ func @addUnrankedUnranked(
 | 
				
			||||||
//                            Handle rank 4 specialization
 | 
					//                            Handle rank 4 specialization
 | 
				
			||||||
// CHECK:                     %[[VAL_42:.*]] = scf.if %[[GREATEST_RANK_IS_4]] -> (tensor<*xf32>) {
 | 
					// CHECK:                     %[[VAL_42:.*]] = scf.if %[[GREATEST_RANK_IS_4]] -> (tensor<*xf32>) {
 | 
				
			||||||
// CHECK:                       %[[CONST_SHAPE_4:.*]] = shape.const_shape [1, 1, 1, 1]
 | 
					// CHECK:                       %[[CONST_SHAPE_4:.*]] = shape.const_shape [1, 1, 1, 1]
 | 
				
			||||||
// CHECK:                       %[[BROADCASTED_LHS_4:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_4]] : tensor<?xindex>, tensor<4xindex> -> tensor<4xindex>
 | 
					// CHECK:                       %[[BROADCASTED_LHS_4:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_4]] : tensor<?xindex>, tensor<4xindex> -> tensor<?xindex>
 | 
				
			||||||
// CHECK:                       %[[BROADCASTED_RHS_4:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_4]] : tensor<?xindex>, tensor<4xindex> -> tensor<4xindex>
 | 
					// CHECK:                       %[[CASTED_LHS_4:.*]] = tensor_cast %[[BROADCASTED_LHS_4]] : tensor<?xindex> to tensor<4xindex>
 | 
				
			||||||
// CHECK:                       %[[RESHAPED_LHS_4:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[BROADCASTED_LHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
 | 
					// CHECK:                       %[[BROADCASTED_RHS_4:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_4]] : tensor<?xindex>, tensor<4xindex> -> tensor<?xindex>
 | 
				
			||||||
// CHECK:                       %[[RESHAPED_RHS_4:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[BROADCASTED_RHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
 | 
					// CHECK:                       %[[CASTED_RHS_4:.*]] = tensor_cast %[[BROADCASTED_RHS_4]] : tensor<?xindex> to tensor<4xindex>
 | 
				
			||||||
 | 
					// CHECK:                       %[[RESHAPED_LHS_4:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
 | 
				
			||||||
 | 
					// CHECK:                       %[[RESHAPED_RHS_4:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
 | 
				
			||||||
// CHECK:                       %[[RESULT_RANK_4:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_4]], %[[RESHAPED_RHS_4]] : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
 | 
					// CHECK:                       %[[RESULT_RANK_4:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_4]], %[[RESHAPED_RHS_4]] : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
 | 
				
			||||||
// CHECK:                       %[[RESULT_4:.*]] = tensor_cast %[[RESULT_RANK_4]] : tensor<?x?x?x?xf32> to tensor<*xf32>
 | 
					// CHECK:                       %[[RESULT_4:.*]] = tensor_cast %[[RESULT_RANK_4]] : tensor<?x?x?x?xf32> to tensor<*xf32>
 | 
				
			||||||
// CHECK:                       scf.yield %[[RESULT_4]] : tensor<*xf32>
 | 
					// CHECK:                       scf.yield %[[RESULT_4]] : tensor<*xf32>
 | 
				
			||||||
| 
						 | 
					@ -392,10 +398,12 @@ func @addUnrankedUnranked(
 | 
				
			||||||
//                              Handle rank 5 specialization
 | 
					//                              Handle rank 5 specialization
 | 
				
			||||||
// CHECK:                       %[[VAL_50:.*]] = scf.if %[[GREATEST_RANK_IS_5]] -> (tensor<*xf32>) {
 | 
					// CHECK:                       %[[VAL_50:.*]] = scf.if %[[GREATEST_RANK_IS_5]] -> (tensor<*xf32>) {
 | 
				
			||||||
// CHECK:                         %[[CONST_SHAPE_5:.*]] = shape.const_shape [1, 1, 1, 1, 1]
 | 
					// CHECK:                         %[[CONST_SHAPE_5:.*]] = shape.const_shape [1, 1, 1, 1, 1]
 | 
				
			||||||
// CHECK:                         %[[BROADCASTED_LHS_5:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_5]] : tensor<?xindex>, tensor<5xindex> -> tensor<5xindex>
 | 
					// CHECK:                         %[[BROADCASTED_LHS_5:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_5]] : tensor<?xindex>, tensor<5xindex> -> tensor<?xindex>
 | 
				
			||||||
// CHECK:                         %[[BROADCASTED_RHS_5:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_5]] : tensor<?xindex>, tensor<5xindex> -> tensor<5xindex>
 | 
					// CHECK:                         %[[CASTED_LHS_5:.*]] = tensor_cast %[[BROADCASTED_LHS_5]] : tensor<?xindex> to tensor<5xindex>
 | 
				
			||||||
// CHECK:                         %[[RESHAPED_LHS_5:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[BROADCASTED_LHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor<?x?x?x?x?xf32>
 | 
					// CHECK:                         %[[BROADCASTED_RHS_5:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_5]] : tensor<?xindex>, tensor<5xindex> -> tensor<?xindex>
 | 
				
			||||||
// CHECK:                         %[[RESHAPED_RHS_5:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[BROADCASTED_RHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor<?x?x?x?x?xf32>
 | 
					// CHECK:                         %[[CASTED_RHS_5:.*]] = tensor_cast %[[BROADCASTED_RHS_5]] : tensor<?xindex> to tensor<5xindex>
 | 
				
			||||||
 | 
					// CHECK:                         %[[RESHAPED_LHS_5:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor<?x?x?x?x?xf32>
 | 
				
			||||||
 | 
					// CHECK:                         %[[RESHAPED_RHS_5:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor<?x?x?x?x?xf32>
 | 
				
			||||||
// CHECK:                         %[[RESULT_RANK_5:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_5]], %[[RESHAPED_RHS_5]] : (tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
 | 
					// CHECK:                         %[[RESULT_RANK_5:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_5]], %[[RESHAPED_RHS_5]] : (tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
 | 
				
			||||||
// CHECK:                         %[[RESULT_5:.*]] = tensor_cast %[[RESULT_RANK_5]] : tensor<?x?x?x?x?xf32> to tensor<*xf32>
 | 
					// CHECK:                         %[[RESULT_5:.*]] = tensor_cast %[[RESULT_RANK_5]] : tensor<?x?x?x?x?xf32> to tensor<*xf32>
 | 
				
			||||||
// CHECK:                         scf.yield %[[RESULT_5]] : tensor<*xf32>
 | 
					// CHECK:                         scf.yield %[[RESULT_5]] : tensor<*xf32>
 | 
				
			||||||
| 
						 | 
					@ -405,10 +413,12 @@ func @addUnrankedUnranked(
 | 
				
			||||||
//                                Handle rank 6 specialization
 | 
					//                                Handle rank 6 specialization
 | 
				
			||||||
// CHECK:                         %[[VAL_58:.*]] = scf.if %[[GREATEST_RANK_IS_6]] -> (tensor<*xf32>) {
 | 
					// CHECK:                         %[[VAL_58:.*]] = scf.if %[[GREATEST_RANK_IS_6]] -> (tensor<*xf32>) {
 | 
				
			||||||
// CHECK:                           %[[CONST_SHAPE_6:.*]] = shape.const_shape [1, 1, 1, 1, 1, 1]
 | 
					// CHECK:                           %[[CONST_SHAPE_6:.*]] = shape.const_shape [1, 1, 1, 1, 1, 1]
 | 
				
			||||||
// CHECK:                           %[[BROADCASTED_LHS_6:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_6]] : tensor<?xindex>, tensor<6xindex> -> tensor<6xindex>
 | 
					// CHECK:                           %[[BROADCASTED_LHS_6:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_6]] : tensor<?xindex>, tensor<6xindex> -> tensor<?xindex>
 | 
				
			||||||
// CHECK:                           %[[BROADCASTED_RHS_6:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_6]] : tensor<?xindex>, tensor<6xindex> -> tensor<6xindex>
 | 
					// CHECK:                           %[[CASTED_LHS_6:.*]] = tensor_cast %[[BROADCASTED_LHS_6]] : tensor<?xindex> to tensor<6xindex>
 | 
				
			||||||
// CHECK:                           %[[RESHAPED_LHS_6:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[BROADCASTED_LHS_6]]) : (tensor<*xf32>, tensor<6xindex>) -> tensor<?x?x?x?x?x?xf32>
 | 
					// CHECK:                           %[[BROADCASTED_RHS_6:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_6]] : tensor<?xindex>, tensor<6xindex> -> tensor<?xindex>
 | 
				
			||||||
// CHECK:                           %[[RESHAPED_RHS_6:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[BROADCASTED_RHS_6]]) : (tensor<*xf32>, tensor<6xindex>) -> tensor<?x?x?x?x?x?xf32>
 | 
					// CHECK:                           %[[CASTED_RHS_6:.*]] = tensor_cast %[[BROADCASTED_RHS_6]] : tensor<?xindex> to tensor<6xindex>
 | 
				
			||||||
 | 
					// CHECK:                           %[[RESHAPED_LHS_6:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[CASTED_LHS_6]]) : (tensor<*xf32>, tensor<6xindex>) -> tensor<?x?x?x?x?x?xf32>
 | 
				
			||||||
 | 
					// CHECK:                           %[[RESHAPED_RHS_6:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[CASTED_RHS_6]]) : (tensor<*xf32>, tensor<6xindex>) -> tensor<?x?x?x?x?x?xf32>
 | 
				
			||||||
// CHECK:                           %[[RESULT_RANK_6:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_6]], %[[RESHAPED_RHS_6]] : (tensor<?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?xf32>) -> tensor<?x?x?x?x?x?xf32>
 | 
					// CHECK:                           %[[RESULT_RANK_6:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_6]], %[[RESHAPED_RHS_6]] : (tensor<?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?xf32>) -> tensor<?x?x?x?x?x?xf32>
 | 
				
			||||||
// CHECK:                           %[[RESULT_6:.*]] = tensor_cast %[[RESULT_RANK_6]] : tensor<?x?x?x?x?x?xf32> to tensor<*xf32>
 | 
					// CHECK:                           %[[RESULT_6:.*]] = tensor_cast %[[RESULT_RANK_6]] : tensor<?x?x?x?x?x?xf32> to tensor<*xf32>
 | 
				
			||||||
// CHECK:                           scf.yield %[[RESULT_6]] : tensor<*xf32>
 | 
					// CHECK:                           scf.yield %[[RESULT_6]] : tensor<*xf32>
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue