mlir-hlo/lib/Dialect/mhlo/transforms/lhlo_fusion.cc

571 lines
21 KiB
C++

/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
#include "mlir-hlo/Dialect/mhlo/transforms/fusion_utils.h"
#include "mlir-hlo/utils/cycle_detector.h"
#include "mlir/Dialect/Shape/IR/Shape.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/Matchers.h"
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
#include "mlir/Transforms/RegionUtils.h" // TF:llvm-project
// This pass has similar functionality of the fusion pass in XLA stack.
// However, unlike XLA, it targets the fully dynamic shape scenario.
// Currently, it implements the kLoop and kInput fusion templates.
// During conversion, it tries to greedily find kLoop/kInput fusion
// patterns.
//
// Similar to XLA, this pass supports fusion pattern having multiple outputs
// if all the shape of outputs are consistent. Following are some examples.
//
// kLoop kInput
// +----+ +----+ +----+ +----+ +----+ +----+
// |elem| |elem| |elem| |elem<----+elem+---->elem+----+
// +-+--+ +-+--+ +-+--+ +-+--+ +----+ +-+--+ |
// | | | | | |
// | | | | |
// +-v--+ | +-v--+ +--v---+ +--v---+ |
// |elem+<---+----<+elem| |reduce| |reduce| |
// +-+--+ +-+--+ +--+---+ +--+---+ |
// | | | | |
// | | | | |
// v v v v v
//
// To this end, we also add an simple shape constraint analysis phase.
// For kLoop fusion template, it requires all the outputs of the fused
// pattern have the same shape. However, we don't know the actual value
// of the shape at the compile time in the dynamic shape world.
// Fortunately, we could still infer the relationship among different ops
// according to their shape constraint traits. Currently, We only consider
// shape equality propagation for elementwise ops (assuming that implicit
// shape broadcast is forbidden). The above process could be built on the
// shape dialect once it is ready.
//
// TODO(disc): This file implements fusion on buffer level, re-visit this after
// more shape inference/constraint infras are ready in mhlo level.
// TODO(disc): Not using fusibility interface a.t.m, re-visit this if necessary.
namespace mlir {
namespace lmhlo {
namespace {
struct FusionOptions {
// Maximum allowed number of arguments per fused kernel. Here arguments
// include both ready-only buffers and writable buffers.
int max_num_arguments_per_kernel;
};
// A fusion planner that can propose a fusion plan for a block of ops.
// The fusion plan is consisted of a group of fusion patterns.
//
// Currently all proposed patterns followed xla kLoop/kInput like fusion
// templates while are adapted to the fully dynamic shape world.
//
// kLoop fusion template satisfies:
// - all ops in the fusion pattern are element-wise.
// - all the shapes of outputs of fusion pattern are same or have same number
// of elements, and thus can fit into a same parallel loop.
//
// kInput fusion template satisfies:
// - any op in the fusion pattern is either element-wise or a reduction.
// - if a op is a reduction, its output cannot be consumed by other
// ops in the same fusion pattern.
// - all the effective shapes of outputs of fusion pattern are same.
// - For element-wise op, its effective shape is its output shape.
// - For reduction op, its effective shape is its operand shape.
// - currently our downstreaming codegen engine only support 2d -> 1d tensor
// reduction. TODO(disc): lift this limitation.
// - 2D row reduction: out[i] = sum({in[i][j] for all j})
// - 2D column reduction: out[j] = sum({in[i][j] for all i}
class FusionPlanner {
public:
explicit FusionPlanner(const FusionOptions& options, Block* block)
: options_(options), block_(block) {
// Move up metadata-only ops (e.g. dim, shape_of) as far as possible.
MoveUpMetadataOnlyOpsForFusion();
for (Operation& op : *block) {
op_list_.push_back(&op);
}
shape_analysis_.reset(new ShapeConstraintAnalysis(op_list_));
cycle_detector_.reset(new GraphCycles(op_list_.size()));
BuildNodeMap();
}
// Returns a fusion plan if success, otherwise none.
llvm::Optional<FusionPlan> Run() {
// Greedily search connected fusible pattern, and ops belonging to
// a same fusion pattern are grouped into a cluster.
RunEdgeContractionLoop();
// After doing edge contraction, each unique cluster having size
// more than one represents a potential fusion pattern.
// We collect all these clusters and construct a fusion plan.
FusionPlan plan;
DenseSet<Cluster*> seen_clusters;
for (Operation* op : op_list_) {
Cluster* cluster = GetClusterForNode(op);
if (!seen_clusters.insert(cluster).second) continue;
FusionPattern& fusion_pattern = cluster->fused_pattern();
// Make sure the ops in a fusion pattern are in topological ordering.
fusion_pattern.sortFusionOpListBy(op_to_node_id_);
if (!fusion_pattern.isFusible() || fusion_pattern.effectiveSize() <= 1) {
continue;
}
plan.emplace_back(fusion_pattern);
}
// Re-order ops inside the blocks to make sure all producers are placed
// before its consumers after fusion.
ReorderOperationsInsideBlock();
return plan;
}
// Returns the op_list this planner operates on.
const SmallVectorImpl<Operation*>& op_list() const { return op_list_; }
private:
// Represent a (partial) fused pattern
class Cluster {
public:
Cluster(int node_id, FusionPlanner* planner)
: node_id_(node_id), pattern_(planner->op_list()[node_id]) {}
// Merges `other` into this cluster, and clears `other`.
void Merge(Cluster* other) {
pattern_.mergeInplace(other->fused_pattern());
}
// The number of nodes in this cluster.
int cluster_size() { return pattern_.size(); }
// The ID of the cluster as represented in `cycle_detector_`.
int cycles_graph_node_id() const { return node_id_; }
// Sets the ID of the cluster as represented in `cycle_detector_`.
void set_cycles_graph_node_id(int cycles_graph_node_id) {
node_id_ = cycles_graph_node_id;
}
// Currently the fused pattern this cluster holds.
FusionPattern& fused_pattern() { return pattern_; }
private:
// ID of the representative node of this cluster.
int node_id_;
// the fused pattern this cluster holds.
FusionPattern pattern_;
};
private:
// Returns a new cluster with specified `cycles_graph_node_id`
Cluster* MakeCluster(int cycles_graph_node_id) {
cluster_storage_.emplace_back(new Cluster(cycles_graph_node_id, this));
return cluster_storage_.back().get();
}
// Metadata ops (e.g. shapeOf, dimOp) don't change data thus we move forward
// them as far as possible inside the same block to enable more fusion
// opportunities.
void MoveUpMetadataOnlyOpsForFusion() {
SmallVector<Operation*, 4> ops;
for (Operation& op : *block_) {
ops.push_back(&op);
}
auto inBlock = [&](Operation* op, Block* block) {
return op && op->getBlock() == block;
};
for (Operation* op : ops) {
Block* block = op->getBlock();
if (isa<shape::ShapeOfOp>(op)) {
Operation* definingOp = op->getOperand(0).getDefiningOp();
if (!inBlock(definingOp, block)) {
op->moveBefore(block, block->begin());
} else {
op->moveAfter(definingOp);
}
} else if (isa<memref::DimOp>(op)) {
Operation* firstOperandOp = op->getOperand(0).getDefiningOp();
Operation* secondOperandOp = op->getOperand(1).getDefiningOp();
if (!inBlock(firstOperandOp, block) &&
!inBlock(secondOperandOp, block)) {
op->moveBefore(block, block->begin());
} else if (!inBlock(firstOperandOp, block)) {
op->moveAfter(secondOperandOp);
} else if (!inBlock(secondOperandOp, block)) {
op->moveAfter(firstOperandOp);
} else if (firstOperandOp->isBeforeInBlock(secondOperandOp)) {
op->moveAfter(secondOperandOp);
} else {
op->moveAfter(firstOperandOp);
}
}
}
}
// Returns all the values touched by this op or its nested ops.
SmallVector<Value, 4> GetAllPossibleUsedValues(Operation* op) {
SmallVector<Value, 4> values;
op->walk([&](Operation* nest_op) {
for (Value v : nest_op->getOperands()) {
values.push_back(v);
}
});
return values;
}
// Builds the initial dependency graph.
void BuildNodeMap() {
int num_nodes = op_list_.size();
for (int node_id = 0; node_id < num_nodes; ++node_id) {
Operation* op = op_list_[node_id];
MakeCluster(node_id);
op_to_node_id_[op] = node_id;
leader_for_node_.insert(node_id);
for (Value operand : GetAllPossibleUsedValues(op)) {
Operation* operand_op = FindLastWriter(operand);
// Only consider the operand_op inside the target block.
auto iter = op_to_node_id_.find(operand_op);
if (iter == op_to_node_id_.end()) {
continue;
}
// Add an edge to connect the last writer and the current consumer.
cycle_detector_->InsertEdge(iter->second, node_id);
}
// For some ops (e.g. lmhlo ops), some operands are the output memrefs
// Thus these operands are supposed to be updated.
// Suppose that a op (or its nested ops) can only write the buffers
// explicit passed in as operands of this op.
int num_input_operand = op->getNumOperands() - getNumResultOperands(op);
for (Value v : op->getOperands().drop_front(num_input_operand)) {
auto it = last_writer_.try_emplace(v, op);
(void)it;
// Currently, a buffer is only supposed to be written once (as the
// output operand of one lmhlo op).
assert(it.second);
}
}
}
// Returns the cluster contains this op.
Cluster* GetClusterForNode(Operation* n) {
int id = op_to_node_id_[n];
id = leader_for_node_.getLeaderValue(id);
return cluster_storage_[id].get();
}
// Returns the cluster contains the op having `node_id`.
Cluster* GetClusterForCyclesGraphNode(int node_id) {
return cluster_storage_[leader_for_node_.getLeaderValue(node_id)].get();
}
// Merges the clusters `cluster_from` and `cluster_to`.
bool MergeClusters(Cluster* cluster_from, Cluster* cluster_to) {
int from = cluster_from->cycles_graph_node_id();
int to = cluster_to->cycles_graph_node_id();
auto optional_merged_node = cycle_detector_->ContractEdge(from, to);
if (!optional_merged_node.hasValue()) {
llvm::dbgs() << "Could not contract " << from << " -> " << to
<< " because contracting the edge would create a cycle.";
return false;
}
// Merge the clusters.
cluster_from->Merge(cluster_to);
cluster_from->set_cycles_graph_node_id(*optional_merged_node);
// Merge the UnionFind Set.
leader_for_node_.unionSets(from, to);
return true;
}
using FnTy = llvm::function_ref<bool(Cluster*, Cluster*)>;
bool ForEachEdgeInPostOrder(FnTy fn, bool enable_cross_fusion = false) {
bool changed = false;
for (int32_t node : cycle_detector_->AllNodesInPostOrder()) {
Cluster* cluster_from = GetClusterForCyclesGraphNode(node);
// Make a copy of the set of successors because we may modify the graph in
// TryToContractEdge.
std::vector<int32_t> successors_copy =
cycle_detector_->SuccessorsCopy(cluster_from->cycles_graph_node_id());
for (int to : successors_copy) {
Cluster* cluster_to = GetClusterForCyclesGraphNode(to);
bool contracted_edge = fn(cluster_from, cluster_to);
changed |= contracted_edge;
}
}
if (!enable_cross_fusion) return changed;
// To enable even more fusion opportunities (e.g. horizontal fusion)
for (int32_t lhs : cycle_detector_->AllNodesInPostOrder()) {
Cluster* cluster_lhs = GetClusterForCyclesGraphNode(lhs);
if (!cluster_lhs) {
continue;
}
for (int32_t rhs : cycle_detector_->AllNodesInPostOrder()) {
Cluster* cluster_rhs = GetClusterForCyclesGraphNode(rhs);
if (!cluster_rhs || cluster_lhs == cluster_rhs) {
continue;
}
bool contracted_edge = fn(cluster_lhs, cluster_rhs);
changed |= contracted_edge;
}
}
return changed;
}
// This function check if fusing `from` with `to` is valid and if so perform
// the merge. The validity is based on the operations in the clusters and
// the compatibility of the shapes of the outputs of the would-be fused
// clusters.
// Returns true is the merge was performed.
bool TryToContractEdge(Cluster* from, Cluster* to) {
// Try merge and check if valid.
if (!from->fused_pattern().isMergeable(to->fused_pattern())) return false;
FusionPattern fused_pattern =
from->fused_pattern().merge(to->fused_pattern());
auto& op_list = fused_pattern.getOpList();
auto& operands = fused_pattern.getOperands();
auto& results = fused_pattern.getResults();
if (results.size() + operands.size() >
options_.max_num_arguments_per_kernel) {
// some backend devices (e.g. GPU) do not support a kernel with
// too many arguments.
return false;
}
// We currently do not support a constant op as final output of a fusion
// pattern.
// TODO(disc): copy small const in case necessary.
for (Value result : results) {
Operation* result_op = FindLastWriter(result);
assert(result_op);
if (isa<lmhlo::ConstOp>(result_op)) {
return false;
}
}
// ReduceOp can not have consumer within the fusion pattern.
for (Operation* op : op_list) {
if (!isa<lmhlo::ReduceOp>(op)) continue;
int num_input_operand = op->getNumOperands() - getNumResultOperands(op);
for (Value v : op->getOperands().drop_front(num_input_operand)) {
for (Operation* user : getValueUsers(v)) {
if (user == op) continue;
if (std::find(op_list.begin(), op_list.end(), user) !=
op_list.end()) {
return false;
}
}
}
}
// All outputs of a fusion pattern should have compatible shape.
// Here `compatible` means:
// - if `to` and `from` are both kInput fusion, all output should have same
// shape.
// - otherwise, all output should have same number of elements.
// No outside users, these ops may be eliminated. We fused it here and let
// latter pass to do such DCE.
if (results.empty()) return true;
bool check_same_shape = (to->fused_pattern().isKInputFusion() &&
from->fused_pattern().isKInputFusion());
auto get_effective_shape = [&](Value v) {
auto result_op = FindLastWriter(v);
assert(result_op);
// effective shape of reduce op is its operand's shape.
return isa<lmhlo::ReduceOp>(result_op) ? result_op->getOperand(0) : v;
};
Value ref_shape = get_effective_shape(results[0]);
if (!llvm::all_of(results, [&](Value result) {
Value shape = get_effective_shape(result);
return check_same_shape
? shape_analysis_->HasSameShape(ref_shape, shape)
: shape_analysis_->HasSameNumElements(ref_shape, shape);
})) {
return false;
}
return MergeClusters(from, to);
}
// Greedily fuse connected node.
bool RunEdgeContractionLoop() {
using std::placeholders::_1;
using std::placeholders::_2;
bool changed = false;
// Run fusion pass repeatedly until nothing to be fused
while (ForEachEdgeInPostOrder(
std::bind(&FusionPlanner::TryToContractEdge, this, _1, _2), false)) {
// empty statement by design
}
return changed;
}
// Here `value` is supported to be a pointer to buffer.
// Returns the defining op of `value `if no known op updates the buffer,
// otherwise returns the last op that updates the buffer pointed by the
// `value`.
Operation* FindLastWriter(Value value) {
auto it = last_writer_.find(value);
if (it != last_writer_.end()) {
return it->second;
}
return value.getDefiningOp();
}
// Re-order ops inside the block to make sure that producers are before
// consumers after fusion.
void ReorderOperationsInsideBlock() {
auto reorder_func = [&](Cluster* from, Cluster* to) {
FusionPattern& from_pattern = from->fused_pattern();
FusionPattern& to_pattern = to->fused_pattern();
Operation* last_op_in_from = from_pattern.getOpList().back();
for (Operation* op : llvm::reverse(to_pattern.getOpList())) {
if (!last_op_in_from->isBeforeInBlock(op))
op->moveAfter(last_op_in_from);
}
return false;
};
ForEachEdgeInPostOrder(reorder_func);
}
// hyper-parameters that controls the behaviour of the fusion planner.
FusionOptions options_;
// The block that fusion planner works on.
Block* block_;
// Ops inside the block
SmallVector<Operation*, 4> op_list_;
// Shape equality checker
std::unique_ptr<ShapeConstraintAnalysis> shape_analysis_;
// op -> node_id
DenseMap<Operation*, int> op_to_node_id_;
// make sure not introduce cycle after fusion
std::unique_ptr<GraphCycles> cycle_detector_;
std::vector<std::unique_ptr<Cluster>> cluster_storage_;
// a UnionFind set. Each set represents a (partial) fused pattern
// and has a leader as representation.
EquivalenceClasses<int32_t> leader_for_node_;
// Here `value` is supported to be a pointer to buffer.
// Returns the defining op of `value `if no known op updates the buffer,
// otherwise returns the last op that updates the buffer pointed by the
// `value`.
DenseMap<Value, Operation*> last_writer_;
};
struct LhloFusionPass : public LhloFusionPassBase<LhloFusionPass> {
using LhloFusionPassBase<LhloFusionPass>::LhloFusionPassBase;
explicit LhloFusionPass(int max_num_arguments_per_kernel)
: LhloFusionPassBase<LhloFusionPass>::LhloFusionPassBase() {
this->max_num_arguments_per_kernel_ = max_num_arguments_per_kernel;
}
void runOnFunction() override {
FuncOp func = getFunction();
// collect all blocks inside the function.
SmallVector<Block*, 4> blocks;
CollectBlocksInsideFunction(func, blocks);
// process each block and do fusion within a block.
FusionOptions options;
options.max_num_arguments_per_kernel = max_num_arguments_per_kernel_;
for (Block* block : blocks) {
FusionPlanner planner(options, block);
llvm::Optional<FusionPlan> plan = planner.Run();
if (!plan) {
emitError(func.getLoc(),
"an error occurs while trying to find fusion candidates");
signalPassFailure();
return;
}
if (!ApplyFusionPlan(*plan)) {
emitError(func.getLoc(), "apply fusion plan failed");
signalPassFailure();
return;
}
}
}
bool ApplyFusionPlan(FusionPlan& plan) {
for (FusionPattern& pattern : plan) {
auto& op_list = pattern.getOpList();
OpBuilder b(op_list.back());
// Get the fused locations
SmallVector<Location, 4> locations;
locations.reserve(op_list.size());
for (Operation* op : op_list) {
locations.push_back(op->getLoc());
}
Location fused_loc =
FusedLoc::get(op_list.back()->getContext(), locations);
// Move ops inside fusion pattern to the region attached to the fusion op.
FusionOp fusion = b.create<lmhlo::FusionOp>(fused_loc);
Region& region = fusion.region();
Block& block = region.front();
for (Operation* op : llvm::reverse(op_list)) {
op->moveBefore(&block, block.begin());
}
}
return true;
}
void CollectBlocksInsideFunction(FuncOp op, SmallVectorImpl<Block*>& blocks) {
op.walk([&](Block* block) {
// It does not make sense to fuse the region attached to these ops.
if (!isa<lmhlo::ReduceOp, lmhlo::FusionOp>(block->getParentOp()))
blocks.push_back(block);
});
}
};
} // namespace
std::unique_ptr<OperationPass<FuncOp>> createLhloFusionPass(
int max_num_arguments_per_kernel) {
return std::make_unique<LhloFusionPass>(max_num_arguments_per_kernel);
}
} // namespace lmhlo
} // namespace mlir