From a68a16cdc71c0a6f9a66a47b0c7c316dfa4cf261 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 6 Aug 2020 07:46:30 -0700 Subject: [PATCH] [MLIR][XLA] Allow for choice of safe/unsafe variant in broadcast utils Create safe or unsafe variants of `shape.broadcast` depending on the context. The representation by means of an extent tensor is only legal if the operands are known to be broadcastable. Currently, there is no use in a safe context in the codebase but it will be used for shape inference eventually. PiperOrigin-RevId: 325228073 --- include/mlir-hlo/utils/broadcast_utils.h | 10 ++++--- lib/Dialect/mhlo/IR/chlo_ops.cc | 2 +- .../mhlo/transforms/chlo_legalize_to_hlo.cc | 4 +-- lib/utils/broadcast_utils.cc | 28 ++++++++++++------- tests/chlo_infer_shape_type_methods.mlir | 11 ++++---- tests/chlo_legalize_to_hlo_broadcasts.mlir | 10 +++---- 6 files changed, 37 insertions(+), 28 deletions(-) diff --git a/include/mlir-hlo/utils/broadcast_utils.h b/include/mlir-hlo/utils/broadcast_utils.h index 1e24042..1c57073 100644 --- a/include/mlir-hlo/utils/broadcast_utils.h +++ b/include/mlir-hlo/utils/broadcast_utils.h @@ -38,10 +38,12 @@ bool IsLegalNumpyRankedBroadcast(Value lhs, Value rhs, // Emits shape dialect ops to compute the result shape for a broadcasting // binary elementwise op which broadcasts according to "numpy" semantics -// (see above), returning an extents tensor of the resulting shape. -Value ComputeBinaryElementwiseBroadcastingResultExtents(Location loc, Value lhs, - Value rhs, - OpBuilder& builder); +// (see above), returning a `shape.shape` or an extent tensor of the resulting +// shape. The result should only be an extent tensor in contexts that ensure +// both operands to be broadcastable. +Value ComputeBinaryElementwiseBroadcastingResultExtents( + Location loc, Value lhs, Value rhs, OpBuilder& builder, + bool unsafe_as_extent_tensor); } // namespace hlo } // namespace mlir diff --git a/lib/Dialect/mhlo/IR/chlo_ops.cc b/lib/Dialect/mhlo/IR/chlo_ops.cc index 99ed8bc..81389c3 100644 --- a/lib/Dialect/mhlo/IR/chlo_ops.cc +++ b/lib/Dialect/mhlo/IR/chlo_ops.cc @@ -151,7 +151,7 @@ LogicalResult ReifyBroadcastBinaryOpReturnTypeShapes( } Value computed_shape = hlo::ComputeBinaryElementwiseBroadcastingResultExtents( - loc, lhs, rhs, builder); + loc, lhs, rhs, builder, /*unsafe_as_extent_tensor=*/false); if (!computed_shape) return failure(); reifiedReturnShapes.push_back(computed_shape); return success(); diff --git a/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc b/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc index adbd2e5..c2db488 100644 --- a/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc +++ b/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc @@ -124,8 +124,8 @@ struct ConvertRankedDynamicBroadcastBinaryOp int64_t result_rank = std::max(lhs_type.getRank(), rhs_type.getRank()); Value result_extents = - hlo::ComputeBinaryElementwiseBroadcastingResultExtents(loc, lhs, rhs, - rewriter); + hlo::ComputeBinaryElementwiseBroadcastingResultExtents( + loc, lhs, rhs, rewriter, /*unsafe_as_extent_tensor=*/true); // Note that we unconditionally emit DynamicBroadcastInDim ops and let // downstream canonicalizations fold them away if possible. This is diff --git a/lib/utils/broadcast_utils.cc b/lib/utils/broadcast_utils.cc index a3ce4d4..71b1a4e 100644 --- a/lib/utils/broadcast_utils.cc +++ b/lib/utils/broadcast_utils.cc @@ -20,6 +20,7 @@ limitations under the License. #include "llvm/ADT/Sequence.h" #include "llvm/ADT/SmallVector.h" #include "mlir/Dialect/Shape/IR/Shape.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/StandardTypes.h" @@ -46,9 +47,9 @@ bool IsLegalNumpyRankedBroadcast(Value lhs, Value rhs, broadcast_dims.getIntValues().begin()); } -Value ComputeBinaryElementwiseBroadcastingResultExtents(Location loc, Value lhs, - Value rhs, - OpBuilder& builder) { +Value ComputeBinaryElementwiseBroadcastingResultExtents( + Location loc, Value lhs, Value rhs, OpBuilder& builder, + bool unsafe_as_extent_tensor) { auto lhs_type = lhs.getType().dyn_cast(); auto rhs_type = rhs.getType().dyn_cast(); if (!lhs_type || !rhs_type) { @@ -57,15 +58,22 @@ Value ComputeBinaryElementwiseBroadcastingResultExtents(Location loc, Value lhs, return nullptr; } - int64_t result_rank = std::max(lhs_type.getRank(), rhs_type.getRank()); Value lhs_shape_v = builder.createOrFold(loc, lhs); Value rhs_shape_v = builder.createOrFold(loc, rhs); - Value result_shape_v = builder.createOrFold( - loc, shape::ShapeType::get(builder.getContext()), lhs_shape_v, - rhs_shape_v, nullptr /* error */); - return builder.createOrFold( - loc, RankedTensorType::get({result_rank}, builder.getIndexType()), - result_shape_v); + + if (unsafe_as_extent_tensor) { + int64_t result_rank = std::max(lhs_type.getRank(), rhs_type.getRank()); + Value result_shape_v = builder.createOrFold( + loc, shape::getExtentTensorType(builder.getContext()), lhs_shape_v, + rhs_shape_v, nullptr /* error */); + return builder.createOrFold( + loc, RankedTensorType::get({result_rank}, builder.getIndexType()), + result_shape_v); + } + + return builder.createOrFold( + loc, builder.getType(), lhs_shape_v, rhs_shape_v, + nullptr /* error */); } } // namespace hlo diff --git a/tests/chlo_infer_shape_type_methods.mlir b/tests/chlo_infer_shape_type_methods.mlir index 99aab53..d226c92 100644 --- a/tests/chlo_infer_shape_type_methods.mlir +++ b/tests/chlo_infer_shape_type_methods.mlir @@ -5,15 +5,14 @@ // only test reification on an examplar op. // CHECK-SAME: %[[ARG0:.+]]: tensor, // CHECK-SAME: %[[ARG1:.+]]: tensor -func @broadcast_add(%arg0: tensor, %arg1: tensor) -> tensor<1xindex> { +func @broadcast_add(%arg0: tensor, %arg1: tensor) -> !shape.shape { // CHECK-DAG: %[[ARG0_S:.+]] = shape.shape_of %[[ARG0]] // CHECK-DAG: %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]] - // CHECK-DAG: %[[BCAST_S:.+]] = shape.broadcast %[[ARG0_S]], %[[ARG1_S]] - // CHECK: %[[EXTENTS:.+]] = shape.to_extent_tensor %[[BCAST_S]] - // CHECK: return %[[EXTENTS]] + // CHECK-DAG: %[[BCAST_S:.+]] = shape.broadcast %[[ARG0_S]], %[[ARG1_S]] : tensor, tensor -> !shape.shape + // CHECK: return %[[BCAST_S]] : !shape.shape %0 = chlo.broadcast_add %arg0, %arg1 : (tensor, tensor) -> tensor - %1 = "mhlo_test.reify_return_type_shapes"(%0) : (tensor) -> tensor<1xindex> - return %1 : tensor<1xindex> + %1 = "mhlo_test.reify_return_type_shapes"(%0) : (tensor) -> !shape.shape + return %1 : !shape.shape } // ----- diff --git a/tests/chlo_legalize_to_hlo_broadcasts.mlir b/tests/chlo_legalize_to_hlo_broadcasts.mlir index c08ead5..9670372 100644 --- a/tests/chlo_legalize_to_hlo_broadcasts.mlir +++ b/tests/chlo_legalize_to_hlo_broadcasts.mlir @@ -19,7 +19,7 @@ func @dynamicBroadcast(%arg0: tensor, %arg1: tensor) -> tensor to tensor<2xindex> // CHECK-DAG: %[[ARG0_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} // CHECK-DAG: %[[ARG1_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} // CHECK-NEXT: %[[RESULT:.+]] = mhlo.add %[[ARG0_B]], %[[ARG1_B]] @@ -40,7 +40,7 @@ func @dynamicBroadcastComplex(%arg0: tensor, %arg1: tensor) -> t // CHECK-NEXT: %[[WITNESS:.+]] = shape.cstr_broadcastable %[[ARG0_S]], %[[ARG1_S]] // CHECK-NEXT: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]] // CHECK-NEXT: %[[RESULT_S:.+]] = shape.broadcast %[[ARG0_S]], %[[ARG1_S]] - // CHECK-NEXT: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_S]] + // CHECK-NEXT: %[[RESULT_EXTENTS:.+]] = tensor_cast %[[RESULT_S]] : tensor to tensor<2xindex> // CHECK-DAG: %[[ARG0_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<2xindex>) -> tensor // CHECK-DAG: %[[ARG1_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor, tensor<2xindex>) -> tensor // CHECK-NEXT: %[[RESULT:.+]] = "mhlo.complex"(%[[ARG0_B]], %[[ARG1_B]]) : (tensor, tensor) -> tensor> @@ -61,7 +61,7 @@ func @dynamicBroadcastCompare(%arg0: tensor, %arg1: tensor) -> t // CHECK: %[[WITNESS:.+]] = shape.cstr_broadcastable %[[ARG0_S]], %[[ARG1_S]] // CHECK: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]] // CHECK: %[[RESULT_S:.+]] = shape.broadcast %[[ARG0_S]], %[[ARG1_S]] - // CHECK: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_S]] + // CHECK: %[[RESULT_EXTENTS:.+]] = tensor_cast %[[RESULT_S]] : tensor to tensor<2xindex> // CHECK-DAG: %[[ARG0_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<2xindex>) -> tensor // CHECK-DAG: %[[ARG1_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor, tensor<2xindex>) -> tensor // CHECK: %[[RESULT:.+]] = "mhlo.compare"(%[[ARG0_B]], %[[ARG1_B]]) {comparison_direction = "EQ"} : (tensor, tensor) -> tensor @@ -263,7 +263,7 @@ func @addScalarUnranked(%arg0: tensor, %arg1: tensor<*xf32>) -> tensor<*xf3 // CHECK: %[[ASSUMING_RESULT:.*]] = shape.assuming %[[WITNESS]] -> (tensor) { // CHECK: %[[SCALAR_SHAPE:.*]] = shape.const_shape [] // CHECK: %[[BROADCASTED_SHAPE:.*]] = shape.broadcast %[[SCALAR_SHAPE]], %[[SHAPE_RESHAPED]] -// CHECK: %[[SHAPE_TENSOR:.*]] = shape.to_extent_tensor %[[BROADCASTED_SHAPE]] : !shape.shape -> tensor<1xindex> +// CHECK: %[[SHAPE_TENSOR:.*]] = tensor_cast %[[BROADCASTED_SHAPE]] : tensor to tensor<1xindex> // CHECK: %[[BROADCASTED_LHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG_0]], %[[SHAPE_TENSOR]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor<1xindex>) -> tensor // CHECK: %[[BROADCASTED_RHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[RESHAPED]], %[[SHAPE_TENSOR]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor, tensor<1xindex>) -> tensor // CHECK: %[[BROADCASTED_RESULT:.*]] = mhlo.add %[[BROADCASTED_LHS]], %[[BROADCASTED_RHS]] : tensor @@ -296,7 +296,7 @@ func @addUnrankedScalar(%arg0: tensor<*xf32>, %arg1: tensor) -> tensor<*xf3 // CHECK: %[[SHAPE_1:.*]] = shape.shape_of %[[ARG_1]] : tensor // CHECK: %[[WITNESS:.*]] = shape.cstr_broadcastable %[[SHAPE_RESHAPED]], %[[SHAPE_1]] // CHECK: %[[ASSUMING_RESULT:.*]] = shape.assuming %[[WITNESS]] -> (tensor) { -// CHECK: %[[ASTENSOR:.*]] = shape.to_extent_tensor %[[SHAPE_RESHAPED]] +// CHECK: %[[ASTENSOR:.*]] = tensor_cast %[[SHAPE_RESHAPED]] // CHECK: %[[BROADCASTED_LHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[RESHAPED]], %[[ASTENSOR]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor, tensor<1xindex>) -> tensor // CHECK: %[[BROADCASTED_RHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG_1]], %[[ASTENSOR]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor<1xindex>) -> tensor // CHECK: %[[BROADCASTED_RESULT:.*]] = mhlo.add %[[BROADCASTED_LHS]], %[[BROADCASTED_RHS]] : tensor