Do not mandate the result type of shape computations but have it be inferred from context.
The computation of a broadcasted shape forced the use of the shape type unnecessarily, which blocked further canonicalizations. PiperOrigin-RevId: 323783998
This commit is contained in:
parent
cd01bb4c4e
commit
1b0eb4baa7
|
@ -58,13 +58,10 @@ Value ComputeBinaryElementwiseBroadcastingResultExtents(Location loc, Value lhs,
|
||||||
}
|
}
|
||||||
|
|
||||||
int64_t result_rank = std::max(lhs_type.getRank(), rhs_type.getRank());
|
int64_t result_rank = std::max(lhs_type.getRank(), rhs_type.getRank());
|
||||||
auto shape_type = shape::ShapeType::get(builder.getContext());
|
Value lhs_shape_v = builder.createOrFold<shape::ShapeOfOp>(loc, lhs);
|
||||||
Value lhs_shape_v =
|
Value rhs_shape_v = builder.createOrFold<shape::ShapeOfOp>(loc, rhs);
|
||||||
builder.createOrFold<shape::ShapeOfOp>(loc, shape_type, lhs);
|
|
||||||
Value rhs_shape_v =
|
|
||||||
builder.createOrFold<shape::ShapeOfOp>(loc, shape_type, rhs);
|
|
||||||
Value result_shape_v = builder.createOrFold<shape::BroadcastOp>(
|
Value result_shape_v = builder.createOrFold<shape::BroadcastOp>(
|
||||||
loc, shape_type, lhs_shape_v, rhs_shape_v, nullptr /* error */);
|
loc, lhs_shape_v, rhs_shape_v, nullptr /* error */);
|
||||||
return builder.createOrFold<shape::ToExtentTensorOp>(
|
return builder.createOrFold<shape::ToExtentTensorOp>(
|
||||||
loc, RankedTensorType::get({result_rank}, builder.getIndexType()),
|
loc, RankedTensorType::get({result_rank}, builder.getIndexType()),
|
||||||
result_shape_v);
|
result_shape_v);
|
||||||
|
|
|
@ -18,9 +18,7 @@ func @dynamicBroadcast(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?
|
||||||
// CHECK-DAG: %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]]
|
// CHECK-DAG: %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]]
|
||||||
// CHECK-NEXT: %[[WITNESS:.+]] = shape.cstr_broadcastable %[[ARG0_S]], %[[ARG1_S]]
|
// CHECK-NEXT: %[[WITNESS:.+]] = shape.cstr_broadcastable %[[ARG0_S]], %[[ARG1_S]]
|
||||||
// CHECK-NEXT: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]]
|
// CHECK-NEXT: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]]
|
||||||
// CHECK-DAG: %[[ARG0_SS:.+]] = shape.shape_of %[[ARG0]]
|
// CHECK-DAG: %[[RESULT_S:.+]] = shape.broadcast %[[ARG0_S]], %[[ARG1_S]]
|
||||||
// CHECK-DAG: %[[ARG1_SS:.+]] = shape.shape_of %[[ARG1]]
|
|
||||||
// CHECK-DAG: %[[RESULT_S:.+]] = shape.broadcast %[[ARG0_SS]], %[[ARG1_SS]]
|
|
||||||
// CHECK: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_S]]
|
// CHECK: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_S]]
|
||||||
// CHECK-DAG: %[[ARG0_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
// CHECK-DAG: %[[ARG0_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||||
// CHECK-DAG: %[[ARG1_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}
|
// CHECK-DAG: %[[ARG1_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}
|
||||||
|
@ -41,9 +39,7 @@ func @dynamicBroadcastComplex(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> t
|
||||||
// CHECK-DAG: %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]]
|
// CHECK-DAG: %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]]
|
||||||
// CHECK-NEXT: %[[WITNESS:.+]] = shape.cstr_broadcastable %[[ARG0_S]], %[[ARG1_S]]
|
// CHECK-NEXT: %[[WITNESS:.+]] = shape.cstr_broadcastable %[[ARG0_S]], %[[ARG1_S]]
|
||||||
// CHECK-NEXT: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]]
|
// CHECK-NEXT: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]]
|
||||||
// CHECK-DAG: %[[ARG0_SS:.+]] = shape.shape_of %[[ARG0]]
|
// CHECK-NEXT: %[[RESULT_S:.+]] = shape.broadcast %[[ARG0_S]], %[[ARG1_S]]
|
||||||
// CHECK-DAG: %[[ARG1_SS:.+]] = shape.shape_of %[[ARG1]]
|
|
||||||
// CHECK-NEXT: %[[RESULT_S:.+]] = shape.broadcast %[[ARG0_SS]], %[[ARG1_SS]]
|
|
||||||
// CHECK-NEXT: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_S]]
|
// CHECK-NEXT: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_S]]
|
||||||
// CHECK-DAG: %[[ARG0_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
// CHECK-DAG: %[[ARG0_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||||
// CHECK-DAG: %[[ARG1_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
// CHECK-DAG: %[[ARG1_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||||
|
@ -64,9 +60,7 @@ func @dynamicBroadcastCompare(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> t
|
||||||
// CHECK-DAG: %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]]
|
// CHECK-DAG: %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]]
|
||||||
// CHECK: %[[WITNESS:.+]] = shape.cstr_broadcastable %[[ARG0_S]], %[[ARG1_S]]
|
// CHECK: %[[WITNESS:.+]] = shape.cstr_broadcastable %[[ARG0_S]], %[[ARG1_S]]
|
||||||
// CHECK: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]]
|
// CHECK: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]]
|
||||||
// CHECK-DAG: %[[ARG0_SS:.+]] = shape.shape_of %[[ARG0]]
|
// CHECK: %[[RESULT_S:.+]] = shape.broadcast %[[ARG0_S]], %[[ARG1_S]]
|
||||||
// CHECK-DAG: %[[ARG1_SS:.+]] = shape.shape_of %[[ARG1]]
|
|
||||||
// CHECK: %[[RESULT_S:.+]] = shape.broadcast %[[ARG0_SS]], %[[ARG1_SS]]
|
|
||||||
// CHECK: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_S]]
|
// CHECK: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_S]]
|
||||||
// CHECK-DAG: %[[ARG0_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
// CHECK-DAG: %[[ARG0_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||||
// CHECK-DAG: %[[ARG1_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
// CHECK-DAG: %[[ARG1_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||||
|
@ -269,7 +263,6 @@ func @addScalarUnranked(%arg0: tensor<f32>, %arg1: tensor<*xf32>) -> tensor<*xf3
|
||||||
// CHECK: %[[WITNESS:.*]] = shape.cstr_broadcastable %[[SHAPE_0]], %[[SHAPE_RESHAPED]]
|
// CHECK: %[[WITNESS:.*]] = shape.cstr_broadcastable %[[SHAPE_0]], %[[SHAPE_RESHAPED]]
|
||||||
// CHECK: %[[ASSUMING_RESULT:.*]] = shape.assuming %[[WITNESS]] -> (tensor<?xf32>) {
|
// CHECK: %[[ASSUMING_RESULT:.*]] = shape.assuming %[[WITNESS]] -> (tensor<?xf32>) {
|
||||||
// CHECK: %[[SCALAR_SHAPE:.*]] = shape.const_shape []
|
// CHECK: %[[SCALAR_SHAPE:.*]] = shape.const_shape []
|
||||||
// CHECK: %[[SHAPE_RESHAPED:.*]] = shape.shape_of %[[RESHAPED]] : tensor<?xf32>
|
|
||||||
// CHECK: %[[BROADCASTED_SHAPE:.*]] = shape.broadcast %[[SCALAR_SHAPE]], %[[SHAPE_RESHAPED]]
|
// CHECK: %[[BROADCASTED_SHAPE:.*]] = shape.broadcast %[[SCALAR_SHAPE]], %[[SHAPE_RESHAPED]]
|
||||||
// CHECK: %[[SHAPE_TENSOR:.*]] = shape.to_extent_tensor %[[BROADCASTED_SHAPE]] : !shape.shape -> tensor<1xindex>
|
// CHECK: %[[SHAPE_TENSOR:.*]] = shape.to_extent_tensor %[[BROADCASTED_SHAPE]] : !shape.shape -> tensor<1xindex>
|
||||||
// CHECK: %[[BROADCASTED_LHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG_0]], %[[SHAPE_TENSOR]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>, tensor<1xindex>) -> tensor<?xf32>
|
// CHECK: %[[BROADCASTED_LHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG_0]], %[[SHAPE_TENSOR]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>, tensor<1xindex>) -> tensor<?xf32>
|
||||||
|
@ -306,10 +299,9 @@ func @addUnrankedScalar(%arg0: tensor<*xf32>, %arg1: tensor<f32>) -> tensor<*xf3
|
||||||
// CHECK: %[[SHAPE_1:.*]] = shape.shape_of %[[ARG_1]] : tensor<f32>
|
// CHECK: %[[SHAPE_1:.*]] = shape.shape_of %[[ARG_1]] : tensor<f32>
|
||||||
// CHECK: %[[WITNESS:.*]] = shape.cstr_broadcastable %[[SHAPE_RESHAPED]], %[[SHAPE_1]]
|
// CHECK: %[[WITNESS:.*]] = shape.cstr_broadcastable %[[SHAPE_RESHAPED]], %[[SHAPE_1]]
|
||||||
// CHECK: %[[ASSUMING_RESULT:.*]] = shape.assuming %[[WITNESS]] -> (tensor<?xf32>) {
|
// CHECK: %[[ASSUMING_RESULT:.*]] = shape.assuming %[[WITNESS]] -> (tensor<?xf32>) {
|
||||||
// CHECK: %[[SHAPE_OF:.*]] = shape.shape_of %[[RESHAPED]] : tensor<?xf32>
|
// CHECK: %[[ASTENSOR:.*]] = shape.to_extent_tensor %[[SHAPE_RESHAPED]]
|
||||||
// CHECK: %[[SHAPE_RESHAPED:.*]] = shape.to_extent_tensor %[[SHAPE_OF]]
|
// CHECK: %[[BROADCASTED_LHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[RESHAPED]], %[[ASTENSOR]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||||
// CHECK: %[[BROADCASTED_LHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[RESHAPED]], %[[SHAPE_RESHAPED]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xf32>, tensor<1xindex>) -> tensor<?xf32>
|
// CHECK: %[[BROADCASTED_RHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG_1]], %[[ASTENSOR]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>, tensor<1xindex>) -> tensor<?xf32>
|
||||||
// CHECK: %[[BROADCASTED_RHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG_1]], %[[SHAPE_RESHAPED]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>, tensor<1xindex>) -> tensor<?xf32>
|
|
||||||
// CHECK: %[[BROADCASTED_RESULT:.*]] = mhlo.add %[[BROADCASTED_LHS]], %[[BROADCASTED_RHS]] : tensor<?xf32>
|
// CHECK: %[[BROADCASTED_RESULT:.*]] = mhlo.add %[[BROADCASTED_LHS]], %[[BROADCASTED_RHS]] : tensor<?xf32>
|
||||||
// CHECK: shape.assuming_yield %[[BROADCASTED_RESULT]] : tensor<?xf32>
|
// CHECK: shape.assuming_yield %[[BROADCASTED_RESULT]] : tensor<?xf32>
|
||||||
// CHECK: }
|
// CHECK: }
|
||||||
|
|
Loading…
Reference in New Issue