Remove unnecessary conversions between Shape and ExtentTensor.
PiperOrigin-RevId: 323981215
This commit is contained in:
		
							parent
							
								
									ffef8d6593
								
							
						
					
					
						commit
						cce4bddf4b
					
				|  | @ -197,8 +197,8 @@ struct ConvertUnrankedScalarDynamicBroadcastBinaryOp | |||
|     Value shape = | ||||
|         rewriter.create<shape::ShapeOfOp>(loc, lhs_is_scalar ? rhs : lhs); | ||||
|     Value num_elements = rewriter.create<shape::NumElementsOp>(loc, shape); | ||||
|     Value size = rewriter.create<shape::SizeToIndexOp>(loc, num_elements); | ||||
|     Value size_tensor = rewriter.create<TensorFromElementsOp>(loc, size); | ||||
|     Value size_tensor = | ||||
|         rewriter.create<TensorFromElementsOp>(loc, num_elements); | ||||
|     Value reshaped = rewriter.create<mhlo::DynamicReshapeOp>( | ||||
|         loc, RankedTensorType::get({-1}, result_type.getElementType()), | ||||
|         lhs_is_scalar ? rhs : lhs, size_tensor); | ||||
|  | @ -211,10 +211,8 @@ struct ConvertUnrankedScalarDynamicBroadcastBinaryOp | |||
|         loc, SmallVector<Type, 1>{reshaped.getType()}, operands, op.getAttrs()); | ||||
| 
 | ||||
|     // Reshape the result back into an unranked tensor.
 | ||||
|     Value shape_tensor = rewriter.create<shape::ToExtentTensorOp>( | ||||
|         loc, RankedTensorType::get({-1}, rewriter.getIndexType()), shape); | ||||
|     rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(op, result_type, | ||||
|                                                         computed, shape_tensor); | ||||
|                                                         computed, shape); | ||||
| 
 | ||||
|     return success(); | ||||
|   } | ||||
|  | @ -278,18 +276,10 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp | |||
|     //
 | ||||
|     // See if shapes are equal.
 | ||||
|     OpBuilder else_no_scalars_builder = if_rhs_scalar_op.getElseBodyBuilder(); | ||||
|     auto extent_tensor_type = RankedTensorType::get({ShapedType::kDynamicSize}, | ||||
|                                                     rewriter.getIndexType()); | ||||
|     Value shape_of_lhs = | ||||
|         else_no_scalars_builder.create<shape::ToExtentTensorOp>( | ||||
|             loc, extent_tensor_type, | ||||
|             else_no_scalars_builder.create<shape::ShapeOfOp>(loc, lhs) | ||||
|                 .getResult()); | ||||
|         else_no_scalars_builder.create<shape::ShapeOfOp>(loc, lhs); | ||||
|     Value shape_of_rhs = | ||||
|         else_no_scalars_builder.create<shape::ToExtentTensorOp>( | ||||
|             loc, extent_tensor_type, | ||||
|             else_no_scalars_builder.create<shape::ShapeOfOp>(loc, rhs) | ||||
|                 .getResult()); | ||||
|         else_no_scalars_builder.create<shape::ShapeOfOp>(loc, rhs); | ||||
|     Value equal_shapes = else_no_scalars_builder.create<shape::ShapeEqOp>( | ||||
|         loc, shape_of_lhs, shape_of_rhs); | ||||
| 
 | ||||
|  | @ -319,12 +309,8 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp | |||
|   // tensor.
 | ||||
