[MLIR][XLA] Allow for choice of safe/unsafe variant in broadcast utils

Create safe or unsafe variants of `shape.broadcast` depending on the context.
The representation by means of an extent tensor is only legal if the operands
are known to be broadcastable. Currently, there is no use in a safe context in
the codebase but it will be used for shape inference eventually.

PiperOrigin-RevId: 325228073
This commit is contained in:
A. Unique TensorFlower 2020-08-06 07:46:30 -07:00 committed by Geoffrey Martin-Noble
parent bc3293a05f
commit a68a16cdc7
6 changed files with 37 additions and 28 deletions

View File

@ -38,10 +38,12 @@ bool IsLegalNumpyRankedBroadcast(Value lhs, Value rhs,
// Emits shape dialect ops to compute the result shape for a broadcasting
// binary elementwise op which broadcasts according to "numpy" semantics
// (see above), returning an extents tensor of the resulting shape.
Value ComputeBinaryElementwiseBroadcastingResultExtents(Location loc, Value lhs,
Value rhs,
OpBuilder& builder);
// (see above), returning a `shape.shape` or an extent tensor of the resulting
// shape. The result should only be an extent tensor in contexts that ensure
// both operands to be broadcastable.
Value ComputeBinaryElementwiseBroadcastingResultExtents(
Location loc, Value lhs, Value rhs, OpBuilder& builder,
bool unsafe_as_extent_tensor);
} // namespace hlo
} // namespace mlir

View File

@ -151,7 +151,7 @@ LogicalResult ReifyBroadcastBinaryOpReturnTypeShapes(
}
Value computed_shape = hlo::ComputeBinaryElementwiseBroadcastingResultExtents(
loc, lhs, rhs, builder);
loc, lhs, rhs, builder, /*unsafe_as_extent_tensor=*/false);
if (!computed_shape) return failure();
reifiedReturnShapes.push_back(computed_shape);
return success();

View File

