Improve broadcast transformation to treat dynamic shapes with 1 element as scalar.

A shape that contains exactly one element is effectively a scalar. This leads
to a speedup in cases where we have a binary op with one operand that is
effectively a scalar, because we can use the fast path.

PiperOrigin-RevId: 357515552
This commit is contained in:
Adrian Kuegel 2021-02-14 23:24:45 -08:00 committed by TensorFlow MLIR Team
parent 4060a86fe2
commit 824bc9c425
2 changed files with 91 additions and 67 deletions

View File

@ -230,46 +230,54 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp
// pattern will handle the lowering.
if (!lhs_type || !rhs_type) return failure();
// If lhs is scalar
Value shape_of_lhs = rewriter.create<shape::ShapeOfOp>(loc, lhs);
Value shape_of_rhs = rewriter.create<shape::ShapeOfOp>(loc, rhs);
// If lhs has exactly one element
auto if_op = rewriter.create<scf::IfOp>(
loc, result_type, IsScalarTensor(rewriter, op, lhs), true);
loc, result_type, IsSingleElementShape(rewriter, op, shape_of_lhs),
true);
OpBuilder if_lhs_scalar_builder =
if_op.getThenBodyBuilder(rewriter.getListener());
Value reshaped_lhs = if_lhs_scalar_builder.create<tensor::CastOp>(
Value reshaped_lhs = if_lhs_scalar_builder.create<mhlo::ReshapeOp>(
loc, RankedTensorType::get({}, lhs_type.getElementType()), lhs);
Value if_lhs_scalar_result = if_lhs_scalar_builder.create<ChloOpTy>(
loc, ArrayRef<Type>{result_type}, ArrayRef<Value>{reshaped_lhs, rhs},
op.getAttrs());
if_lhs_scalar_builder.create<scf::YieldOp>(loc, if_lhs_scalar_result);
Value extended_if_lhs_scalar_result =
extendToBroadcastShape(if_lhs_scalar_builder, loc, if_lhs_scalar_result,
shape_of_lhs, shape_of_rhs);
if_lhs_scalar_builder.create<scf::YieldOp>(loc,
extended_if_lhs_scalar_result);
// If lhs is NOT scalar
// If lhs does not have exactly one element
//
// See if rhs is scalar
// See if rhs has exactly one element
OpBuilder else_lhs_scalar_builder =
if_op.getElseBodyBuilder(rewriter.getListener());
auto if_rhs_scalar_op = else_lhs_scalar_builder.create<scf::IfOp>(
loc, result_type, IsScalarTensor(else_lhs_scalar_builder, op, rhs),
true);
loc, result_type,
IsSingleElementShape(else_lhs_scalar_builder, op, shape_of_rhs), true);
else_lhs_scalar_builder.create<scf::YieldOp>(loc,
if_rhs_scalar_op.getResult(0));
OpBuilder if_rhs_scalar_builder =
if_rhs_scalar_op.getThenBodyBuilder(rewriter.getListener());
Value reshaped_rhs = if_rhs_scalar_builder.create<tensor::CastOp>(
loc, RankedTensorType::get({}, lhs_type.getElementType()), rhs);
Value reshaped_rhs = if_rhs_scalar_builder.create<mhlo::ReshapeOp>(
loc, RankedTensorType::get({}, rhs_type.getElementType()), rhs);
Value if_rhs_scalar_result = if_rhs_scalar_builder.create<ChloOpTy>(
loc, ArrayRef<Type>{result_type}, ArrayRef<Value>{lhs, reshaped_rhs},
op.getAttrs());
if_rhs_scalar_builder.create<scf::YieldOp>(loc, if_rhs_scalar_result);
Value extended_if_rhs_scalar_result =
extendToBroadcastShape(if_rhs_scalar_builder, loc, if_rhs_scalar_result,
shape_of_lhs, shape_of_rhs);
if_rhs_scalar_builder.create<scf::YieldOp>(loc,
extended_if_rhs_scalar_result);
// If NEITHER shape is scalar
// If NEITHER shape has exactly one element
//
// See if shapes are equal.
OpBuilder else_no_scalars_builder =
if_rhs_scalar_op.getElseBodyBuilder(rewriter.getListener());
Value shape_of_lhs =
else_no_scalars_builder.create<shape::ShapeOfOp>(loc, lhs);
Value shape_of_rhs =
else_no_scalars_builder.create<shape::ShapeOfOp>(loc, rhs);
Value equal_shapes = else_no_scalars_builder.create<shape::ShapeEqOp>(
loc, shape_of_lhs, shape_of_rhs);
@ -284,7 +292,7 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp
Adaptor::CreateOp(op, result_type, lhs, rhs, if_eq_shapes_builder);
if_eq_shapes_builder.create<scf::YieldOp>(loc, non_broadcast_op);
// If shapes are not scalar, nor equal
// If shapes do not have exactly one element, nor are equal
//
// See if values are of a rank that we support.
OpBuilder if_neq_shapes_builder =
@ -297,16 +305,17 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp
}
private:
// Returns the dynamic result of checking the given value is a scalar tensor.
Value IsScalarTensor(OpBuilder &rewriter, ChloOpTy op, Value tensor) const {
// Returns the dynamic result of checking the given value is effectively a
// scalar shape (i.e. the number of elements is 1).
Value IsSingleElementShape(OpBuilder &rewriter, ChloOpTy op,
Value shape_of_tensor) const {
auto loc = op.getLoc();
Value shape_of_tensor = rewriter.create<shape::ShapeOfOp>(loc, tensor);
Value rank_tensor = rewriter.create<shape::RankOp>(
loc, rewriter.getIndexType(), shape_of_tensor);
Value num_elements =
rewriter.create<shape::NumElementsOp>(loc, shape_of_tensor);
return rewriter.create<CmpIOp>(loc, rewriter.getI1Type(), CmpIPredicate::eq,
rank_tensor,
rewriter.create<ConstantIndexOp>(loc, 0));
num_elements,
rewriter.create<ConstantIndexOp>(loc, 1));
}
Value GreaterRankIsN(OpBuilder &builder, Location loc, Value actual_rank,
@ -326,6 +335,36 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp
greater_rank_is_n, true);
}
Value extendToBroadcastShape(OpBuilder &builder, Location loc, Value value,
Value shape_of_lhs, Value shape_of_rhs) const {
auto unknown_rank_extent_tensor_type = RankedTensorType::get(
{RankedTensorType::kDynamicSize}, builder.getIndexType());
Value broadcast_shape =
builder.create<shape::BroadcastOp>(loc, unknown_rank_extent_tensor_type,
shape_of_lhs, shape_of_rhs, nullptr);
return builder.create<mhlo::DynamicReshapeOp>(loc, value.getType(), value,
broadcast_shape);
}
Value createBroadcastToKnownRank(OpBuilder &builder, ChloOpTy op, Value value,
int targeted_rank) const {
auto loc = op.getLoc();
Value shape = builder.create<shape::ShapeOfOp>(loc, value);
SmallVector<int64_t, 6> ranked_shape(targeted_rank, 1);
auto unknown_rank_extent_tensor_type = RankedTensorType::get(
{RankedTensorType::kDynamicSize}, builder.getIndexType());
auto known_rank_extent_tensor_type =
RankedTensorType::get({targeted_rank}, builder.getIndexType());
Value ranked_shape_val = builder.create<shape::ConstShapeOp>(
loc, known_rank_extent_tensor_type,
mlir::DenseIntElementsAttr::get(known_rank_extent_tensor_type,
ranked_shape));
Value extended_value = builder.create<shape::BroadcastOp>(
loc, unknown_rank_extent_tensor_type, shape, ranked_shape_val, nullptr);
return builder.create<tensor::CastOp>(loc, known_rank_extent_tensor_type,
extended_value);
}
// Create the if statement and code for a broadcasting op with a result of a
// given rank.
void createRankSpecializedBroadcastAndOp(OpBuilder &if_builder, ChloOpTy op,
@ -333,32 +372,16 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp
int targeted_rank) const {
auto loc = op.getLoc();
// Handle shape broadcasting and inferrence.
Value lhs_shape = if_builder.create<shape::ShapeOfOp>(loc, lhs);
Value rhs_shape = if_builder.create<shape::ShapeOfOp>(loc, rhs);
SmallVector<int64_t, 6> ranked_shape(targeted_rank, 1);
auto unknown_rank_extent_tensor_type = RankedTensorType::get(
{RankedTensorType::kDynamicSize}, if_builder.getIndexType());
auto known_rank_extent_tensor_type =
RankedTensorType::get({targeted_rank}, if_builder.getIndexType());
// Handle shape broadcasting and inference.
Value extended_lhs_casted =
createBroadcastToKnownRank(if_builder, op, lhs, targeted_rank);
Value extended_rhs_casted =
createBroadcastToKnownRank(if_builder, op, rhs, targeted_rank);
auto dynamic_dimensions = llvm::SmallVector<int64_t, 6>(
targeted_rank, RankedTensorType::kDynamicSize);
auto reshaped_type = RankedTensorType::get(
llvm::SmallVector<int64_t, 6>(targeted_rank,
RankedTensorType::kDynamicSize),
dynamic_dimensions,
lhs.getType().template dyn_cast<TensorType>().getElementType());
Value ranked_shape_val = if_builder.create<shape::ConstShapeOp>(
loc, known_rank_extent_tensor_type,
mlir::DenseIntElementsAttr::get(known_rank_extent_tensor_type,
ranked_shape));
Value extended_lhs = if_builder.create<shape::BroadcastOp>(
loc, unknown_rank_extent_tensor_type, lhs_shape, ranked_shape_val,
nullptr);
Value extended_lhs_casted = if_builder.create<tensor::CastOp>(
loc, known_rank_extent_tensor_type, extended_lhs);
Value extended_rhs = if_builder.create<shape::BroadcastOp>(
loc, unknown_rank_extent_tensor_type, rhs_shape, ranked_shape_val,
nullptr);
Value extended_rhs_casted = if_builder.create<tensor::CastOp>(
loc, known_rank_extent_tensor_type, extended_rhs);
// 1. Reshape operands to the given rank (with the same number of elements)
// 2. Compute the ranked-broadcasted ChloOp (which will assert that the ops
@ -372,10 +395,8 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp
.getType()
.template dyn_cast<TensorType>()
.getElementType();
auto result_type = RankedTensorType::get(
llvm::SmallVector<int64_t, 6>(targeted_rank,
RankedTensorType::kDynamicSize),
result_element_type);
auto result_type =
RankedTensorType::get(dynamic_dimensions, result_element_type);
Value result = if_builder.create<ChloOpTy>(
loc, ArrayRef<Type>{result_type},
ArrayRef<Value>{reshaped_lhs, reshaped_rhs}, op.getAttrs());

View File

@ -158,32 +158,34 @@ func @addUnrankedUnranked(
// CHECK-SAME: %[[LHS:.*]]: tensor<*xf32>,
// CHECK-SAME: %[[RHS:.*]]: tensor<*xf32>) -> tensor<*xf32> {
// CHECK-NEXT: %[[LHS_SHAPE:.*]] = shape.shape_of %[[LHS]] : tensor<*xf32> -> tensor<?xindex>
// CHECK-NEXT: %[[LHS_RANK:.*]] = shape.rank %[[LHS_SHAPE]] : tensor<?xindex> -> index
// CHECK-NEXT: %[[C0:.*]] = constant 0 : index
// CHECK-NEXT: %[[LHS_IS_SCALAR:.*]] = cmpi eq, %[[LHS_RANK]], %[[C0]] : index
// CHECK-NEXT: %[[RHS_SHAPE:.*]] = shape.shape_of %[[RHS]] : tensor<*xf32> -> tensor<?xindex>
// CHECK-NEXT: %[[NUM_LHS:.*]] = shape.num_elements %[[LHS_SHAPE]] : tensor<?xindex> -> index
// CHECK-NEXT: %[[C1:.*]] = constant 1 : index
// CHECK-NEXT: %[[LHS_IS_SCALAR:.*]] = cmpi eq, %[[NUM_LHS]], %[[C1]] : index
// Handle scalar LHS case
// CHECK-NEXT: %[[VAL_8:.*]] = scf.if %[[LHS_IS_SCALAR]] -> (tensor<*xf32>) {
// CHECK-NEXT: %[[SCALAR_LHS:.*]] = tensor.cast %[[LHS]] : tensor<*xf32> to tensor<f32>
// CHECK-NEXT: %[[RHS_SHAPE_1:.*]] = shape.shape_of %[[RHS]] : tensor<*xf32> -> tensor<?xindex>
// CHECK-NEXT: %[[NUM_RHS:.*]] = shape.num_elements %[[RHS_SHAPE_1]] : tensor<?xindex> -> index
// CHECK-NEXT: %[[SCALAR_LHS:.*]] = "mhlo.reshape"(%[[LHS]]) : (tensor<*xf32>) -> tensor<f32>
// CHECK-NEXT: %[[NUM_RHS:.*]] = shape.num_elements %[[RHS_SHAPE]] : tensor<?xindex> -> index
// CHECK-NEXT: %[[NUM_TENS_RHS:.*]] = tensor.from_elements %[[NUM_RHS]] : tensor<1xindex>
// CHECK-NEXT: %[[RESHAPED_RHS:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[NUM_TENS_RHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK-NEXT: %[[LHS_SCALAR_RESULT:.*]] = chlo.broadcast_add %[[SCALAR_LHS]], %[[RESHAPED_RHS]] : (tensor<f32>, tensor<?xf32>) -> tensor<?xf32>
// CHECK-NEXT: %[[RESHAPED_LHS_SCALAR_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[LHS_SCALAR_RESULT]], %[[RHS_SHAPE_1]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK-NEXT: scf.yield %[[RESHAPED_LHS_SCALAR_RESULT]] : tensor<*xf32>
// CHECK-NEXT: %[[RESHAPED_LHS_SCALAR_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[LHS_SCALAR_RESULT]], %[[RHS_SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK-NEXT: %[[SHAPE_BROADCAST_LHS:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[RESHAPED_EXTENDED_LHS_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[RESHAPED_LHS_SCALAR_RESULT]], %[[SHAPE_BROADCAST_LHS]]) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK-NEXT: scf.yield %[[RESHAPED_EXTENDED_LHS_RESULT]] : tensor<*xf32>
// CHECK-NEXT: } else {
// CHECK-NEXT: %[[RHS_SHAPE:.*]] = shape.shape_of %[[RHS]] : tensor<*xf32> -> tensor<?xindex>
// CHECK-NEXT: %[[RHS_RANK:.*]] = shape.rank %[[RHS_SHAPE]] : tensor<?xindex> -> index
// CHECK-NEXT: %[[RHS_IS_SCALAR:.*]] = cmpi eq, %[[RHS_RANK]], %[[C0]] : index
// CHECK-NEXT: %[[NUM_RHS:.*]] = shape.num_elements %[[RHS_SHAPE]] : tensor<?xindex> -> index
// CHECK-NEXT: %[[RHS_IS_SCALAR:.*]] = cmpi eq, %[[NUM_RHS]], %[[C1]] : index
// Handle scalar RHS case
// CHECK-NEXT: %[[VAL_14:.*]] = scf.if %[[RHS_IS_SCALAR]] -> (tensor<*xf32>) {
// CHECK-NEXT: %[[SCALAR_RHS:.*]] = tensor.cast %[[RHS]] : tensor<*xf32> to tensor<f32>
// CHECK-NEXT: %[[NUM_LHS:.*]] = shape.num_elements %[[LHS_SHAPE]] : tensor<?xindex> -> index
// CHECK-NEXT: %[[SCALAR_RHS:.*]] = "mhlo.reshape"(%[[RHS]]) : (tensor<*xf32>) -> tensor<f32>
// CHECK-NEXT: %[[NUM_TENS_LHS:.*]] = tensor.from_elements %[[NUM_LHS]] : tensor<1xindex>
// CHECK-NEXT: %[[RESHAPED_LHS:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[NUM_TENS_LHS]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK-NEXT: %[[RHS_SCALAR_RESULT:.*]] = chlo.broadcast_add %[[RESHAPED_LHS]], %[[SCALAR_RHS]] : (tensor<?xf32>, tensor<f32>) -> tensor<?xf32>
// CHECK-NEXT: %[[RESHAPED_RHS_SCALAR_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[RHS_SCALAR_RESULT:.*]], %[[LHS_SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK-NEXT: scf.yield %[[RESHAPED_RHS_SCALAR_RESULT]] : tensor<*xf32>
// CHECK-NEXT: %[[SHAPE_BROADCAST_RHS:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
// CHECK-NEXT: %[[RESHAPED_EXTENDED_RHS_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[RESHAPED_RHS_SCALAR_RESULT]], %[[SHAPE_BROADCAST_RHS]]) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK-NEXT: scf.yield %[[RESHAPED_EXTENDED_RHS_RESULT]] : tensor<*xf32>
// CHECK-NEXT: } else {
// CHECK-NEXT: %[[SHAPES_EQ:.*]] = shape.shape_eq %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor<?xindex>, tensor<?xindex>
// Handle equal shapes case
@ -197,10 +199,11 @@ func @addUnrankedUnranked(
// CHECK-NEXT: %[[RESHAPED_SAME_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[FLATTENED_RESULT]], %[[ANY_SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK-NEXT: scf.yield %[[RESHAPED_SAME_RESULT]] : tensor<*xf32>
// CHECK-NEXT: } else {
// CHECK-NEXT: %[[LHS_RANK:.*]] = shape.rank %[[LHS_SHAPE]] : tensor<?xindex> -> index
// CHECK-NEXT: %[[RHS_RANK:.*]] = shape.rank %[[RHS_SHAPE]] : tensor<?xindex> -> index
// CHECK-NEXT: %[[LHS_RANK_GREATER:.*]] = cmpi sgt, %[[LHS_RANK]], %[[RHS_RANK]] : index
// CHECK-NEXT: %[[GREATEST_RANK:.*]] = select %[[LHS_RANK_GREATER]], %[[LHS_RANK]], %[[RHS_RANK]] : index
// Handle rank 1 specialization
// CHECK-NEXT: %[[C1:.*]] = constant 1 : index
// CHECK-NEXT: %[[GREATEST_RANK_IS_1:.*]] = cmpi eq, %[[GREATEST_RANK]], %[[C1]] : index
// CHECK-NEXT: %[[RESULT_RANK_1:.*]] = scf.if %[[GREATEST_RANK_IS_1]] -> (tensor<*xf32>) {
// CHECK-NEXT: %[[CONST_SHAPE_1:.*]] = shape.const_shape [1]