|   Value IsScalarTensor(OpBuilder &rewriter, ChloOpTy op, Value tensor) const { | ||||
|     auto loc = op.getLoc(); | ||||
|     auto extent_tensor_type = RankedTensorType::get({ShapedType::kDynamicSize}, | ||||
|                                                     rewriter.getIndexType()); | ||||
| 
 | ||||
|     Value shape_of_tensor = rewriter.create<shape::ToExtentTensorOp>( | ||||
|         loc, extent_tensor_type, | ||||
|         rewriter.create<shape::ShapeOfOp>(loc, tensor).getResult()); | ||||
|     Value shape_of_tensor = rewriter.create<shape::ShapeOfOp>(loc, tensor); | ||||
|     Value rank_tensor = rewriter.create<shape::RankOp>( | ||||
|         loc, rewriter.getIndexType(), shape_of_tensor); | ||||
|     return rewriter.create<CmpIOp>(loc, rewriter.getI1Type(), CmpIPredicate::eq, | ||||
|  |  | |||
|  | @ -252,9 +252,8 @@ func @addScalarUnranked(%arg0: tensor<f32>, %arg1: tensor<*xf32>) -> tensor<*xf3 | |||
| //                  First handle the dynamic reshaping of the unranked operand | ||||
| //                  to a 1D tensor. | ||||
| // CHECK:           %[[SHAPE_1:.*]] = shape.shape_of %[[ARG_1]] : tensor<*xf32> | ||||
| // CHECK:           %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE_1]] | ||||
| // CHECK:           %[[NUM_ELEMENTS_INDEX:.*]] = shape.size_to_index %[[NUM_ELEMENTS]] | ||||
| // CHECK:           %[[SIZE_TENSOR:.*]] = tensor_from_elements(%[[NUM_ELEMENTS_INDEX]]) : tensor<1xindex> | ||||
| // CHECK:           %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE_1]] : tensor<?xindex> -> index | ||||
| // CHECK:           %[[SIZE_TENSOR:.*]] = tensor_from_elements(%[[NUM_ELEMENTS]]) : tensor<1xindex> | ||||
| // CHECK:           %[[RESHAPED:.*]] = "mhlo.dynamic_reshape"(%[[ARG_1]], %[[SIZE_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32> | ||||
| //                  The assuming region is part of the second stage of lowering | ||||
| //                  with ranked broadcasting logic. | ||||
|  | @ -272,8 +271,7 @@ func @addScalarUnranked(%arg0: tensor<f32>, %arg1: tensor<*xf32>) -> tensor<*xf3 | |||
| // CHECK:           } | ||||
| //                  As part of the unranked logic, the result is reshaped back | ||||
| //                  to an unranked tensor. | ||||
| // CHECK:             %[[SHAPE_2:.*]] = shape.to_extent_tensor %[[SHAPE_1]] : tensor<?xindex> -> tensor<?xindex> | ||||
| // CHECK:           %[[RESHAPED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[ASSUMING_RESULT:.*]], %[[SHAPE_2]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32> | ||||
| // CHECK:           %[[RESHAPED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[ASSUMING_RESULT:.*]], %[[SHAPE_1]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32> | ||||
| // CHECK:           return %[[RESHAPED_RESULT]] : tensor<*xf32> | ||||
| // CHECK:         } | ||||
| 
 | ||||
|  | @ -289,9 +287,8 @@ func @addUnrankedScalar(%arg0: tensor<*xf32>, %arg1: tensor<f32>) -> tensor<*xf3 | |||
| //                  First handle the dynamic reshaping of the unranked operand | ||||
| //                  to a 1D tensor. | ||||
| // CHECK:           %[[SHAPE_0:.*]] = shape.shape_of %[[ARG_0]] : tensor<*xf32> | ||||
| // CHECK:           %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE_0]] | ||||
| // CHECK:           %[[NUM_ELEMENTS_INDEX:.*]] = shape.size_to_index %[[NUM_ELEMENTS]] | ||||
| // CHECK:           %[[SIZE_TENSOR:.*]] = tensor_from_elements(%[[NUM_ELEMENTS_INDEX]]) : tensor<1xindex> | ||||
| // CHECK:           %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE_0]] : tensor<?xindex> -> index | ||||
| // CHECK:           %[[SIZE_TENSOR:.*]] = tensor_from_elements(%[[NUM_ELEMENTS]]) : tensor<1xindex> | ||||
| // CHECK:           %[[RESHAPED:.*]] = "mhlo.dynamic_reshape"(%[[ARG_0]], %[[SIZE_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32> | ||||
| //                  The assuming region is part of the second stage of lowering | ||||
| //                  with ranked broadcasting logic. | ||||
|  | @ -307,8 +304,7 @@ func @addUnrankedScalar(%arg0: tensor<*xf32>, %arg1: tensor<f32>) -> tensor<*xf3 | |||
| // CHECK:           } | ||||
| //                  As part of the unranked logic, the result is reshaped back | ||||
| //                  to an unranked tensor. | ||||
| // CHECK:             %[[SHAPE_2:.*]] = shape.to_extent_tensor %[[SHAPE_0]] | ||||
| // CHECK:           %[[RESHAPED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[ASSUMING_RESULT:.*]], %[[SHAPE_2]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32> | ||||
| // CHECK:           %[[RESHAPED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[ASSUMING_RESULT:.*]], %[[SHAPE_0]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32> | ||||
| // CHECK:           return %[[RESHAPED_RESULT]] : tensor<*xf32> | ||||
| // CHECK:         } | ||||
| 
 | ||||
|  | @ -323,9 +319,8 @@ func @addUnrankedUnranked( | |||
| // CHECK-LABEL:   func @addUnrankedUnranked( | ||||
| // CHECK-SAME:          %[[LHS:.*]]: tensor<*xf32>, | ||||
| // CHECK-SAME:          %[[RHS:.*]]: tensor<*xf32>) -> tensor<*xf32> { | ||||
| // CHECK:           %[[LHS_SHAPE:.*]] = shape.shape_of %[[LHS]] : tensor<*xf32> | ||||
| // CHECK:           %[[LHS_EXTENT_TENSOR:.*]] = shape.to_extent_tensor %[[LHS_SHAPE]] : tensor<?xindex> | ||||
| // CHECK:           %[[RANK_LHS:.*]] = shape.rank %[[LHS_EXTENT_TENSOR]] | ||||
| // CHECK:           %[[LHS_SHAPE:.*]] = shape.shape_of %[[LHS]] : tensor<*xf32> -> tensor<?xindex> | ||||
| // CHECK:           %[[RANK_LHS:.*]] = shape.rank %[[LHS_SHAPE]] : tensor<?xindex> -> index | ||||
| // CHECK:           %[[C0:.*]] = constant 0 : index | ||||
| // CHECK:           %[[LHS_IS_SCALAR:.*]] = cmpi "eq", %[[RANK_LHS]], %[[C0]] : index | ||||
| //                  Handle scalar LHS case | ||||
|  | @ -334,9 +329,8 @@ func @addUnrankedUnranked( | |||
| // CHECK:             %[[VAL_10:.*]] = chlo.broadcast_add %[[SCALAR_LHS]], %[[RHS]] : (tensor<f32>, tensor<*xf32>) -> tensor<*xf32> | ||||
| // CHECK:             scf.yield %[[VAL_10]] : tensor<*xf32> | ||||
| // CHECK:           } else { | ||||
| // CHECK:             %[[RHS_SHAPE:.*]] = shape.shape_of %[[RHS]] : tensor<*xf32> | ||||
| // CHECK:             %[[RHS_EXTENT_TENSOR:.*]] = shape.to_extent_tensor %[[RHS_SHAPE]] : tensor<?xindex> | ||||
| // CHECK:             %[[RANK_RHS:.*]] = shape.rank %[[RHS_EXTENT_TENSOR]] | ||||
| // CHECK:             %[[RHS_SHAPE:.*]] = shape.shape_of %[[RHS]] : tensor<*xf32> -> tensor<?xindex> | ||||
| // CHECK:             %[[RANK_RHS:.*]] = shape.rank %[[RHS_SHAPE]] : tensor<?xindex> -> index | ||||
| // CHECK:             %[[RHS_IS_SCALAR:.*]] = cmpi "eq", %[[RANK_RHS]], %[[C0]] : index | ||||
|   //                  Handle scalar RHS case | ||||
| // CHECK:             %[[VAL_14:.*]] = scf.if %[[RHS_IS_SCALAR]] -> (tensor<*xf32>) { | ||||
|  | @ -344,7 +338,7 @@ func @addUnrankedUnranked( | |||
| // CHECK:               %[[VAL_16:.*]] = chlo.broadcast_add %[[LHS]], %[[SCALAR_RHS]] : (tensor<*xf32>, tensor<f32>) -> tensor<*xf32> | ||||
| // CHECK:               scf.yield %[[VAL_16]] : tensor<*xf32> | ||||
| // CHECK:             } else { | ||||
| // CHECK:               %[[SHAPES_EQ:.*]] = shape.shape_eq %[[LHS_EXTENT_TENSOR]], %[[RHS_EXTENT_TENSOR]] : tensor<?xindex>, tensor<?xindex> | ||||
| // CHECK:               %[[SHAPES_EQ:.*]] = shape.shape_eq %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex> | ||||
|   //                    Handle scalar RHS case | ||||
| // CHECK:               %[[VAL_18:.*]] = scf.if %[[SHAPES_EQ]] -> (tensor<*xf32>) { | ||||
| // CHECK:                 %[[VAL_19:.*]] = mhlo.add %[[LHS]], %[[RHS]] : tensor<*xf32> | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue