From 1b0eb4baa74a9ef703f6dedc9d6c95817dbc8ffc Mon Sep 17 00:00:00 2001 From: Stephan Herhut Date: Wed, 29 Jul 2020 07:35:52 -0700 Subject: [PATCH] Do not mandate the result type of shape computations but have it be inferred from context. The computation of a broadcasted shape forced the use of the shape type unnecessarily, which blocked further canonicalizations. PiperOrigin-RevId: 323783998 --- lib/utils/broadcast_utils.cc | 9 +++------ tests/chlo_legalize_to_hlo_broadcasts.mlir | 20 ++++++-------------- 2 files changed, 9 insertions(+), 20 deletions(-) diff --git a/lib/utils/broadcast_utils.cc b/lib/utils/broadcast_utils.cc index 73111c0..c446626 100644 --- a/lib/utils/broadcast_utils.cc +++ b/lib/utils/broadcast_utils.cc @@ -58,13 +58,10 @@ Value ComputeBinaryElementwiseBroadcastingResultExtents(Location loc, Value lhs, } int64_t result_rank = std::max(lhs_type.getRank(), rhs_type.getRank()); - auto shape_type = shape::ShapeType::get(builder.getContext()); - Value lhs_shape_v = - builder.createOrFold(loc, shape_type, lhs); - Value rhs_shape_v = - builder.createOrFold(loc, shape_type, rhs); + Value lhs_shape_v = builder.createOrFold(loc, lhs); + Value rhs_shape_v = builder.createOrFold(loc, rhs); Value result_shape_v = builder.createOrFold( - loc, shape_type, lhs_shape_v, rhs_shape_v, nullptr /* error */); + loc, lhs_shape_v, rhs_shape_v, nullptr /* error */); return builder.createOrFold( loc, RankedTensorType::get({result_rank}, builder.getIndexType()), result_shape_v); diff --git a/tests/chlo_legalize_to_hlo_broadcasts.mlir b/tests/chlo_legalize_to_hlo_broadcasts.mlir index 997136e..96e4fa3 100644 --- a/tests/chlo_legalize_to_hlo_broadcasts.mlir +++ b/tests/chlo_legalize_to_hlo_broadcasts.mlir @@ -18,9 +18,7 @@ func @dynamicBroadcast(%arg0: tensor, %arg1: tensor) -> tensor : tensor<1xi64>} // CHECK-DAG: %[[ARG1_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} @@ -41,9 +39,7 @@ func @dynamicBroadcastComplex(%arg0: tensor, %arg1: tensor) -> t // CHECK-DAG: %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]] // CHECK-NEXT: %[[WITNESS:.+]] = shape.cstr_broadcastable %[[ARG0_S]], %[[ARG1_S]] // CHECK-NEXT: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]] - // CHECK-DAG: %[[ARG0_SS:.+]] = shape.shape_of %[[ARG0]] - // CHECK-DAG: %[[ARG1_SS:.+]] = shape.shape_of %[[ARG1]] - // CHECK-NEXT: %[[RESULT_S:.+]] = shape.broadcast %[[ARG0_SS]], %[[ARG1_SS]] + // CHECK-NEXT: %[[RESULT_S:.+]] = shape.broadcast %[[ARG0_S]], %[[ARG1_S]] // CHECK-NEXT: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_S]] // CHECK-DAG: %[[ARG0_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<2xindex>) -> tensor // CHECK-DAG: %[[ARG1_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor, tensor<2xindex>) -> tensor @@ -64,9 +60,7 @@ func @dynamicBroadcastCompare(%arg0: tensor, %arg1: tensor) -> t // CHECK-DAG: %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]] // CHECK: %[[WITNESS:.+]] = shape.cstr_broadcastable %[[ARG0_S]], %[[ARG1_S]] // CHECK: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]] - // CHECK-DAG: %[[ARG0_SS:.+]] = shape.shape_of %[[ARG0]] - // CHECK-DAG: %[[ARG1_SS:.+]] = shape.shape_of %[[ARG1]] - // CHECK: %[[RESULT_S:.+]] = shape.broadcast %[[ARG0_SS]], %[[ARG1_SS]] + // CHECK: %[[RESULT_S:.+]] = shape.broadcast %[[ARG0_S]], %[[ARG1_S]] // CHECK: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_S]] // CHECK-DAG: %[[ARG0_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<2xindex>) -> tensor // CHECK-DAG: %[[ARG1_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor, tensor<2xindex>) -> tensor @@ -269,7 +263,6 @@ func @addScalarUnranked(%arg0: tensor, %arg1: tensor<*xf32>) -> tensor<*xf3 // CHECK: %[[WITNESS:.*]] = shape.cstr_broadcastable %[[SHAPE_0]], %[[SHAPE_RESHAPED]] // CHECK: %[[ASSUMING_RESULT:.*]] = shape.assuming %[[WITNESS]] -> (tensor) { // CHECK: %[[SCALAR_SHAPE:.*]] = shape.const_shape [] -// CHECK: %[[SHAPE_RESHAPED:.*]] = shape.shape_of %[[RESHAPED]] : tensor // CHECK: %[[BROADCASTED_SHAPE:.*]] = shape.broadcast %[[SCALAR_SHAPE]], %[[SHAPE_RESHAPED]] // CHECK: %[[SHAPE_TENSOR:.*]] = shape.to_extent_tensor %[[BROADCASTED_SHAPE]] : !shape.shape -> tensor<1xindex> // CHECK: %[[BROADCASTED_LHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG_0]], %[[SHAPE_TENSOR]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor<1xindex>) -> tensor @@ -306,10 +299,9 @@ func @addUnrankedScalar(%arg0: tensor<*xf32>, %arg1: tensor) -> tensor<*xf3 // CHECK: %[[SHAPE_1:.*]] = shape.shape_of %[[ARG_1]] : tensor // CHECK: %[[WITNESS:.*]] = shape.cstr_broadcastable %[[SHAPE_RESHAPED]], %[[SHAPE_1]] // CHECK: %[[ASSUMING_RESULT:.*]] = shape.assuming %[[WITNESS]] -> (tensor) { -// CHECK: %[[SHAPE_OF:.*]] = shape.shape_of %[[RESHAPED]] : tensor -// CHECK: %[[SHAPE_RESHAPED:.*]] = shape.to_extent_tensor %[[SHAPE_OF]] -// CHECK: %[[BROADCASTED_LHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[RESHAPED]], %[[SHAPE_RESHAPED]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor, tensor<1xindex>) -> tensor -// CHECK: %[[BROADCASTED_RHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG_1]], %[[SHAPE_RESHAPED]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor<1xindex>) -> tensor +// CHECK: %[[ASTENSOR:.*]] = shape.to_extent_tensor %[[SHAPE_RESHAPED]] +// CHECK: %[[BROADCASTED_LHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[RESHAPED]], %[[ASTENSOR]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor, tensor<1xindex>) -> tensor +// CHECK: %[[BROADCASTED_RHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG_1]], %[[ASTENSOR]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor<1xindex>) -> tensor // CHECK: %[[BROADCASTED_RESULT:.*]] = mhlo.add %[[BROADCASTED_LHS]], %[[BROADCASTED_RHS]] : tensor // CHECK: shape.assuming_yield %[[BROADCASTED_RESULT]] : tensor // CHECK: }