Do not mandate the result type of shape computations but have it be inferred from context.
The computation of a broadcasted shape forced the use of the shape type unnecessarily, which blocked further canonicalizations. PiperOrigin-RevId: 323783998
This commit is contained in:
		
							parent
							
								
									cd01bb4c4e
								
							
						
					
					
						commit
						1b0eb4baa7
					
				|  | @ -58,13 +58,10 @@ Value ComputeBinaryElementwiseBroadcastingResultExtents(Location loc, Value lhs, | |||
|   } | ||||
| 
 | ||||
|   int64_t result_rank = std::max(lhs_type.getRank(), rhs_type.getRank()); | ||||
|   auto shape_type = shape::ShapeType::get(builder.getContext()); | ||||
|   Value lhs_shape_v = | ||||
|       builder.createOrFold<shape::ShapeOfOp>(loc, shape_type, lhs); | ||||
|   Value rhs_shape_v = | ||||
|       builder.createOrFold<shape::ShapeOfOp>(loc, shape_type, rhs); | ||||
|   Value lhs_shape_v = builder.createOrFold<shape::ShapeOfOp>(loc, lhs); | ||||
|   Value rhs_shape_v = builder.createOrFold<shape::ShapeOfOp>(loc, rhs); | ||||
|   Value result_shape_v = builder.createOrFold<shape::BroadcastOp>( | ||||
|       loc, shape_type, lhs_shape_v, rhs_shape_v, nullptr /* error */); | ||||
|       loc, lhs_shape_v, rhs_shape_v, nullptr /* error */); | ||||
|   return builder.createOrFold<shape::ToExtentTensorOp>( | ||||
|       loc, RankedTensorType::get({result_rank}, builder.getIndexType()), | ||||
|       result_shape_v); | ||||
|  |  | |||
|  | @ -18,9 +18,7 @@ func @dynamicBroadcast(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> tensor<? | |||
|   // CHECK-DAG:  %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]] | ||||
|   // CHECK-NEXT: %[[WITNESS:.+]] = shape.cstr_broadcastable %[[ARG0_S]], %[[ARG1_S]] | ||||
|   // CHECK-NEXT: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]] | ||||
|   // CHECK-DAG:  %[[ARG0_SS:.+]] = shape.shape_of %[[ARG0]] | ||||
|   // CHECK-DAG:  %[[ARG1_SS:.+]] = shape.shape_of %[[ARG1]] | ||||
|   // CHECK-DAG:    %[[RESULT_S:.+]] = shape.broadcast %[[ARG0_SS]], %[[ARG1_SS]] | ||||
|   // CHECK-DAG:    %[[RESULT_S:.+]] = shape.broadcast %[[ARG0_S]], %[[ARG1_S]] | ||||
|   // CHECK:        %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_S]] | ||||
|   // CHECK-DAG:    %[[ARG0_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} | ||||
|   // CHECK-DAG:    %[[ARG1_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} | ||||
|  | @ -41,9 +39,7 @@ func @dynamicBroadcastComplex(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> t | |||
|   // CHECK-DAG:  %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]] | ||||
|   // CHECK-NEXT: %[[WITNESS:.+]] = shape.cstr_broadcastable %[[ARG0_S]], %[[ARG1_S]] | ||||
|   // CHECK-NEXT: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]] | ||||
|   // CHECK-DAG:  %[[ARG0_SS:.+]] = shape.shape_of %[[ARG0]] | ||||
|   // CHECK-DAG:  %[[ARG1_SS:.+]] = shape.shape_of %[[ARG1]] | ||||
|   // CHECK-NEXT:   %[[RESULT_S:.+]] = shape.broadcast %[[ARG0_SS]], %[[ARG1_SS]] | ||||
|   // CHECK-NEXT:   %[[RESULT_S:.+]] = shape.broadcast %[[ARG0_S]], %[[ARG1_S]] | ||||
|   // CHECK-NEXT:   %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_S]] | ||||
|   // CHECK-DAG:    %[[ARG0_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32> | ||||
|   // CHECK-DAG:    %[[ARG1_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32> | ||||
|  | @ -64,9 +60,7 @@ func @dynamicBroadcastCompare(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> t | |||
|   // CHECK-DAG: %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]] | ||||
|   // CHECK: %[[WITNESS:.+]] = shape.cstr_broadcastable %[[ARG0_S]], %[[ARG1_S]] | ||||
|   // CHECK: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]] | ||||
|   // CHECK-DAG:  %[[ARG0_SS:.+]] = shape.shape_of %[[ARG0]] | ||||
|   // CHECK-DAG:  %[[ARG1_SS:.+]] = shape.shape_of %[[ARG1]] | ||||
|   // CHECK: %[[RESULT_S:.+]] = shape.broadcast %[[ARG0_SS]], %[[ARG1_SS]] | ||||
|   // CHECK: %[[RESULT_S:.+]] = shape.broadcast %[[ARG0_S]], %[[ARG1_S]] | ||||
|   // CHECK: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_S]] | ||||
|   // CHECK-DAG: %[[ARG0_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32> | ||||
|   // CHECK-DAG: %[[ARG1_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32> | ||||
|  | @ -269,7 +263,6 @@ func @addScalarUnranked(%arg0: tensor<f32>, %arg1: tensor<*xf32>) -> tensor<*xf3 | |||
| // CHECK:           %[[WITNESS:.*]] = shape.cstr_broadcastable %[[SHAPE_0]], %[[SHAPE_RESHAPED]] | ||||
| // CHECK:           %[[ASSUMING_RESULT:.*]] = shape.assuming %[[WITNESS]] -> (tensor<?xf32>) { | ||||
| // CHECK:             %[[SCALAR_SHAPE:.*]] = shape.const_shape [] | ||||
| // CHECK:             %[[SHAPE_RESHAPED:.*]] = shape.shape_of %[[RESHAPED]] : tensor<?xf32> | ||||
| // CHECK:             %[[BROADCASTED_SHAPE:.*]] = shape.broadcast %[[SCALAR_SHAPE]], %[[SHAPE_RESHAPED]] | ||||
| // CHECK:             %[[SHAPE_TENSOR:.*]] = shape.to_extent_tensor %[[BROADCASTED_SHAPE]] : !shape.shape -> tensor<1xindex> | ||||
| // CHECK:             %[[BROADCASTED_LHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG_0]], %[[SHAPE_TENSOR]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>, tensor<1xindex>) -> tensor<?xf32> | ||||
|  | @ -306,10 +299,9 @@ func @addUnrankedScalar(%arg0: tensor<*xf32>, %arg1: tensor<f32>) -> tensor<*xf3 | |||
| // CHECK:           %[[SHAPE_1:.*]] = shape.shape_of %[[ARG_1]] : tensor<f32> | ||||
| // CHECK:           %[[WITNESS:.*]] = shape.cstr_broadcastable %[[SHAPE_RESHAPED]], %[[SHAPE_1]] | ||||
| // CHECK:           %[[ASSUMING_RESULT:.*]] = shape.assuming %[[WITNESS]] -> (tensor<?xf32>) { | ||||
| // CHECK:             %[[SHAPE_OF:.*]] = shape.shape_of %[[RESHAPED]] : tensor<?xf32> | ||||
| // CHECK:             %[[SHAPE_RESHAPED:.*]] = shape.to_extent_tensor %[[SHAPE_OF]] | ||||
| // CHECK:             %[[BROADCASTED_LHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[RESHAPED]], %[[SHAPE_RESHAPED]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xf32>, tensor<1xindex>) -> tensor<?xf32> | ||||
| // CHECK:             %[[BROADCASTED_RHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG_1]], %[[SHAPE_RESHAPED]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>, tensor<1xindex>) -> tensor<?xf32> | ||||
| // CHECK:             %[[ASTENSOR:.*]] = shape.to_extent_tensor %[[SHAPE_RESHAPED]] | ||||
| // CHECK:             %[[BROADCASTED_LHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[RESHAPED]], %[[ASTENSOR]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xf32>, tensor<1xindex>) -> tensor<?xf32> | ||||
| // CHECK:             %[[BROADCASTED_RHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG_1]], %[[ASTENSOR]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>, tensor<1xindex>) -> tensor<?xf32> | ||||
| // CHECK:             %[[BROADCASTED_RESULT:.*]] = mhlo.add %[[BROADCASTED_LHS]], %[[BROADCASTED_RHS]] : tensor<?xf32> | ||||
| // CHECK:             shape.assuming_yield %[[BROADCASTED_RESULT]] : tensor<?xf32> | ||||
| // CHECK:           } | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue