Move code from helper struct to the only user.
We don't need the separate helper struct anymore, because it is now only used in one place. PiperOrigin-RevId: 366012639
This commit is contained in:
parent
4033a56750
commit
c8157ba4df
|
@ -206,180 +206,6 @@ struct ConvertUnrankedScalarDynamicBroadcastBinaryOp
|
|||
}
|
||||
};
|
||||
|
||||
template <typename ChloOpTy>
|
||||
struct ConvertUnrankedDynamicBroadcastOpHelper {
|
||||
// Returns the dynamic result of checking the given value is effectively a
|
||||
// scalar shape (i.e. the number of elements is 1).
|
||||
static Value GreaterRankIsN(OpBuilder &builder, Location loc,
|
||||
Value actual_rank, int targeted_rank) {
|
||||
return builder.create<CmpIOp>(
|
||||
loc, CmpIPredicate::eq, actual_rank,
|
||||
builder.create<ConstantIndexOp>(loc, targeted_rank));
|
||||
}
|
||||
|
||||
static scf::IfOp createIfOpForRankSpecializedBroadcastAndOp(
|
||||
OpBuilder &builder, ChloOpTy op, Value actual_rank, int targeted_rank) {
|
||||
// Create the if block to place the current specialized logic in.
|
||||
Value greater_rank_is_n =
|
||||
GreaterRankIsN(builder, op.getLoc(), actual_rank, targeted_rank);
|
||||
return builder.create<scf::IfOp>(op.getLoc(), op.getResult().getType(),
|
||||
greater_rank_is_n, true);
|
||||
}
|
||||
|
||||
static Value createBroadcastToKnownRank(OpBuilder &builder, ChloOpTy op,
|
||||
Value shape, int targeted_rank) {
|
||||
auto loc = op.getLoc();
|
||||
SmallVector<int64_t, 6> ranked_shape(targeted_rank, 1);
|
||||
auto unknown_rank_extent_tensor_type = RankedTensorType::get(
|
||||
{RankedTensorType::kDynamicSize}, builder.getIndexType());
|
||||
auto known_rank_extent_tensor_type =
|
||||
RankedTensorType::get({targeted_rank}, builder.getIndexType());
|
||||
Value ranked_shape_val = builder.create<shape::ConstShapeOp>(
|
||||
loc, known_rank_extent_tensor_type,
|
||||
mlir::DenseIntElementsAttr::get(known_rank_extent_tensor_type,
|
||||
ranked_shape));
|
||||
Value extended_value = builder.create<shape::BroadcastOp>(
|
||||
loc, unknown_rank_extent_tensor_type, shape, ranked_shape_val, nullptr);
|
||||
return builder.create<tensor::CastOp>(loc, known_rank_extent_tensor_type,
|
||||
extended_value);
|
||||
}
|
||||
|
||||
// Create the if statement and code for a broadcasting op with a result of a
|
||||
// given rank.
|
||||
static void createRankSpecializedBroadcastAndOp(OpBuilder &if_builder,
|
||||
ChloOpTy op,
|
||||
ValueRange operands,
|
||||
ValueRange operand_shapes,
|
||||
int targeted_rank) {
|
||||
auto loc = op.getLoc();
|
||||
SmallVector<Value, 2> reshaped_operands;
|
||||
|
||||
auto dynamic_dimensions = llvm::SmallVector<int64_t, 6>(
|
||||
targeted_rank, RankedTensorType::kDynamicSize);
|
||||
|
||||
for (auto it : llvm::zip(operands, operand_shapes)) {
|
||||
Value operand, shape;
|
||||
std::tie(operand, shape) = it;
|
||||
// Handle shape broadcasting and inference.
|
||||
Value extended_operand_casted =
|
||||
createBroadcastToKnownRank(if_builder, op, shape, targeted_rank);
|
||||
|
||||
// 1. Reshape operands to the given rank (with the same number of
|
||||
// elements)
|
||||
// 2. Compute the ranked-broadcasted ChloOp (which will assert that the
|
||||
// ops
|
||||
// can be broadcasted and do the actual broadcasting)
|
||||
// 3. Type erase the output back to unranked
|
||||
auto reshaped_type = RankedTensorType::get(
|
||||
dynamic_dimensions,
|
||||
operand.getType().template dyn_cast<TensorType>().getElementType());
|
||||
Value reshaped_operand = if_builder.create<mhlo::DynamicReshapeOp>(
|
||||
loc, reshaped_type, operand, extended_operand_casted);
|
||||
reshaped_operands.push_back(reshaped_operand);
|
||||
}
|
||||
auto result_element_type = op.getResult()
|
||||
.getType()
|
||||
.template dyn_cast<TensorType>()
|
||||
.getElementType();
|
||||
auto result_type =
|
||||
RankedTensorType::get(dynamic_dimensions, result_element_type);
|
||||
Value result = if_builder.create<ChloOpTy>(
|
||||
loc, ArrayRef<Type>{result_type}, reshaped_operands, op->getAttrs());
|
||||
Value reshaped_result = if_builder.create<tensor::CastOp>(
|
||||
loc, UnrankedTensorType::get(result_element_type), result);
|
||||
if_builder.create<scf::YieldOp>(loc, reshaped_result);
|
||||
}
|
||||
|
||||
// Iterates over the desired ranks to be specialized and generates the code
|
||||
// snippet for each case.
|
||||
static Value HandleBroadcastAndOp(OpBuilder &rewriter, ChloOpTy op,
|
||||
ValueRange operands) {
|
||||
auto loc = op.getLoc();
|
||||
|
||||
// Get the minimum broadcast shapes of the operands.
|
||||
SmallVector<Value> shapes;
|
||||
shapes.reserve(operands.size());
|
||||
auto extent_tensor_type = RankedTensorType::get({ShapedType::kDynamicSize},
|
||||
rewriter.getIndexType());
|
||||
for (Value operand : operands) {
|
||||
Value shape =
|
||||
rewriter.create<shape::ShapeOfOp>(loc, extent_tensor_type, operand);
|
||||
shapes.push_back(shape);
|
||||
}
|
||||
auto broadcast_shape = rewriter.create<shape::BroadcastOp>(
|
||||
loc, extent_tensor_type, shapes, nullptr);
|
||||
SmallVector<Type> result_types(shapes.size(), extent_tensor_type);
|
||||
auto reduced_shapes =
|
||||
rewriter
|
||||
.create<chlo::MinimumBroadcastShapesOp>(loc, result_types, shapes)
|
||||
.results();
|
||||
SmallVector<Value> reshaped_operands;
|
||||
reshaped_operands.reserve(operands.size());
|
||||
for (auto it : llvm::zip(operands, reduced_shapes)) {
|
||||
Value operand;
|
||||
Value reduced_shape;
|
||||
std::tie(operand, reduced_shape) = it;
|
||||
auto reshaped_operand = rewriter.create<mhlo::DynamicReshapeOp>(
|
||||
loc, operand.getType(), operand, reduced_shape);
|
||||
reshaped_operands.push_back(reshaped_operand);
|
||||
}
|
||||
|
||||
// Find the largest rank of the operands.
|
||||
Value greater_rank;
|
||||
for (Value shape : reduced_shapes) {
|
||||
Value rank =
|
||||
rewriter.create<shape::RankOp>(loc, rewriter.getIndexType(), shape);
|
||||
if (!greater_rank) {
|
||||
greater_rank = rank;
|
||||
} else {
|
||||
Value greater_rank_compare = rewriter.create<CmpIOp>(
|
||||
loc, CmpIPredicate::sgt, greater_rank, rank);
|
||||
greater_rank = rewriter.create<SelectOp>(loc, greater_rank_compare,
|
||||
greater_rank, rank);
|
||||
}
|
||||
}
|
||||
|
||||
// Generate a list of nested if/else statements to handle rank
|
||||
// specializations from 1 to `kMaxRankSpecialization`.
|
||||
scf::IfOp if_op = createIfOpForRankSpecializedBroadcastAndOp(
|
||||
rewriter, op, greater_rank, 1);
|
||||
OpBuilder if_builder = if_op.getThenBodyBuilder(rewriter.getListener());
|
||||
createRankSpecializedBroadcastAndOp(if_builder, op, reshaped_operands,
|
||||
reduced_shapes, 1);
|
||||
|
||||
// Put each subsequent rank specialization inside the else statement of the
|
||||
// previous one.
|
||||
OpBuilder else_builder = if_op.getElseBodyBuilder(rewriter.getListener());
|
||||
constexpr int kMaxRankSpecialization = 5;
|
||||
for (int i = 2; i < kMaxRankSpecialization; i++) {
|
||||
auto inner_if = createIfOpForRankSpecializedBroadcastAndOp(
|
||||
else_builder, op, greater_rank, i);
|
||||
if_builder = inner_if.getThenBodyBuilder(rewriter.getListener());
|
||||
createRankSpecializedBroadcastAndOp(if_builder, op, reshaped_operands,
|
||||
reduced_shapes, i);
|
||||
else_builder.create<scf::YieldOp>(loc, inner_if.getResult(0));
|
||||
else_builder = inner_if.getElseBodyBuilder(rewriter.getListener());
|
||||
}
|
||||
// Fire an assertion if none of the rank specializations applied (one of
|
||||
// the ranks was greater than `kMaxRankSpecialization`).
|
||||
else_builder.create<AssertOp>(
|
||||
loc,
|
||||
GreaterRankIsN(else_builder, op.getLoc(), greater_rank,
|
||||
kMaxRankSpecialization),
|
||||
"Input for dynamic binary op lowering was of a rank greater than " +
|
||||
std::to_string(kMaxRankSpecialization));
|
||||
// Add the rank 5 specialization to the innermost else block.
|
||||
createRankSpecializedBroadcastAndOp(else_builder, op, reshaped_operands,
|
||||
reduced_shapes, kMaxRankSpecialization);
|
||||
|
||||
// Return the reshaped result of the outermost if statement.
|
||||
auto result = if_op.getResult(0);
|
||||
auto reshaped_result = rewriter.create<mhlo::DynamicReshapeOp>(
|
||||
loc, result.getType(), result, broadcast_shape);
|
||||
return reshaped_result;
|
||||
}
|
||||
};
|
||||
|
||||
// Handles lowering of the following pattern to patterns that will be further
|
||||
// matched by other patterns until they result in LHLO:
|
||||
// %result = "chlo.op"(%op0, %op1, ...) : (<*xTy>, <*xTy>, ...) -> <*xTy>
|
||||
|
@ -498,8 +324,7 @@ struct ConvertUnrankedDynamicBroadcastNaryOp
|
|||
if_eq_shapes_op.getElseBodyBuilder(rewriter.getListener());
|
||||
if_neq_shapes_builder.create<scf::YieldOp>(
|
||||
loc,
|
||||
ConvertUnrankedDynamicBroadcastOpHelper<ChloOpTy>::HandleBroadcastAndOp(
|
||||
if_neq_shapes_builder, op, transformed_operands));
|
||||
HandleBroadcastAndOp(if_neq_shapes_builder, op, transformed_operands));
|
||||
|
||||
rewriter.replaceOp(op, {if_op.getResult(0)});
|
||||
return success();
|
||||
|
@ -529,6 +354,177 @@ struct ConvertUnrankedDynamicBroadcastNaryOp
|
|||
return builder.create<mhlo::DynamicReshapeOp>(loc, result_type, value,
|
||||
broadcast_shape);
|
||||
}
|
||||
|
||||
// Returns the dynamic result of checking the given value is effectively a
|
||||
// scalar shape (i.e. the number of elements is 1).
|
||||
Value GreaterRankIsN(OpBuilder &builder, Location loc, Value actual_rank,
|
||||
int targeted_rank) const {
|
||||
return builder.create<CmpIOp>(
|
||||
loc, CmpIPredicate::eq, actual_rank,
|
||||
builder.create<ConstantIndexOp>(loc, targeted_rank));
|
||||
}
|
||||
|
||||
scf::IfOp createIfOpForRankSpecializedBroadcastAndOp(
|
||||
OpBuilder &builder, ChloOpTy op, Value actual_rank,
|
||||
int targeted_rank) const {
|
||||
// Create the if block to place the current specialized logic in.
|
||||
Value greater_rank_is_n =
|
||||
GreaterRankIsN(builder, op.getLoc(), actual_rank, targeted_rank);
|
||||
return builder.create<scf::IfOp>(op.getLoc(), op.getResult().getType(),
|
||||
greater_rank_is_n, true);
|
||||
}
|
||||
|
||||
Value createBroadcastToKnownRank(OpBuilder &builder, ChloOpTy op, Value shape,
|
||||
int targeted_rank) const {
|
||||
auto loc = op.getLoc();
|
||||
SmallVector<int64_t, 6> ranked_shape(targeted_rank, 1);
|
||||
auto unknown_rank_extent_tensor_type = RankedTensorType::get(
|
||||
{RankedTensorType::kDynamicSize}, builder.getIndexType());
|
||||
auto known_rank_extent_tensor_type =
|
||||
RankedTensorType::get({targeted_rank}, builder.getIndexType());
|
||||
Value ranked_shape_val = builder.create<shape::ConstShapeOp>(
|
||||
loc, known_rank_extent_tensor_type,
|
||||
mlir::DenseIntElementsAttr::get(known_rank_extent_tensor_type,
|
||||
ranked_shape));
|
||||
Value extended_value = builder.create<shape::BroadcastOp>(
|
||||
loc, unknown_rank_extent_tensor_type, shape, ranked_shape_val, nullptr);
|
||||
return builder.create<tensor::CastOp>(loc, known_rank_extent_tensor_type,
|
||||
extended_value);
|
||||
}
|
||||
|
||||
// Create the if statement and code for a broadcasting op with a result of a
|
||||
// given rank.
|
||||
void createRankSpecializedBroadcastAndOp(OpBuilder &if_builder, ChloOpTy op,
|
||||
ValueRange operands,
|
||||
ValueRange operand_shapes,
|
||||
int targeted_rank) const {
|
||||
auto loc = op.getLoc();
|
||||
SmallVector<Value, 2> reshaped_operands;
|
||||
|
||||
auto dynamic_dimensions = llvm::SmallVector<int64_t, 6>(
|
||||
targeted_rank, RankedTensorType::kDynamicSize);
|
||||
|
||||
for (auto it : llvm::zip(operands, operand_shapes)) {
|
||||
Value operand, shape;
|
||||
std::tie(operand, shape) = it;
|
||||
// Handle shape broadcasting and inference.
|
||||
Value extended_operand_casted =
|
||||
createBroadcastToKnownRank(if_builder, op, shape, targeted_rank);
|
||||
|
||||
// 1. Reshape operands to the given rank (with the same number of
|
||||
// elements)
|
||||
// 2. Compute the ranked-broadcasted ChloOp (which will assert that the
|
||||
// ops
|
||||
// can be broadcasted and do the actual broadcasting)
|
||||
// 3. Type erase the output back to unranked
|
||||
auto reshaped_type = RankedTensorType::get(
|
||||
dynamic_dimensions,
|
||||
operand.getType().template dyn_cast<TensorType>().getElementType());
|
||||
Value reshaped_operand = if_builder.create<mhlo::DynamicReshapeOp>(
|
||||
loc, reshaped_type, operand, extended_operand_casted);
|
||||
reshaped_operands.push_back(reshaped_operand);
|
||||
}
|
||||
auto result_element_type = op.getResult()
|
||||
.getType()
|
||||
.template dyn_cast<TensorType>()
|
||||
.getElementType();
|
||||
auto result_type =
|
||||
RankedTensorType::get(dynamic_dimensions, result_element_type);
|
||||
Value result = if_builder.create<ChloOpTy>(
|
||||
loc, ArrayRef<Type>{result_type}, reshaped_operands, op->getAttrs());
|
||||
Value reshaped_result = if_builder.create<tensor::CastOp>(
|
||||
loc, UnrankedTensorType::get(result_element_type), result);
|
||||
if_builder.create<scf::YieldOp>(loc, reshaped_result);
|
||||
}
|
||||
|
||||
// Iterates over the desired ranks to be specialized and generates the code
|
||||
// snippet for each case.
|
||||
Value HandleBroadcastAndOp(OpBuilder &rewriter, ChloOpTy op,
|
||||
ValueRange operands) const {
|
||||
auto loc = op.getLoc();
|
||||
|
||||
// Get the minimum broadcast shapes of the operands.
|
||||
SmallVector<Value> shapes;
|
||||
shapes.reserve(operands.size());
|
||||
auto extent_tensor_type = RankedTensorType::get({ShapedType::kDynamicSize},
|
||||
rewriter.getIndexType());
|
||||
for (Value operand : operands) {
|
||||
Value shape =
|
||||
rewriter.create<shape::ShapeOfOp>(loc, extent_tensor_type, operand);
|
||||
shapes.push_back(shape);
|
||||
}
|
||||
auto broadcast_shape = rewriter.create<shape::BroadcastOp>(
|
||||
loc, extent_tensor_type, shapes, nullptr);
|
||||
SmallVector<Type> result_types(shapes.size(), extent_tensor_type);
|
||||
auto reduced_shapes =
|
||||
rewriter
|
||||
.create<chlo::MinimumBroadcastShapesOp>(loc, result_types, shapes)
|
||||
.results();
|
||||
SmallVector<Value> reshaped_operands;
|
||||
reshaped_operands.reserve(operands.size());
|
||||
for (auto it : llvm::zip(operands, reduced_shapes)) {
|
||||
Value operand;
|
||||
Value reduced_shape;
|
||||
std::tie(operand, reduced_shape) = it;
|
||||
auto reshaped_operand = rewriter.create<mhlo::DynamicReshapeOp>(
|
||||
loc, operand.getType(), operand, reduced_shape);
|
||||
reshaped_operands.push_back(reshaped_operand);
|
||||
}
|
||||
|
||||
// Find the largest rank of the operands.
|
||||
Value greater_rank;
|
||||
for (Value shape : reduced_shapes) {
|
||||
Value rank =
|
||||
rewriter.create<shape::RankOp>(loc, rewriter.getIndexType(), shape);
|
||||
if (!greater_rank) {
|
||||
greater_rank = rank;
|
||||
} else {
|
||||
Value greater_rank_compare = rewriter.create<CmpIOp>(
|
||||
loc, CmpIPredicate::sgt, greater_rank, rank);
|
||||
greater_rank = rewriter.create<SelectOp>(loc, greater_rank_compare,
|
||||
greater_rank, rank);
|
||||
}
|
||||
}
|
||||
|
||||
// Generate a list of nested if/else statements to handle rank
|
||||
// specializations from 1 to `kMaxRankSpecialization`.
|
||||
scf::IfOp if_op = createIfOpForRankSpecializedBroadcastAndOp(
|
||||
rewriter, op, greater_rank, 1);
|
||||
OpBuilder if_builder = if_op.getThenBodyBuilder(rewriter.getListener());
|
||||
createRankSpecializedBroadcastAndOp(if_builder, op, reshaped_operands,
|
||||
reduced_shapes, 1);
|
||||
|
||||
// Put each subsequent rank specialization inside the else statement of the
|
||||
// previous one.
|
||||
OpBuilder else_builder = if_op.getElseBodyBuilder(rewriter.getListener());
|
||||
constexpr int kMaxRankSpecialization = 5;
|
||||
for (int i = 2; i < kMaxRankSpecialization; i++) {
|
||||
auto inner_if = createIfOpForRankSpecializedBroadcastAndOp(
|
||||
else_builder, op, greater_rank, i);
|
||||
if_builder = inner_if.getThenBodyBuilder(rewriter.getListener());
|
||||
createRankSpecializedBroadcastAndOp(if_builder, op, reshaped_operands,
|
||||
reduced_shapes, i);
|
||||
else_builder.create<scf::YieldOp>(loc, inner_if.getResult(0));
|
||||
else_builder = inner_if.getElseBodyBuilder(rewriter.getListener());
|
||||
}
|
||||
// Fire an assertion if none of the rank specializations applied (one of
|
||||
// the ranks was greater than `kMaxRankSpecialization`).
|
||||
else_builder.create<AssertOp>(
|
||||
loc,
|
||||
GreaterRankIsN(else_builder, op.getLoc(), greater_rank,
|
||||
kMaxRankSpecialization),
|
||||
"Input for dynamic binary op lowering was of a rank greater than " +
|
||||
std::to_string(kMaxRankSpecialization));
|
||||
// Add the rank 5 specialization to the innermost else block.
|
||||
createRankSpecializedBroadcastAndOp(else_builder, op, reshaped_operands,
|
||||
reduced_shapes, kMaxRankSpecialization);
|
||||
|
||||
// Return the reshaped result of the outermost if statement.
|
||||
auto result = if_op.getResult(0);
|
||||
auto reshaped_result = rewriter.create<mhlo::DynamicReshapeOp>(
|
||||
loc, result.getType(), result, broadcast_shape);
|
||||
return reshaped_result;
|
||||
}
|
||||
};
|
||||
|
||||
struct TransformUnrankedHloPass
|
||||
|
|
Loading…
Reference in New Issue