@ -124,8 +124,8 @@ struct ConvertRankedDynamicBroadcastBinaryOp
int64_t result_rank = std::max(lhs_type.getRank(), rhs_type.getRank());
Value result_extents =
hlo::ComputeBinaryElementwiseBroadcastingResultExtents(loc, lhs, rhs,
rewriter);
hlo::ComputeBinaryElementwiseBroadcastingResultExtents(
loc, lhs, rhs, rewriter, /*unsafe_as_extent_tensor=*/true);
// Note that we unconditionally emit DynamicBroadcastInDim ops and let
// downstream canonicalizations fold them away if possible. This is

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/StandardTypes.h"
@ -46,9 +47,9 @@ bool IsLegalNumpyRankedBroadcast(Value lhs, Value rhs,
broadcast_dims.getIntValues().begin());
}
Value ComputeBinaryElementwiseBroadcastingResultExtents(Location loc, Value lhs,
Value rhs,
OpBuilder& builder) {
Value ComputeBinaryElementwiseBroadcastingResultExtents(
Location loc, Value lhs, Value rhs, OpBuilder& builder,
bool unsafe_as_extent_tensor) {
auto lhs_type = lhs.getType().dyn_cast<RankedTensorType>();
auto rhs_type = rhs.getType().dyn_cast<RankedTensorType>();
if (!lhs_type || !rhs_type) {
@ -57,15 +58,22 @@ Value ComputeBinaryElementwiseBroadcastingResultExtents(Location loc, Value lhs,
return nullptr;
}
int64_t result_rank = std::max(lhs_type.getRank(), rhs_type.getRank());
Value lhs_shape_v = builder.createOrFold<shape::ShapeOfOp>(loc, lhs);
Value rhs_shape_v = builder.createOrFold<shape::ShapeOfOp>(loc, rhs);
Value result_shape_v = builder.createOrFold<shape::BroadcastOp>(
loc, shape::ShapeType::get(builder.getContext()), lhs_shape_v,
rhs_shape_v, nullptr /* error */);
return builder.createOrFold<shape::ToExtentTensorOp>(
loc, RankedTensorType::get({result_rank}, builder.getIndexType()),
result_shape_v);
if (unsafe_as_extent_tensor) {
int64_t result_rank = std::max(lhs_type.getRank(), rhs_type.getRank());
Value result_shape_v = builder.createOrFold<shape::BroadcastOp>(
loc, shape::getExtentTensorType(builder.getContext()), lhs_shape_v,
rhs_shape_v, nullptr /* error */);
return builder.createOrFold<TensorCastOp>(
loc, RankedTensorType::get({result_rank}, builder.getIndexType()),
result_shape_v);
}
return builder.createOrFold<shape::BroadcastOp>(
loc, builder.getType<shape::ShapeType>(), lhs_shape_v, rhs_shape_v,
nullptr /* error */);
}
} // namespace hlo

View File

@ -5,15 +5,14 @@
// only test reification on an examplar op.
// CHECK-SAME: %[[ARG0:.+]]: tensor<?xf32>,
// CHECK-SAME: %[[ARG1:.+]]: tensor<?xf32>
func @broadcast_add(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<1xindex> {
func @broadcast_add(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> !shape.shape {
// CHECK-DAG: %[[ARG0_S:.+]] = shape.shape_of %[[ARG0]]
// CHECK-DAG: %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]]
// CHECK-DAG: %[[BCAST_S:.+]] = shape.broadcast %[[ARG0_S]], %[[ARG1_S]]
// CHECK: %[[EXTENTS:.+]] = shape.to_extent_tensor %[[BCAST_S]]
// CHECK: return %[[EXTENTS]]
// CHECK-DAG: %[[BCAST_S:.+]] = shape.broadcast %[[ARG0_S]], %[[ARG1_S]] : tensor<?xindex>, tensor<?xindex> -> !shape.shape
// CHECK: return %[[BCAST_S]] : !shape.shape
%0 = chlo.broadcast_add %arg0, %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%1 = "mhlo_test.reify_return_type_shapes"(%0) : (tensor<?xf32>) -> tensor<1xindex>
return %1 : tensor<1xindex>
%1 = "mhlo_test.reify_return_type_shapes"(%0) : (tensor<?xf32>) -> !shape.shape
return %1 : !shape.shape
}
// -----

View File

@ -19,7 +19,7 @@ func @dynamicBroadcast(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?
// CHECK-NEXT: %[[WITNESS:.+]] = shape.cstr_broadcastable %[[ARG0_S]], %[[ARG1_S]]
// CHECK-NEXT: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]]
// CHECK-DAG: %[[RESULT_S:.+]] = shape.broadcast %[[ARG0_S]], %[[ARG1_S]]
// CHECK: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_S]]
// CHECK: %[[RESULT_EXTENTS:.+]] = tensor_cast %[[RESULT_S]] : tensor<?xindex> to tensor<2xindex>
// CHECK-DAG: %[[ARG0_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK-DAG: %[[ARG1_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}
// CHECK-NEXT: %[[RESULT:.+]] = mhlo.add %[[ARG0_B]], %[[ARG1_B]]
@ -40,7 +40,7 @@ func @dynamicBroadcastComplex(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> t
// CHECK-NEXT: %[[WITNESS:.+]] = shape.cstr_broadcastable %[[ARG0_S]], %[[ARG1_S]]
// CHECK-NEXT: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]]
// CHECK-NEXT: %[[RESULT_S:.+]] = shape.broadcast %[[ARG0_S]], %[[ARG1_S]]
// CHECK-NEXT: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_S]]
// CHECK-NEXT: %[[RESULT_EXTENTS:.+]] = tensor_cast %[[RESULT_S]] : tensor<?xindex> to tensor<2xindex>
// CHECK-DAG: %[[ARG0_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
// CHECK-DAG: %[[ARG1_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
// CHECK-NEXT: %[[RESULT:.+]] = "mhlo.complex"(%[[ARG0_B]], %[[ARG1_B]]) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xcomplex<f32>>
@ -61,7 +61,7 @@ func @dynamicBroadcastCompare(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> t
// CHECK: %[[WITNESS:.+]] = shape.cstr_broadcastable %[[ARG0_S]], %[[ARG1_S]]
// CHECK: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]]
// CHECK: %[[RESULT_S:.+]] = shape.broadcast %[[ARG0_S]], %[[ARG1_S]]
// CHECK: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_S]]
// CHECK: %[[RESULT_EXTENTS:.+]] = tensor_cast %[[RESULT_S]] : tensor<?xindex> to tensor<2xindex>
// CHECK-DAG: %[[ARG0_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
// CHECK-DAG: %[[ARG1_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
// CHECK: %[[RESULT:.+]] = "mhlo.compare"(%[[ARG0_B]], %[[ARG1_B]]) {comparison_direction = "EQ"} : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xi1>
@ -263,7 +263,7 @@ func @addScalarUnranked(%arg0: tensor<f32>, %arg1: tensor<*xf32>) -> tensor<*xf3
// CHECK: %[[ASSUMING_RESULT:.*]] = shape.assuming %[[WITNESS]] -> (tensor<?xf32>) {
// CHECK: %[[SCALAR_SHAPE:.*]] = shape.const_shape []
// CHECK: %[[BROADCASTED_SHAPE:.*]] = shape.broadcast %[[SCALAR_SHAPE]], %[[SHAPE_RESHAPED]]
// CHECK: %[[SHAPE_TENSOR:.*]] = shape.to_extent_tensor %[[BROADCASTED_SHAPE]] : !shape.shape -> tensor<1xindex>
// CHECK: %[[SHAPE_TENSOR:.*]] = tensor_cast %[[BROADCASTED_SHAPE]] : tensor<?xindex> to tensor<1xindex>
// CHECK: %[[BROADCASTED_LHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG_0]], %[[SHAPE_TENSOR]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK: %[[BROADCASTED_RHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[RESHAPED]], %[[SHAPE_TENSOR]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK: %[[BROADCASTED_RESULT:.*]] = mhlo.add %[[BROADCASTED_LHS]], %[[BROADCASTED_RHS]] : tensor<?xf32>
@ -296,7 +296,7 @@ func @addUnrankedScalar(%arg0: tensor<*xf32>, %arg1: tensor<f32>) -> tensor<*xf3
// CHECK: %[[SHAPE_1:.*]] = shape.shape_of %[[ARG_1]] : tensor<f32>
// CHECK: %[[WITNESS:.*]] = shape.cstr_broadcastable %[[SHAPE_RESHAPED]], %[[SHAPE_1]]
// CHECK: %[[ASSUMING_RESULT:.*]] = shape.assuming %[[WITNESS]] -> (tensor<?xf32>) {
// CHECK: %[[ASTENSOR:.*]] = shape.to_extent_tensor %[[SHAPE_RESHAPED]]
// CHECK: %[[ASTENSOR:.*]] = tensor_cast %[[SHAPE_RESHAPED]]
// CHECK: %[[BROADCASTED_LHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[RESHAPED]], %[[ASTENSOR]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK: %[[BROADCASTED_RHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG_1]], %[[ASTENSOR]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK: %[[BROADCASTED_RESULT:.*]] = mhlo.add %[[BROADCASTED_LHS]], %[[BROADCASTED_RHS]] : tensor<?xf32>