[MLIR][KernelGen] Merge rank specialization clusters
Merge adjacent rank specialization clusters. Combine their operands, bodies, and results. PiperOrigin-RevId: 378433222
This commit is contained in:
parent
b6d8160611
commit
b580722041
|
@ -158,6 +158,112 @@ struct RankSpecializationClusterPattern : public RewritePattern {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct MergeRankSpecializationClusterOpsPattern
|
||||||
|
: public OpRewritePattern<chlo::RankSpecializationClusterOp> {
|
||||||
|
using OpRewritePattern<chlo::RankSpecializationClusterOp>::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(chlo::RankSpecializationClusterOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
auto preceding_op =
|
||||||
|
llvm::dyn_cast_or_null<chlo::RankSpecializationClusterOp>(
|
||||||
|
op->getPrevNode());
|
||||||
|
if (!preceding_op) return failure();
|
||||||
|
Block *body = op.getBody();
|
||||||
|
Block *preceding_body = preceding_op.getBody();
|
||||||
|
auto yield_op = llvm::dyn_cast<chlo::RankSpecializationClusterYieldOp>(
|
||||||
|
op.getBody()->getTerminator());
|
||||||
|
auto preceding_yield_op =
|
||||||
|
llvm::dyn_cast<chlo::RankSpecializationClusterYieldOp>(
|
||||||
|
preceding_op.getBody()->getTerminator());
|
||||||
|
|
||||||
|
// Merge cluster operands. Consider only those operands of the second
|
||||||
|
// cluster that do not originate in the preceding cluster.
|
||||||
|
SmallVector<Value, 8> new_operands;
|
||||||
|
for (Value v : preceding_op.operands()) new_operands.push_back(v);
|
||||||
|
for (Value v : op.operands()) {
|
||||||
|
if (v.getDefiningOp() != preceding_op) new_operands.push_back(v);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Merge cluster results. Consider only those results of the preceding
|
||||||
|
// cluster that are not exclusively used as operands to the second cluster.
|
||||||
|
SmallVector<Value, 8> new_unmapped_results;
|
||||||
|
for (auto it :
|
||||||
|
llvm::zip(preceding_op.results(), preceding_yield_op.results())) {
|
||||||
|
Value result, inner_result;
|
||||||
|
std::tie(result, inner_result) = it;
|
||||||
|
if (!llvm::all_of(result.getUsers(),
|
||||||
|
[&](Operation *user) { return user == op; })) {
|
||||||
|
new_unmapped_results.push_back(inner_result);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (Value v : yield_op.results()) new_unmapped_results.push_back(v);
|
||||||
|
|
||||||
|
// Create merged cluster op.
|
||||||
|
rewriter.setInsertionPoint(preceding_op);
|
||||||
|
auto loc = op.getLoc();
|
||||||
|
auto result_types = llvm::to_vector<16>(llvm::map_range(
|
||||||
|
new_unmapped_results, [](Value v) { return v.getType(); }));
|
||||||
|
auto new_op = rewriter.create<chlo::RankSpecializationClusterOp>(
|
||||||
|
loc, result_types, new_operands);
|
||||||
|
auto operand_types = llvm::to_vector<16>(
|
||||||
|
llvm::map_range(new_operands, [](Value v) { return v.getType(); }));
|
||||||
|
Block *new_body = rewriter.createBlock(&new_op.body(), {}, operand_types);
|
||||||
|
rewriter.setInsertionPointToStart(new_body);
|
||||||
|
|
||||||
|
// Map operands and copy operations of the preceding cluster into the new
|
||||||
|
// body.
|
||||||
|
BlockAndValueMapping bvm;
|
||||||
|
for (auto it : llvm::enumerate(preceding_body->getArguments()))
|
||||||
|
bvm.map(it.value(), new_body->getArgument(it.index()));
|
||||||
|
for (Operation &nested_op : preceding_body->without_terminator())
|
||||||
|
rewriter.clone(nested_op, bvm);
|
||||||
|
|
||||||
|
// Map operands and copy operations of the second cluster. If they result
|
||||||
|
// from the preceeding cluster, we can simply map the corresponding value
|
||||||
|
// internally.
|
||||||
|
int64_t block_arg_offset = preceding_op->getNumOperands();
|
||||||
|
for (auto it : llvm::zip(body->getArguments(), op.operands())) {
|
||||||
|
Value block_arg, operand;
|
||||||
|
std::tie(block_arg, operand) = it;
|
||||||
|
if (operand.getDefiningOp() == preceding_op) {
|
||||||
|
auto where = llvm::find(preceding_op.results(), operand);
|
||||||
|
assert(where.getBase() != nullptr && "expected to find ");
|
||||||
|
bvm.map(block_arg,
|
||||||
|
bvm.lookup(preceding_yield_op.getOperand(where.getIndex())));
|
||||||
|
} else {
|
||||||
|
bvm.map(block_arg, new_body->getArgument(block_arg_offset++));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (Operation &nested_op : body->without_terminator()) {
|
||||||
|
rewriter.clone(nested_op, bvm);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Yield inner results.
|
||||||
|
rewriter.create<chlo::RankSpecializationClusterYieldOp>(
|
||||||
|
loc,
|
||||||
|
llvm::to_vector<16>(llvm::map_range(new_unmapped_results, [&](Value v) {
|
||||||
|
return bvm.lookupOrDefault(v);
|
||||||
|
})));
|
||||||
|
|
||||||
|
// Replace the two cluster ops with the new corresponding results.
|
||||||
|
SmallVector<Value, 8> preceding_op_replacements;
|
||||||
|
int64_t i = 0;
|
||||||
|
for (Value result : preceding_op.results()) {
|
||||||
|
Value replacement = nullptr;
|
||||||
|
if (!llvm::all_of(result.getUsers(),
|
||||||
|
[&](Operation *user) { return user == op; })) {
|
||||||
|
replacement = new_op->getResult(i++);
|
||||||
|
}
|
||||||
|
preceding_op_replacements.push_back(replacement);
|
||||||
|
}
|
||||||
|
ValueRange op_replacements = new_op.results().take_back(op.getNumResults());
|
||||||
|
rewriter.replaceOp(op, op_replacements);
|
||||||
|
rewriter.replaceOp(preceding_op, preceding_op_replacements);
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
struct RankSpecializationClusterPass
|
struct RankSpecializationClusterPass
|
||||||
: public RankSpecializationClusterPassBase<RankSpecializationClusterPass> {
|
: public RankSpecializationClusterPassBase<RankSpecializationClusterPass> {
|
||||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||||
|
@ -678,7 +784,8 @@ struct RankSpecializationToSCFPass
|
||||||
|
|
||||||
void PopulateRankSpecializationClusterPatterns(
|
void PopulateRankSpecializationClusterPatterns(
|
||||||
MLIRContext *context, OwningRewritePatternList *patterns) {
|
MLIRContext *context, OwningRewritePatternList *patterns) {
|
||||||
patterns->insert<RankSpecializationClusterPattern>(context);
|
patterns->insert<MergeRankSpecializationClusterOpsPattern,
|
||||||
|
RankSpecializationClusterPattern>(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
void PopulateRankSpecializationToSCFPatterns(MLIRContext *context,
|
void PopulateRankSpecializationToSCFPatterns(MLIRContext *context,
|
||||||
|
|
|
@ -555,3 +555,29 @@ func @mul(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>) -> tensor<*xf32> {
|
||||||
// CHECK-SCF-DAG: %[[RES_SHAPE:.*]] = shape.broadcast %[[SHAPE_ARG0]], %[[SHAPE_ARG1]]
|
// CHECK-SCF-DAG: %[[RES_SHAPE:.*]] = shape.broadcast %[[SHAPE_ARG0]], %[[SHAPE_ARG1]]
|
||||||
// CHECK-SCF-DAG: %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[UNSHAPED_RES_LHS_SCALAR]], %[[RES_SHAPE]])
|
// CHECK-SCF-DAG: %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[UNSHAPED_RES_LHS_SCALAR]], %[[RES_SHAPE]])
|
||||||
// CHECK-SCF: return %[[RES]]
|
// CHECK-SCF: return %[[RES]]
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @merge_clusters
|
||||||
|
// CHECK-SAME: (%[[ARG0:.*]]: tensor<*xf64>, %[[ARG1:.*]]: tensor<*xf64>)
|
||||||
|
func @merge_clusters(%arg0: tensor<*xf64>, %arg1 : tensor<*xf64>)
|
||||||
|
-> tensor<*xf64> {
|
||||||
|
// CHECK: %[[RES:.*]] = "chlo.rank_specialization_cluster"(%[[ARG0]], %[[ARG1]])
|
||||||
|
// CHECK: ^bb0(%[[ARG0_:.*]]: tensor<*xf64>, %[[ARG1_:.*]]: tensor<*xf64>):
|
||||||
|
// CHECK: %[[TMP0:.*]] = "mhlo.tanh"(%[[ARG0_]])
|
||||||
|
// CHECK: %[[TMP1:.*]] = chlo.broadcast_add %[[TMP0]], %[[ARG1_]]
|
||||||
|
// CHECK: "chlo.rank_specialization_cluster_yield"(%[[TMP1]])
|
||||||
|
// CHECK: return %[[RES]]
|
||||||
|
%0 = "chlo.rank_specialization_cluster"(%arg0) ({
|
||||||
|
^bb0(%arg0_: tensor<*xf64>):
|
||||||
|
%1 = "mhlo.tanh"(%arg0_) : (tensor<*xf64>) -> tensor<*xf64>
|
||||||
|
"chlo.rank_specialization_cluster_yield"(%1) : (tensor<*xf64>) -> ()
|
||||||
|
}) : (tensor<*xf64>) -> (tensor<*xf64>)
|
||||||
|
%2 = "chlo.rank_specialization_cluster"(%0, %arg1) ({
|
||||||
|
^bb0(%3: tensor<*xf64>, %4: tensor<*xf64>):
|
||||||
|
%5 = "chlo.broadcast_add"(%3, %4)
|
||||||
|
: (tensor<*xf64>, tensor<*xf64>) -> tensor<*xf64>
|
||||||
|
"chlo.rank_specialization_cluster_yield"(%5) : (tensor<*xf64>) -> ()
|
||||||
|
}) : (tensor<*xf64>, tensor<*xf64>) -> (tensor<*xf64>)
|
||||||
|
return %2 : tensor<*xf64>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue