Improve broadcast transformation to treat dynamic shapes with 1 element as scalar.
A shape that contains exactly one element is effectively a scalar. This leads to a speedup in cases where we have a binary op with one operand that is effectively a scalar, because we can use the fast path. PiperOrigin-RevId: 357515552
This commit is contained in:
parent
4060a86fe2
commit
824bc9c425
|
@ -230,46 +230,54 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp
|
|||
// pattern will handle the lowering.
|
||||
if (!lhs_type || !rhs_type) return failure();
|
||||
|
||||
// If lhs is scalar
|
||||
Value shape_of_lhs = rewriter.create<shape::ShapeOfOp>(loc, lhs);
|
||||
Value shape_of_rhs = rewriter.create<shape::ShapeOfOp>(loc, rhs);
|
||||
|
||||
// If lhs has exactly one element
|
||||
auto if_op = rewriter.create<scf::IfOp>(
|
||||
loc, result_type, IsScalarTensor(rewriter, op, lhs), true);
|
||||
loc, result_type, IsSingleElementShape(rewriter, op, shape_of_lhs),
|
||||
true);
|
||||
OpBuilder if_lhs_scalar_builder =
|
||||
if_op.getThenBodyBuilder(rewriter.getListener());
|
||||
Value reshaped_lhs = if_lhs_scalar_builder.create<tensor::CastOp>(
|
||||
Value reshaped_lhs = if_lhs_scalar_builder.create<mhlo::ReshapeOp>(
|
||||
loc, RankedTensorType::get({}, lhs_type.getElementType()), lhs);
|
||||
Value if_lhs_scalar_result = if_lhs_scalar_builder.create<ChloOpTy>(
|
||||
loc, ArrayRef<Type>{result_type}, ArrayRef<Value>{reshaped_lhs, rhs},
|
||||
op.getAttrs());
|
||||
if_lhs_scalar_builder.create<scf::YieldOp>(loc, if_lhs_scalar_result);
|
||||
Value extended_if_lhs_scalar_result =
|
||||
extendToBroadcastShape(if_lhs_scalar_builder, loc, if_lhs_scalar_result,
|
||||
shape_of_lhs, shape_of_rhs);
|
||||
if_lhs_scalar_builder.create<scf::YieldOp>(loc,
|
||||
extended_if_lhs_scalar_result);
|
||||
|
||||
// If lhs is NOT scalar
|
||||
// If lhs does not have exactly one element
|
||||
//
|
||||
// See if rhs is scalar
|
||||
// See if rhs has exactly one element
|
||||
OpBuilder else_lhs_scalar_builder =
|
||||
if_op.getElseBodyBuilder(rewriter.getListener());
|
||||
auto if_rhs_scalar_op = else_lhs_scalar_builder.create<scf::IfOp>(
|
||||
loc, result_type, IsScalarTensor(else_lhs_scalar_builder, op, rhs),
|
||||
true);
|
||||
loc, result_type,
|
||||
IsSingleElementShape(else_lhs_scalar_builder, op, shape_of_rhs), true);
|
||||
else_lhs_scalar_builder.create<scf::YieldOp>(loc,
|
||||
if_rhs_scalar_op.getResult(0));
|
||||
OpBuilder if_rhs_scalar_builder =
|
||||
if_rhs_scalar_op.getThenBodyBuilder(rewriter.getListener());
|
||||
Value reshaped_rhs = if_rhs_scalar_builder.create<tensor::CastOp>(
|
||||
loc, RankedTensorType::get({}, lhs_type.getElementType()), rhs);
|
||||
Value reshaped_rhs = if_rhs_scalar_builder.create<mhlo::ReshapeOp>(
|
||||
loc, RankedTensorType::get({}, rhs_type.getElementType()), rhs);
|
||||
Value if_rhs_scalar_result = if_rhs_scalar_builder.create<ChloOpTy>(
|
||||
loc, ArrayRef<Type>{result_type}, ArrayRef<Value>{lhs, reshaped_rhs},
|
||||
op.getAttrs());
|
||||
if_rhs_scalar_builder.create<scf::YieldOp>(loc, if_rhs_scalar_result);
|
||||
Value extended_if_rhs_scalar_result =
|
||||
extendToBroadcastShape(if_rhs_scalar_builder, loc, if_rhs_scalar_result,
|
||||
shape_of_lhs, shape_of_rhs);
|
||||
if_rhs_scalar_builder.create<scf::YieldOp>(loc,
|
||||
extended_if_rhs_scalar_result);
|
||||
|
||||
// If NEITHER shape is scalar
|
||||
// If NEITHER shape has exactly one element
|
||||
//
|
||||
// See if shapes are equal.
|
||||
OpBuilder else_no_scalars_builder =
|
||||
if_rhs_scalar_op.getElseBodyBuilder(rewriter.getListener());
|
||||
Value shape_of_lhs =
|
||||
else_no_scalars_builder.create<shape::ShapeOfOp>(loc, lhs);
|
||||
Value shape_of_rhs =
|
||||
else_no_scalars_builder.create<shape::ShapeOfOp>(loc, rhs);
|
||||
Value equal_shapes = else_no_scalars_builder.create<shape::ShapeEqOp>(
|
||||
loc, shape_of_lhs, shape_of_rhs);
|
||||
|
||||
|
@ -284,7 +292,7 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp
|
|||
Adaptor::CreateOp(op, result_type, lhs, rhs, if_eq_shapes_builder);
|
||||
if_eq_shapes_builder.create<scf::YieldOp>(loc, non_broadcast_op);
|
||||
|
||||
// If shapes are not scalar, nor equal
|
||||
// If shapes do not have exactly one element, nor are equal
|
||||
//
|
||||
// See if values are of a rank that we support.
|
||||
OpBuilder if_neq_shapes_builder =
|
||||
|
@ -297,16 +305,17 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp
|
|||
}
|
||||
|
||||
private:
|
||||
// Returns the dynamic result of checking the given value is a scalar tensor.
|
||||
Value IsScalarTensor(OpBuilder &rewriter, ChloOpTy op, Value tensor) const {
|
||||
// Returns the dynamic result of checking the given value is effectively a
|
||||
// scalar shape (i.e. the number of elements is 1).
|
||||
Value IsSingleElementShape(OpBuilder &rewriter, ChloOpTy op,
|
||||
Value shape_of_tensor) const {
|
||||
auto loc = op.getLoc();
|
||||
|
||||
Value shape_of_tensor = rewriter.create<shape::ShapeOfOp>(loc, tensor);
|
||||
Value rank_tensor = rewriter.create<shape::RankOp>(
|
||||
loc, rewriter.getIndexType(), shape_of_tensor);
|
||||
Value num_elements =
|
||||
rewriter.create<shape::NumElementsOp>(loc, shape_of_tensor);
|
||||
return rewriter.create<CmpIOp>(loc, rewriter.getI1Type(), CmpIPredicate::eq,
|
||||
rank_tensor,
|
||||
rewriter.create<ConstantIndexOp>(loc, 0));
|
||||
num_elements,
|
||||
rewriter.create<ConstantIndexOp>(loc, 1));
|
||||
}
|
||||
|
||||
Value GreaterRankIsN(OpBuilder &builder, Location loc, Value actual_rank,
|
||||
|
@ -326,6 +335,36 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp
|
|||
greater_rank_is_n, true);
|
||||
}
|
||||
|
||||
Value extendToBroadcastShape(OpBuilder &builder, Location loc, Value value,
|
||||
Value shape_of_lhs, Value shape_of_rhs) const {
|
||||
auto unknown_rank_extent_tensor_type = RankedTensorType::get(
|
||||
{RankedTensorType::kDynamicSize}, builder.getIndexType());
|
||||
Value broadcast_shape =
|
||||
builder.create<shape::BroadcastOp>(loc, unknown_rank_extent_tensor_type,
|
||||
shape_of_lhs, shape_of_rhs, nullptr);
|
||||
return builder.create<mhlo::DynamicReshapeOp>(loc, value.getType(), value,
|
||||
broadcast_shape);
|
||||
}
|
||||
|
||||
Value createBroadcastToKnownRank(OpBuilder &builder, ChloOpTy op, Value value,
|
||||
int targeted_rank) const {
|
||||
auto loc = op.getLoc();
|
||||
Value shape = builder.create<shape::ShapeOfOp>(loc, value);
|
||||
SmallVector<int64_t, 6> ranked_shape(targeted_rank, 1);
|
||||
auto unknown_rank_extent_tensor_type = RankedTensorType::get(
|
||||
{RankedTensorType::kDynamicSize}, builder.getIndexType());
|
||||
auto known_rank_extent_tensor_type =
|
||||
RankedTensorType::get({targeted_rank}, builder.getIndexType());
|
||||
Value ranked_shape_val = builder.create<shape::ConstShapeOp>(
|
||||
loc, known_rank_extent_tensor_type,
|
||||
mlir::DenseIntElementsAttr::get(known_rank_extent_tensor_type,
|
||||
ranked_shape));
|
||||
Value extended_value = builder.create<shape::BroadcastOp>(
|
||||
loc, unknown_rank_extent_tensor_type, shape, ranked_shape_val, nullptr);
|
||||
return builder.create<tensor::CastOp>(loc, known_rank_extent_tensor_type,
|
||||
extended_value);
|
||||
}
|
||||
|
||||
// Create the if statement and code for a broadcasting op with a result of a
|
||||
// given rank.
|
||||
void createRankSpecializedBroadcastAndOp(OpBuilder &if_builder, ChloOpTy op,
|
||||
|
@ -333,32 +372,16 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp
|
|||
int targeted_rank) const {
|
||||
auto loc = op.getLoc();
|
||||
|
||||
// Handle shape broadcasting and inferrence.
|
||||
Value lhs_shape = if_builder.create<shape::ShapeOfOp>(loc, lhs);
|
||||
Value rhs_shape = if_builder.create<shape::ShapeOfOp>(loc, rhs);
|
||||
SmallVector<int64_t, 6> ranked_shape(targeted_rank, 1);
|
||||
auto unknown_rank_extent_tensor_type = RankedTensorType::get(
|
||||
{RankedTensorType::kDynamicSize}, if_builder.getIndexType());
|
||||
auto known_rank_extent_tensor_type =
|
||||
RankedTensorType::get({targeted_rank}, if_builder.getIndexType());
|
||||
// Handle shape broadcasting and inference.
|
||||
Value extended_lhs_casted =
|
||||
createBroadcastToKnownRank(if_builder, op, lhs, targeted_rank);
|
||||
Value extended_rhs_casted =
|
||||
createBroadcastToKnownRank(if_builder, op, rhs, targeted_rank);
|
||||
auto dynamic_dimensions = llvm::SmallVector<int64_t, 6>(
|
||||
targeted_rank, RankedTensorType::kDynamicSize);
|
||||
auto reshaped_type = RankedTensorType::get(
|
||||
llvm::SmallVector<int64_t, 6>(targeted_rank,
|
||||
RankedTensorType::kDynamicSize),
|
||||
dynamic_dimensions,
|
||||
lhs.getType().template dyn_cast<TensorType>().getElementType());
|
||||
Value ranked_shape_val = if_builder.create<shape::ConstShapeOp>(
|
||||
loc, known_rank_extent_tensor_type,
|
||||
mlir::DenseIntElementsAttr::get(known_rank_extent_tensor_type,
|
||||
ranked_shape));
|
||||
Value extended_lhs = if_builder.create<shape::BroadcastOp>(
|
||||
loc, unknown_rank_extent_tensor_type, lhs_shape, ranked_shape_val,
|
||||
nullptr);
|
||||
Value extended_lhs_casted = if_builder.create<tensor::CastOp>(
|
||||
loc, known_rank_extent_tensor_type, extended_lhs);
|
||||
Value extended_rhs = if_builder.create<shape::BroadcastOp>(
|
||||
loc, unknown_rank_extent_tensor_type, rhs_shape, ranked_shape_val,
|
||||
nullptr);
|
||||
Value extended_rhs_casted = if_builder.create<tensor::CastOp>(
|
||||
loc, known_rank_extent_tensor_type, extended_rhs);
|
||||
|
||||
// 1. Reshape operands to the given rank (with the same number of elements)
|
||||
// 2. Compute the ranked-broadcasted ChloOp (which will assert that the ops
|
||||
|
@ -372,10 +395,8 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp
|
|||
.getType()
|
||||
.template dyn_cast<TensorType>()
|
||||
.getElementType();
|
||||
auto result_type = RankedTensorType::get(
|
||||
llvm::SmallVector<int64_t, 6>(targeted_rank,
|
||||
RankedTensorType::kDynamicSize),
|
||||
result_element_type);
|
||||
auto result_type =
|
||||
RankedTensorType::get(dynamic_dimensions, result_element_type);
|
||||
Value result = if_builder.create<ChloOpTy>(
|
||||
loc, ArrayRef<Type>{result_type},
|
||||
ArrayRef<Value>{reshaped_lhs, reshaped_rhs}, op.getAttrs());
|
||||
|
|
|
@ -158,32 +158,34 @@ func @addUnrankedUnranked(
|
|||
// CHECK-SAME: %[[LHS:.*]]: tensor<*xf32>,
|
||||
// CHECK-SAME: %[[RHS:.*]]: tensor<*xf32>) -> tensor<*xf32> {
|
||||
// CHECK-NEXT: %[[LHS_SHAPE:.*]] = shape.shape_of %[[LHS]] : tensor<*xf32> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[LHS_RANK:.*]] = shape.rank %[[LHS_SHAPE]] : tensor<?xindex> -> index
|
||||
// CHECK-NEXT: %[[C0:.*]] = constant 0 : index
|
||||
// CHECK-NEXT: %[[LHS_IS_SCALAR:.*]] = cmpi eq, %[[LHS_RANK]], %[[C0]] : index
|
||||
// CHECK-NEXT: %[[RHS_SHAPE:.*]] = shape.shape_of %[[RHS]] : tensor<*xf32> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[NUM_LHS:.*]] = shape.num_elements %[[LHS_SHAPE]] : tensor<?xindex> -> index
|
||||
// CHECK-NEXT: %[[C1:.*]] = constant 1 : index
|
||||
// CHECK-NEXT: %[[LHS_IS_SCALAR:.*]] = cmpi eq, %[[NUM_LHS]], %[[C1]] : index
|
||||
// Handle scalar LHS case
|
||||
// CHECK-NEXT: %[[VAL_8:.*]] = scf.if %[[LHS_IS_SCALAR]] -> (tensor<*xf32>) {
|
||||
// CHECK-NEXT: %[[SCALAR_LHS:.*]] = tensor.cast %[[LHS]] : tensor<*xf32> to tensor<f32>
|
||||
// CHECK-NEXT: %[[RHS_SHAPE_1:.*]] = shape.shape_of %[[RHS]] : tensor<*xf32> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[NUM_RHS:.*]] = shape.num_elements %[[RHS_SHAPE_1]] : tensor<?xindex> -> index
|
||||
// CHECK-NEXT: %[[SCALAR_LHS:.*]] = "mhlo.reshape"(%[[LHS]]) : (tensor<*xf32>) -> tensor<f32>
|
||||
// CHECK-NEXT: %[[NUM_RHS:.*]] = shape.num_elements %[[RHS_SHAPE]] : tensor<?xindex> -> index
|
||||
// CHECK-NEXT: %[[NUM_TENS_RHS:.*]] = tensor.from_elements %[[NUM_RHS]] : tensor<1xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[NUM_TENS_RHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK-NEXT: %[[LHS_SCALAR_RESULT:.*]] = chlo.broadcast_add %[[SCALAR_LHS]], %[[RESHAPED_RHS]] : (tensor<f32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
// CHECK-NEXT: %[[RESHAPED_LHS_SCALAR_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[LHS_SCALAR_RESULT]], %[[RHS_SHAPE_1]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||
// CHECK-NEXT: scf.yield %[[RESHAPED_LHS_SCALAR_RESULT]] : tensor<*xf32>
|
||||
// CHECK-NEXT: %[[RESHAPED_LHS_SCALAR_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[LHS_SCALAR_RESULT]], %[[RHS_SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||
// CHECK-NEXT: %[[SHAPE_BROADCAST_LHS:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED_EXTENDED_LHS_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[RESHAPED_LHS_SCALAR_RESULT]], %[[SHAPE_BROADCAST_LHS]]) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||
// CHECK-NEXT: scf.yield %[[RESHAPED_EXTENDED_LHS_RESULT]] : tensor<*xf32>
|
||||
// CHECK-NEXT: } else {
|
||||
// CHECK-NEXT: %[[RHS_SHAPE:.*]] = shape.shape_of %[[RHS]] : tensor<*xf32> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[RHS_RANK:.*]] = shape.rank %[[RHS_SHAPE]] : tensor<?xindex> -> index
|
||||
// CHECK-NEXT: %[[RHS_IS_SCALAR:.*]] = cmpi eq, %[[RHS_RANK]], %[[C0]] : index
|
||||
// CHECK-NEXT: %[[NUM_RHS:.*]] = shape.num_elements %[[RHS_SHAPE]] : tensor<?xindex> -> index
|
||||
// CHECK-NEXT: %[[RHS_IS_SCALAR:.*]] = cmpi eq, %[[NUM_RHS]], %[[C1]] : index
|
||||
// Handle scalar RHS case
|
||||
// CHECK-NEXT: %[[VAL_14:.*]] = scf.if %[[RHS_IS_SCALAR]] -> (tensor<*xf32>) {
|
||||
// CHECK-NEXT: %[[SCALAR_RHS:.*]] = tensor.cast %[[RHS]] : tensor<*xf32> to tensor<f32>
|
||||
// CHECK-NEXT: %[[NUM_LHS:.*]] = shape.num_elements %[[LHS_SHAPE]] : tensor<?xindex> -> index
|
||||
// CHECK-NEXT: %[[SCALAR_RHS:.*]] = "mhlo.reshape"(%[[RHS]]) : (tensor<*xf32>) -> tensor<f32>
|
||||
// CHECK-NEXT: %[[NUM_TENS_LHS:.*]] = tensor.from_elements %[[NUM_LHS]] : tensor<1xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[NUM_TENS_LHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK-NEXT: %[[RHS_SCALAR_RESULT:.*]] = chlo.broadcast_add %[[RESHAPED_LHS]], %[[SCALAR_RHS]] : (tensor<?xf32>, tensor<f32>) -> tensor<?xf32>
|
||||
// CHECK-NEXT: %[[RESHAPED_RHS_SCALAR_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[RHS_SCALAR_RESULT:.*]], %[[LHS_SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||
// CHECK-NEXT: scf.yield %[[RESHAPED_RHS_SCALAR_RESULT]] : tensor<*xf32>
|
||||
// CHECK-NEXT: %[[SHAPE_BROADCAST_RHS:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[RESHAPED_EXTENDED_RHS_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[RESHAPED_RHS_SCALAR_RESULT]], %[[SHAPE_BROADCAST_RHS]]) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||
// CHECK-NEXT: scf.yield %[[RESHAPED_EXTENDED_RHS_RESULT]] : tensor<*xf32>
|
||||
// CHECK-NEXT: } else {
|
||||
// CHECK-NEXT: %[[SHAPES_EQ:.*]] = shape.shape_eq %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex>
|
||||
// Handle equal shapes case
|
||||
|
@ -197,10 +199,11 @@ func @addUnrankedUnranked(
|
|||
// CHECK-NEXT: %[[RESHAPED_SAME_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[FLATTENED_RESULT]], %[[ANY_SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||
// CHECK-NEXT: scf.yield %[[RESHAPED_SAME_RESULT]] : tensor<*xf32>
|
||||
// CHECK-NEXT: } else {
|
||||
// CHECK-NEXT: %[[LHS_RANK:.*]] = shape.rank %[[LHS_SHAPE]] : tensor<?xindex> -> index
|
||||
// CHECK-NEXT: %[[RHS_RANK:.*]] = shape.rank %[[RHS_SHAPE]] : tensor<?xindex> -> index
|
||||
// CHECK-NEXT: %[[LHS_RANK_GREATER:.*]] = cmpi sgt, %[[LHS_RANK]], %[[RHS_RANK]] : index
|
||||
// CHECK-NEXT: %[[GREATEST_RANK:.*]] = select %[[LHS_RANK_GREATER]], %[[LHS_RANK]], %[[RHS_RANK]] : index
|
||||
// Handle rank 1 specialization
|
||||
// CHECK-NEXT: %[[C1:.*]] = constant 1 : index
|
||||
// CHECK-NEXT: %[[GREATEST_RANK_IS_1:.*]] = cmpi eq, %[[GREATEST_RANK]], %[[C1]] : index
|
||||
// CHECK-NEXT: %[[RESULT_RANK_1:.*]] = scf.if %[[GREATEST_RANK_IS_1]] -> (tensor<*xf32>) {
|
||||
// CHECK-NEXT: %[[CONST_SHAPE_1:.*]] = shape.const_shape [1]
|
||||
|
|
Loading…
Reference in New Issue