[MLIR][KernelGen] Merge rank specialization clusters

Merge adjacent rank specialization clusters. Combine their operands, bodies, and
results.

PiperOrigin-RevId: 378433222
This commit is contained in:
A. Unique TensorFlower 2021-06-09 10:06:47 -07:00 committed by TensorFlow MLIR Team
parent b6d8160611
commit b580722041
2 changed files with 134 additions and 1 deletions

View File

@ -158,6 +158,112 @@ struct RankSpecializationClusterPattern : public RewritePattern {
} }
}; };
struct MergeRankSpecializationClusterOpsPattern
: public OpRewritePattern<chlo::RankSpecializationClusterOp> {
using OpRewritePattern<chlo::RankSpecializationClusterOp>::OpRewritePattern;
LogicalResult matchAndRewrite(chlo::RankSpecializationClusterOp op,
PatternRewriter &rewriter) const override {
auto preceding_op =
llvm::dyn_cast_or_null<chlo::RankSpecializationClusterOp>(
op->getPrevNode());
if (!preceding_op) return failure();
Block *body = op.getBody();
Block *preceding_body = preceding_op.getBody();
auto yield_op = llvm::dyn_cast<chlo::RankSpecializationClusterYieldOp>(
op.getBody()->getTerminator());
auto preceding_yield_op =
llvm::dyn_cast<chlo::RankSpecializationClusterYieldOp>(
preceding_op.getBody()->getTerminator());
// Merge cluster operands. Consider only those operands of the second
// cluster that do not originate in the preceding cluster.
SmallVector<Value, 8> new_operands;
for (Value v : preceding_op.operands()) new_operands.push_back(v);
for (Value v : op.operands()) {
if (v.getDefiningOp() != preceding_op) new_operands.push_back(v);
}
// Merge cluster results. Consider only those results of the preceding
// cluster that are not exclusively used as operands to the second cluster.
SmallVector<Value, 8> new_unmapped_results;
for (auto it :
llvm::zip(preceding_op.results(), preceding_yield_op.results())) {
Value result, inner_result;
std::tie(result, inner_result) = it;
if (!llvm::all_of(result.getUsers(),
[&](Operation *user) { return user == op; })) {
new_unmapped_results.push_back(inner_result);
}
}
for (Value v : yield_op.results()) new_unmapped_results.push_back(v);
// Create merged cluster op.
rewriter.setInsertionPoint(preceding_op);
auto loc = op.getLoc();
auto result_types = llvm::to_vector<16>(llvm::map_range(
new_unmapped_results, [](Value v) { return v.getType(); }));
auto new_op = rewriter.create<chlo::RankSpecializationClusterOp>(
loc, result_types, new_operands);
auto operand_types = llvm::to_vector<16>(
llvm::map_range(new_operands, [](Value v) { return v.getType(); }));
Block *new_body = rewriter.createBlock(&new_op.body(), {}, operand_types);
rewriter.setInsertionPointToStart(new_body);
// Map operands and copy operations of the preceding cluster into the new
// body.
BlockAndValueMapping bvm;
for (auto it : llvm::enumerate(preceding_body->getArguments()))
bvm.map(it.value(), new_body->getArgument(it.index()));
for (Operation &nested_op : preceding_body->without_terminator())
rewriter.clone(nested_op, bvm);
// Map operands and copy operations of the second cluster. If they result
// from the preceeding cluster, we can simply map the corresponding value
// internally.
int64_t block_arg_offset = preceding_op->getNumOperands();
for (auto it : llvm::zip(body->getArguments(), op.operands())) {
Value block_arg, operand;
std::tie(block_arg, operand) = it;
if (operand.getDefiningOp() == preceding_op) {
auto where = llvm::find(preceding_op.results(), operand);
assert(where.getBase() != nullptr && "expected to find ");
bvm.map(block_arg,
bvm.lookup(preceding_yield_op.getOperand(where.getIndex())));
} else {
bvm.map(block_arg, new_body->getArgument(block_arg_offset++));
}
}
for (Operation &nested_op : body->without_terminator()) {
rewriter.clone(nested_op, bvm);
}
// Yield inner results.
rewriter.create<chlo::RankSpecializationClusterYieldOp>(
loc,
llvm::to_vector<16>(llvm::map_range(new_unmapped_results, [&](Value v) {
return bvm.lookupOrDefault(v);
})));
// Replace the two cluster ops with the new corresponding results.
SmallVector<Value, 8> preceding_op_replacements;
int64_t i = 0;
for (Value result : preceding_op.results()) {
Value replacement = nullptr;
if (!llvm::all_of(result.getUsers(),
[&](Operation *user) { return user == op; })) {
replacement = new_op->getResult(i++);
}
preceding_op_replacements.push_back(replacement);
}
ValueRange op_replacements = new_op.results().take_back(op.getNumResults());
rewriter.replaceOp(op, op_replacements);
rewriter.replaceOp(preceding_op, preceding_op_replacements);
return success();
}
};
struct RankSpecializationClusterPass struct RankSpecializationClusterPass
: public RankSpecializationClusterPassBase<RankSpecializationClusterPass> { : public RankSpecializationClusterPassBase<RankSpecializationClusterPass> {
void getDependentDialects(DialectRegistry &registry) const override { void getDependentDialects(DialectRegistry &registry) const override {
@ -678,7 +784,8 @@ struct RankSpecializationToSCFPass
void PopulateRankSpecializationClusterPatterns( void PopulateRankSpecializationClusterPatterns(
MLIRContext *context, OwningRewritePatternList *patterns) { MLIRContext *context, OwningRewritePatternList *patterns) {
patterns->insert<RankSpecializationClusterPattern>(context); patterns->insert<MergeRankSpecializationClusterOpsPattern,
RankSpecializationClusterPattern>(context);
} }
void PopulateRankSpecializationToSCFPatterns(MLIRContext *context, void PopulateRankSpecializationToSCFPatterns(MLIRContext *context,

View File

@ -555,3 +555,29 @@ func @mul(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>) -> tensor<*xf32> {
// CHECK-SCF-DAG: %[[RES_SHAPE:.*]] = shape.broadcast %[[SHAPE_ARG0]], %[[SHAPE_ARG1]] // CHECK-SCF-DAG: %[[RES_SHAPE:.*]] = shape.broadcast %[[SHAPE_ARG0]], %[[SHAPE_ARG1]]
// CHECK-SCF-DAG: %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[UNSHAPED_RES_LHS_SCALAR]], %[[RES_SHAPE]]) // CHECK-SCF-DAG: %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[UNSHAPED_RES_LHS_SCALAR]], %[[RES_SHAPE]])
// CHECK-SCF: return %[[RES]] // CHECK-SCF: return %[[RES]]
// -----
// CHECK-LABEL: @merge_clusters
// CHECK-SAME: (%[[ARG0:.*]]: tensor<*xf64>, %[[ARG1:.*]]: tensor<*xf64>)
func @merge_clusters(%arg0: tensor<*xf64>, %arg1 : tensor<*xf64>)
-> tensor<*xf64> {
// CHECK: %[[RES:.*]] = "chlo.rank_specialization_cluster"(%[[ARG0]], %[[ARG1]])
// CHECK: ^bb0(%[[ARG0_:.*]]: tensor<*xf64>, %[[ARG1_:.*]]: tensor<*xf64>):
// CHECK: %[[TMP0:.*]] = "mhlo.tanh"(%[[ARG0_]])
// CHECK: %[[TMP1:.*]] = chlo.broadcast_add %[[TMP0]], %[[ARG1_]]
// CHECK: "chlo.rank_specialization_cluster_yield"(%[[TMP1]])
// CHECK: return %[[RES]]
%0 = "chlo.rank_specialization_cluster"(%arg0) ({
^bb0(%arg0_: tensor<*xf64>):
%1 = "mhlo.tanh"(%arg0_) : (tensor<*xf64>) -> tensor<*xf64>
"chlo.rank_specialization_cluster_yield"(%1) : (tensor<*xf64>) -> ()
}) : (tensor<*xf64>) -> (tensor<*xf64>)
%2 = "chlo.rank_specialization_cluster"(%0, %arg1) ({
^bb0(%3: tensor<*xf64>, %4: tensor<*xf64>):
%5 = "chlo.broadcast_add"(%3, %4)
: (tensor<*xf64>, tensor<*xf64>) -> tensor<*xf64>
"chlo.rank_specialization_cluster_yield"(%5) : (tensor<*xf64>) -> ()
}) : (tensor<*xf64>, tensor<*xf64>) -> (tensor<*xf64>)
return %2 : tensor<*xf64>
}