[MLIR][XLA] Allow for choice of safe/unsafe variant in broadcast utils
Create safe or unsafe variants of `shape.broadcast` depending on the context. The representation by means of an extent tensor is only legal if the operands are known to be broadcastable. Currently, there is no use in a safe context in the codebase but it will be used for shape inference eventually. PiperOrigin-RevId: 325228073
This commit is contained in:
parent
bc3293a05f
commit
a68a16cdc7
|
@ -38,10 +38,12 @@ bool IsLegalNumpyRankedBroadcast(Value lhs, Value rhs,
|
|||
|
||||
// Emits shape dialect ops to compute the result shape for a broadcasting
|
||||
// binary elementwise op which broadcasts according to "numpy" semantics
|
||||
// (see above), returning an extents tensor of the resulting shape.
|
||||
Value ComputeBinaryElementwiseBroadcastingResultExtents(Location loc, Value lhs,
|
||||
Value rhs,
|
||||
OpBuilder& builder);
|
||||
// (see above), returning a `shape.shape` or an extent tensor of the resulting
|
||||
// shape. The result should only be an extent tensor in contexts that ensure
|
||||
// both operands to be broadcastable.
|
||||
Value ComputeBinaryElementwiseBroadcastingResultExtents(
|
||||
Location loc, Value lhs, Value rhs, OpBuilder& builder,
|
||||
bool unsafe_as_extent_tensor);
|
||||
|
||||
} // namespace hlo
|
||||
} // namespace mlir
|
||||
|
|
|
@ -151,7 +151,7 @@ LogicalResult ReifyBroadcastBinaryOpReturnTypeShapes(
|
|||
}
|
||||
|
||||
Value computed_shape = hlo::ComputeBinaryElementwiseBroadcastingResultExtents(
|
||||
loc, lhs, rhs, builder);
|
||||
loc, lhs, rhs, builder, /*unsafe_as_extent_tensor=*/false);
|
||||
if (!computed_shape) return failure();
|
||||
reifiedReturnShapes.push_back(computed_shape);
|
||||
return success();
|
||||
|
|
|
@ -124,8 +124,8 @@ struct ConvertRankedDynamicBroadcastBinaryOp
|
|||
|
||||
int64_t result_rank = std::max(lhs_type.getRank(), rhs_type.getRank());
|
||||
Value result_extents =
|
||||
hlo::ComputeBinaryElementwiseBroadcastingResultExtents(loc, lhs, rhs,
|
||||
rewriter);
|
||||
hlo::ComputeBinaryElementwiseBroadcastingResultExtents(
|
||||
loc, lhs, rhs, rewriter, /*unsafe_as_extent_tensor=*/true);
|
||||
|
||||
// Note that we unconditionally emit DynamicBroadcastInDim ops and let
|
||||
// downstream canonicalizations fold them away if possible. This is
|
||||
|
|
|
@ -20,6 +20,7 @@ limitations under the License.
|
|||
#include "llvm/ADT/Sequence.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/Diagnostics.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
|
||||
|
@ -46,9 +47,9 @@ bool IsLegalNumpyRankedBroadcast(Value lhs, Value rhs,
|
|||
broadcast_dims.getIntValues().begin());
|
||||
}
|
||||
|
||||
Value ComputeBinaryElementwiseBroadcastingResultExtents(Location loc, Value lhs,
|
||||
Value rhs,
|
||||
OpBuilder& builder) {
|
||||
Value ComputeBinaryElementwiseBroadcastingResultExtents(
|
||||
Location loc, Value lhs, Value rhs, OpBuilder& builder,
|
||||
bool unsafe_as_extent_tensor) {
|
||||
auto lhs_type = lhs.getType().dyn_cast<RankedTensorType>();
|
||||
auto rhs_type = rhs.getType().dyn_cast<RankedTensorType>();
|
||||
if (!lhs_type || !rhs_type) {
|
||||
|
@ -57,15 +58,22 @@ Value ComputeBinaryElementwiseBroadcastingResultExtents(Location loc, Value lhs,
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
int64_t result_rank = std::max(lhs_type.getRank(), rhs_type.getRank());
|
||||
Value lhs_shape_v = builder.createOrFold<shape::ShapeOfOp>(loc, lhs);
|
||||
Value rhs_shape_v = builder.createOrFold<shape::ShapeOfOp>(loc, rhs);
|
||||
|
||||
if (unsafe_as_extent_tensor) {
|
||||
int64_t result_rank = std::max(lhs_type.getRank(), rhs_type.getRank());
|
||||
Value result_shape_v = builder.createOrFold<shape::BroadcastOp>(
|
||||
loc, shape::ShapeType::get(builder.getContext()), lhs_shape_v,
|
||||
loc, shape::getExtentTensorType(builder.getContext()), lhs_shape_v,
|
||||
rhs_shape_v, nullptr /* error */);
|
||||
return builder.createOrFold<shape::ToExtentTensorOp>(
|
||||
return builder.createOrFold<TensorCastOp>(
|
||||
loc, RankedTensorType::get({result_rank}, builder.getIndexType()),
|
||||
result_shape_v);
|
||||
}
|
||||
|
||||
return builder.createOrFold<shape::BroadcastOp>(
|
||||
loc, builder.getType<shape::ShapeType>(), lhs_shape_v, rhs_shape_v,
|
||||
nullptr /* error */);
|
||||
}
|
||||
|
||||
} // namespace hlo
|
||||
|
|
|
@ -5,15 +5,14 @@
|
|||
// only test reification on an examplar op.
|
||||
// CHECK-SAME: %[[ARG0:.+]]: tensor<?xf32>,
|
||||
// CHECK-SAME: %[[ARG1:.+]]: tensor<?xf32>
|
||||
func @broadcast_add(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<1xindex> {
|
||||
func @broadcast_add(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> !shape.shape {
|
||||
// CHECK-DAG: %[[ARG0_S:.+]] = shape.shape_of %[[ARG0]]
|
||||
// CHECK-DAG: %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]]
|
||||
// CHECK-DAG: %[[BCAST_S:.+]] = shape.broadcast %[[ARG0_S]], %[[ARG1_S]]
|
||||
// CHECK: %[[EXTENTS:.+]] = shape.to_extent_tensor %[[BCAST_S]]
|
||||
// CHECK: return %[[EXTENTS]]
|
||||
// CHECK-DAG: %[[BCAST_S:.+]] = shape.broadcast %[[ARG0_S]], %[[ARG1_S]] : tensor<?xindex>, tensor<?xindex> -> !shape.shape
|
||||
// CHECK: return %[[BCAST_S]] : !shape.shape
|
||||
%0 = chlo.broadcast_add %arg0, %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
%1 = "mhlo_test.reify_return_type_shapes"(%0) : (tensor<?xf32>) -> tensor<1xindex>
|
||||
return %1 : tensor<1xindex>
|
||||
%1 = "mhlo_test.reify_return_type_shapes"(%0) : (tensor<?xf32>) -> !shape.shape
|
||||
return %1 : !shape.shape
|
||||
}
|
||||
|
||||
// -----
|
||||
|
|
|
@ -19,7 +19,7 @@ func @dynamicBroadcast(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?
|
|||
// CHECK-NEXT: %[[WITNESS:.+]] = shape.cstr_broadcastable %[[ARG0_S]], %[[ARG1_S]]
|
||||
// CHECK-NEXT: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]]
|
||||
// CHECK-DAG: %[[RESULT_S:.+]] = shape.broadcast %[[ARG0_S]], %[[ARG1_S]]
|
||||
// CHECK: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_S]]
|
||||
// CHECK: %[[RESULT_EXTENTS:.+]] = tensor_cast %[[RESULT_S]] : tensor<?xindex> to tensor<2xindex>
|
||||
// CHECK-DAG: %[[ARG0_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK-DAG: %[[ARG1_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}
|
||||
// CHECK-NEXT: %[[RESULT:.+]] = mhlo.add %[[ARG0_B]], %[[ARG1_B]]
|
||||
|
@ -40,7 +40,7 @@ func @dynamicBroadcastComplex(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> t
|
|||
// CHECK-NEXT: %[[WITNESS:.+]] = shape.cstr_broadcastable %[[ARG0_S]], %[[ARG1_S]]
|
||||
// CHECK-NEXT: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]]
|
||||
// CHECK-NEXT: %[[RESULT_S:.+]] = shape.broadcast %[[ARG0_S]], %[[ARG1_S]]
|
||||
// CHECK-NEXT: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_S]]
|
||||
// CHECK-NEXT: %[[RESULT_EXTENTS:.+]] = tensor_cast %[[RESULT_S]] : tensor<?xindex> to tensor<2xindex>
|
||||
// CHECK-DAG: %[[ARG0_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||
// CHECK-DAG: %[[ARG1_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||
// CHECK-NEXT: %[[RESULT:.+]] = "mhlo.complex"(%[[ARG0_B]], %[[ARG1_B]]) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xcomplex<f32>>
|
||||
|
@ -61,7 +61,7 @@ func @dynamicBroadcastCompare(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> t
|
|||
// CHECK: %[[WITNESS:.+]] = shape.cstr_broadcastable %[[ARG0_S]], %[[ARG1_S]]
|
||||
// CHECK: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]]
|
||||
// CHECK: %[[RESULT_S:.+]] = shape.broadcast %[[ARG0_S]], %[[ARG1_S]]
|
||||
// CHECK: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_S]]
|
||||
// CHECK: %[[RESULT_EXTENTS:.+]] = tensor_cast %[[RESULT_S]] : tensor<?xindex> to tensor<2xindex>
|
||||
// CHECK-DAG: %[[ARG0_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||
// CHECK-DAG: %[[ARG1_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[RESULT:.+]] = "mhlo.compare"(%[[ARG0_B]], %[[ARG1_B]]) {comparison_direction = "EQ"} : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xi1>
|
||||
|
@ -263,7 +263,7 @@ func @addScalarUnranked(%arg0: tensor<f32>, %arg1: tensor<*xf32>) -> tensor<*xf3
|
|||
// CHECK: %[[ASSUMING_RESULT:.*]] = shape.assuming %[[WITNESS]] -> (tensor<?xf32>) {
|
||||
// CHECK: %[[SCALAR_SHAPE:.*]] = shape.const_shape []
|
||||
// CHECK: %[[BROADCASTED_SHAPE:.*]] = shape.broadcast %[[SCALAR_SHAPE]], %[[SHAPE_RESHAPED]]
|
||||
// CHECK: %[[SHAPE_TENSOR:.*]] = shape.to_extent_tensor %[[BROADCASTED_SHAPE]] : !shape.shape -> tensor<1xindex>
|
||||
// CHECK: %[[SHAPE_TENSOR:.*]] = tensor_cast %[[BROADCASTED_SHAPE]] : tensor<?xindex> to tensor<1xindex>
|
||||
// CHECK: %[[BROADCASTED_LHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG_0]], %[[SHAPE_TENSOR]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK: %[[BROADCASTED_RHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[RESHAPED]], %[[SHAPE_TENSOR]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK: %[[BROADCASTED_RESULT:.*]] = mhlo.add %[[BROADCASTED_LHS]], %[[BROADCASTED_RHS]] : tensor<?xf32>
|
||||
|
@ -296,7 +296,7 @@ func @addUnrankedScalar(%arg0: tensor<*xf32>, %arg1: tensor<f32>) -> tensor<*xf3
|
|||
// CHECK: %[[SHAPE_1:.*]] = shape.shape_of %[[ARG_1]] : tensor<f32>
|
||||
// CHECK: %[[WITNESS:.*]] = shape.cstr_broadcastable %[[SHAPE_RESHAPED]], %[[SHAPE_1]]
|
||||
// CHECK: %[[ASSUMING_RESULT:.*]] = shape.assuming %[[WITNESS]] -> (tensor<?xf32>) {
|
||||
// CHECK: %[[ASTENSOR:.*]] = shape.to_extent_tensor %[[SHAPE_RESHAPED]]
|
||||
// CHECK: %[[ASTENSOR:.*]] = tensor_cast %[[SHAPE_RESHAPED]]
|
||||
// CHECK: %[[BROADCASTED_LHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[RESHAPED]], %[[ASTENSOR]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK: %[[BROADCASTED_RHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG_1]], %[[ASTENSOR]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK: %[[BROADCASTED_RESULT:.*]] = mhlo.add %[[BROADCASTED_LHS]], %[[BROADCASTED_RHS]] : tensor<?xf32>
|
||||
|
|
Loading…
Reference in New Issue