diff --git a/lib/Dialect/mhlo/transforms/rank_specialization.cc b/lib/Dialect/mhlo/transforms/rank_specialization.cc index b00cde4..91aa421 100644 --- a/lib/Dialect/mhlo/transforms/rank_specialization.cc +++ b/lib/Dialect/mhlo/transforms/rank_specialization.cc @@ -158,6 +158,112 @@ struct RankSpecializationClusterPattern : public RewritePattern { } }; +struct MergeRankSpecializationClusterOpsPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(chlo::RankSpecializationClusterOp op, + PatternRewriter &rewriter) const override { + auto preceding_op = + llvm::dyn_cast_or_null( + op->getPrevNode()); + if (!preceding_op) return failure(); + Block *body = op.getBody(); + Block *preceding_body = preceding_op.getBody(); + auto yield_op = llvm::dyn_cast( + op.getBody()->getTerminator()); + auto preceding_yield_op = + llvm::dyn_cast( + preceding_op.getBody()->getTerminator()); + + // Merge cluster operands. Consider only those operands of the second + // cluster that do not originate in the preceding cluster. + SmallVector new_operands; + for (Value v : preceding_op.operands()) new_operands.push_back(v); + for (Value v : op.operands()) { + if (v.getDefiningOp() != preceding_op) new_operands.push_back(v); + } + + // Merge cluster results. Consider only those results of the preceding + // cluster that are not exclusively used as operands to the second cluster. + SmallVector new_unmapped_results; + for (auto it : + llvm::zip(preceding_op.results(), preceding_yield_op.results())) { + Value result, inner_result; + std::tie(result, inner_result) = it; + if (!llvm::all_of(result.getUsers(), + [&](Operation *user) { return user == op; })) { + new_unmapped_results.push_back(inner_result); + } + } + for (Value v : yield_op.results()) new_unmapped_results.push_back(v); + + // Create merged cluster op. + rewriter.setInsertionPoint(preceding_op); + auto loc = op.getLoc(); + auto result_types = llvm::to_vector<16>(llvm::map_range( + new_unmapped_results, [](Value v) { return v.getType(); })); + auto new_op = rewriter.create( + loc, result_types, new_operands); + auto operand_types = llvm::to_vector<16>( + llvm::map_range(new_operands, [](Value v) { return v.getType(); })); + Block *new_body = rewriter.createBlock(&new_op.body(), {}, operand_types); + rewriter.setInsertionPointToStart(new_body); + + // Map operands and copy operations of the preceding cluster into the new + // body. + BlockAndValueMapping bvm; + for (auto it : llvm::enumerate(preceding_body->getArguments())) + bvm.map(it.value(), new_body->getArgument(it.index())); + for (Operation &nested_op : preceding_body->without_terminator()) + rewriter.clone(nested_op, bvm); + + // Map operands and copy operations of the second cluster. If they result + // from the preceeding cluster, we can simply map the corresponding value + // internally. + int64_t block_arg_offset = preceding_op->getNumOperands(); + for (auto it : llvm::zip(body->getArguments(), op.operands())) { + Value block_arg, operand; + std::tie(block_arg, operand) = it; + if (operand.getDefiningOp() == preceding_op) { + auto where = llvm::find(preceding_op.results(), operand); + assert(where.getBase() != nullptr && "expected to find "); + bvm.map(block_arg, + bvm.lookup(preceding_yield_op.getOperand(where.getIndex()))); + } else { + bvm.map(block_arg, new_body->getArgument(block_arg_offset++)); + } + } + for (Operation &nested_op : body->without_terminator()) { + rewriter.clone(nested_op, bvm); + } + + // Yield inner results. + rewriter.create( + loc, + llvm::to_vector<16>(llvm::map_range(new_unmapped_results, [&](Value v) { + return bvm.lookupOrDefault(v); + }))); + + // Replace the two cluster ops with the new corresponding results. + SmallVector preceding_op_replacements; + int64_t i = 0; + for (Value result : preceding_op.results()) { + Value replacement = nullptr; + if (!llvm::all_of(result.getUsers(), + [&](Operation *user) { return user == op; })) { + replacement = new_op->getResult(i++); + } + preceding_op_replacements.push_back(replacement); + } + ValueRange op_replacements = new_op.results().take_back(op.getNumResults()); + rewriter.replaceOp(op, op_replacements); + rewriter.replaceOp(preceding_op, preceding_op_replacements); + + return success(); + } +}; + struct RankSpecializationClusterPass : public RankSpecializationClusterPassBase { void getDependentDialects(DialectRegistry ®istry) const override { @@ -678,7 +784,8 @@ struct RankSpecializationToSCFPass void PopulateRankSpecializationClusterPatterns( MLIRContext *context, OwningRewritePatternList *patterns) { - patterns->insert(context); + patterns->insert(context); } void PopulateRankSpecializationToSCFPatterns(MLIRContext *context, diff --git a/tests/rank-specialization.mlir b/tests/rank-specialization.mlir index 117fbe7..b216695 100644 --- a/tests/rank-specialization.mlir +++ b/tests/rank-specialization.mlir @@ -555,3 +555,29 @@ func @mul(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>) -> tensor<*xf32> { // CHECK-SCF-DAG: %[[RES_SHAPE:.*]] = shape.broadcast %[[SHAPE_ARG0]], %[[SHAPE_ARG1]] // CHECK-SCF-DAG: %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[UNSHAPED_RES_LHS_SCALAR]], %[[RES_SHAPE]]) // CHECK-SCF: return %[[RES]] + +// ----- + +// CHECK-LABEL: @merge_clusters +// CHECK-SAME: (%[[ARG0:.*]]: tensor<*xf64>, %[[ARG1:.*]]: tensor<*xf64>) +func @merge_clusters(%arg0: tensor<*xf64>, %arg1 : tensor<*xf64>) + -> tensor<*xf64> { + // CHECK: %[[RES:.*]] = "chlo.rank_specialization_cluster"(%[[ARG0]], %[[ARG1]]) + // CHECK: ^bb0(%[[ARG0_:.*]]: tensor<*xf64>, %[[ARG1_:.*]]: tensor<*xf64>): + // CHECK: %[[TMP0:.*]] = "mhlo.tanh"(%[[ARG0_]]) + // CHECK: %[[TMP1:.*]] = chlo.broadcast_add %[[TMP0]], %[[ARG1_]] + // CHECK: "chlo.rank_specialization_cluster_yield"(%[[TMP1]]) + // CHECK: return %[[RES]] + %0 = "chlo.rank_specialization_cluster"(%arg0) ({ + ^bb0(%arg0_: tensor<*xf64>): + %1 = "mhlo.tanh"(%arg0_) : (tensor<*xf64>) -> tensor<*xf64> + "chlo.rank_specialization_cluster_yield"(%1) : (tensor<*xf64>) -> () + }) : (tensor<*xf64>) -> (tensor<*xf64>) + %2 = "chlo.rank_specialization_cluster"(%0, %arg1) ({ + ^bb0(%3: tensor<*xf64>, %4: tensor<*xf64>): + %5 = "chlo.broadcast_add"(%3, %4) + : (tensor<*xf64>, tensor<*xf64>) -> tensor<*xf64> + "chlo.rank_specialization_cluster_yield"(%5) : (tensor<*xf64>) -> () + }) : (tensor<*xf64>, tensor<*xf64>) -> (tensor<*xf64>) + return %2 : tensor<*xf64> +}