580 lines
20 KiB
C++
580 lines
20 KiB
C++
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||
|
|
||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
you may not use this file except in compliance with the License.
|
||
|
You may obtain a copy of the License at
|
||
|
|
||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||
|
|
||
|
Unless required by applicable law or agreed to in writing, software
|
||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
See the License for the specific language governing permissions and
|
||
|
limitations under the License.
|
||
|
==============================================================================*/
|
||
|
|
||
|
#include <memory>
|
||
|
#include <unordered_map>
|
||
|
#include <unordered_set>
|
||
|
#include <vector>
|
||
|
|
||
|
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||
|
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||
|
#include "mlir/IR/Matchers.h"
|
||
|
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
|
||
|
#include "mlir/Transforms/RegionUtils.h" // TF:llvm-project
|
||
|
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/EquivalenceClasses.h"
|
||
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/cycle_detector.h"
|
||
|
|
||
|
// This pass has similar functionality of the fusion pass in XLA stack.
|
||
|
// However, unlike XLA, it targets the fully dynamic shape scenario.
|
||
|
// Currently, it implements the kLoop and kInput fusion templates.
|
||
|
// During conversion, it tries to greedily find kLoop/kInput fusion
|
||
|
// patterns.
|
||
|
//
|
||
|
// Similar to XLA, this pass supports fusion pattern having multiple outputs
|
||
|
// if all the shape of outputs are consistent. Following are some examples.
|
||
|
//
|
||
|
// kLoop kInput
|
||
|
// +----+ +----+ +----+ +----+ +----+ +----+
|
||
|
// |elem| |elem| |elem| |elem<----+elem+---->elem+----+
|
||
|
// +-+--+ +-+--+ +-+--+ +-+--+ +----+ +-+--+ |
|
||
|
// | | | | | |
|
||
|
// | | | | |
|
||
|
// +-v--+ | +-v--+ +--v---+ +--v---+ |
|
||
|
// |elem+<---+----<+elem| |reduce| |reduce| |
|
||
|
// +-+--+ +-+--+ +--+---+ +--+---+ |
|
||
|
// | | | | |
|
||
|
// | | | | |
|
||
|
// v v v v v
|
||
|
//
|
||
|
// To this end, we also add an simple shape constraint analysis phase.
|
||
|
// For kLoop fusion template, it requires all the outputs of the fused
|
||
|
// pattern have the same shape. However, we don't know the actual value
|
||
|
// of the shape at the compile time in the dynamic shape world.
|
||
|
// Fortunately, we could still infer the relationship among different ops
|
||
|
// according to their shape constrain traits. Currently, We only consider
|
||
|
// shape equality propagation for elementwise ops (assuming that implicit
|
||
|
// shape broadcast is forbidden). The above process could be built on the
|
||
|
// shape dialect once it is ready.
|
||
|
|
||
|
namespace mlir {
|
||
|
namespace xla_hlo {
|
||
|
namespace {
|
||
|
|
||
|
using llvm::EquivalenceClasses;
|
||
|
using FusionPattern = std::vector<Operation*>;
|
||
|
using FusionPlan = std::vector<FusionPattern>;
|
||
|
|
||
|
// To support using EquivalenceClasses for Value
|
||
|
class ValueWrapper {
|
||
|
public:
|
||
|
explicit ValueWrapper(Value value) : value_(std::move(value)) {}
|
||
|
|
||
|
Value getValue() const { return value_; }
|
||
|
|
||
|
bool operator==(const ValueWrapper& rhs) const {
|
||
|
return getValue() == rhs.getValue();
|
||
|
}
|
||
|
|
||
|
private:
|
||
|
Value value_;
|
||
|
};
|
||
|
|
||
|
bool operator<(const ValueWrapper& lhs, const ValueWrapper& rhs) {
|
||
|
auto lhs_value = lhs.getValue().getAsOpaquePointer();
|
||
|
auto rhs_value = rhs.getValue().getAsOpaquePointer();
|
||
|
return lhs_value < rhs_value;
|
||
|
}
|
||
|
|
||
|
bool IsFusible(Operation* op) {
|
||
|
if (matchPattern(op, m_Constant())) {
|
||
|
return true;
|
||
|
}
|
||
|
auto op_fusibility = dyn_cast<InferFusibilityOpInterface>(op);
|
||
|
return op_fusibility && (op_fusibility.isFusibleWithOperand() ||
|
||
|
op_fusibility.isFusibleWithConsumer());
|
||
|
}
|
||
|
|
||
|
SmallVector<Value, 4> GetInputsOfFusionPattern(const FusionPattern& pattern) {
|
||
|
SmallVector<Value, 4> inputs;
|
||
|
DenseSet<Value> input_set;
|
||
|
DenseSet<Operation*> op_set;
|
||
|
for (Operation* op : pattern) {
|
||
|
bool inserted = op_set.insert(op).second;
|
||
|
(void)inserted;
|
||
|
assert(inserted && "FusionPattern contains duplicate operations");
|
||
|
}
|
||
|
|
||
|
for (Operation* op : pattern) {
|
||
|
for (Value operand : op->getOperands()) {
|
||
|
Operation* operand_op = operand.getDefiningOp();
|
||
|
if (op_set.find(operand_op) != op_set.end()) {
|
||
|
// skip if defining op is in the pattern
|
||
|
continue;
|
||
|
}
|
||
|
if (input_set.insert(operand).second) {
|
||
|
inputs.push_back(operand);
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
return inputs;
|
||
|
}
|
||
|
|
||
|
SmallVector<Value, 4> GetOutputsOfFusionPattern(const FusionPattern& pattern) {
|
||
|
SmallVector<Value, 4> outputs;
|
||
|
DenseSet<Operation*> op_set;
|
||
|
for (Operation* op : pattern) {
|
||
|
bool inserted = op_set.insert(op).second;
|
||
|
(void)inserted;
|
||
|
assert(inserted && "FusionPattern contains duplicate operations");
|
||
|
}
|
||
|
|
||
|
for (Operation* op : pattern) {
|
||
|
for (Value result : op->getResults()) {
|
||
|
bool has_external_user = llvm::any_of(
|
||
|
result.getUses(),
|
||
|
[&](OpOperand& use) { return !op_set.count(use.getOwner()); });
|
||
|
if (has_external_user) {
|
||
|
outputs.push_back(result);
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
return outputs;
|
||
|
}
|
||
|
|
||
|
FusionPattern MergeFusionPattern(const FusionPattern& lhs,
|
||
|
const FusionPattern& rhs) {
|
||
|
FusionPattern pattern(lhs);
|
||
|
pattern.insert(pattern.end(), rhs.begin(), rhs.end());
|
||
|
return pattern;
|
||
|
}
|
||
|
|
||
|
inline int EffectiveSize(const FusionPattern& pattern) {
|
||
|
return llvm::count_if(
|
||
|
pattern, [](Operation* op) { return !matchPattern(op, m_Constant()); });
|
||
|
}
|
||
|
|
||
|
// This is an simple shape constraint analysis, which is used to
|
||
|
// guide fusion decision (e.g. we only fuse shape-compatible ops).
|
||
|
//
|
||
|
// Currently, We only consider shape equality propagation based
|
||
|
// on the shape constrain traits of elementwise ops (assuming that
|
||
|
// implicit shape broadcast is forbidden).
|
||
|
class ShapeConstraintAnalysis {
|
||
|
public:
|
||
|
explicit ShapeConstraintAnalysis(const SmallVectorImpl<Operation*>& op_list) {
|
||
|
PropagateEquality(op_list);
|
||
|
}
|
||
|
|
||
|
// Returns true is `lhs` and `rhs` are supposed to have same shape.
|
||
|
bool HasSameShape(Value lhs, Value rhs) {
|
||
|
return impl_.isEquivalent(ValueWrapper(lhs), ValueWrapper(rhs));
|
||
|
}
|
||
|
|
||
|
private:
|
||
|
// shape equality propagation based on the shape constrains of
|
||
|
// elementwise ops.
|
||
|
void PropagateEquality(const SmallVectorImpl<Operation*>& op_list) {
|
||
|
bool converged = true;
|
||
|
do {
|
||
|
converged = true;
|
||
|
auto update = [&](Value lhs, Value rhs) {
|
||
|
if (!impl_.isEquivalent(ValueWrapper(lhs), ValueWrapper(rhs))) {
|
||
|
converged = false;
|
||
|
impl_.unionSets(ValueWrapper(lhs), ValueWrapper(rhs));
|
||
|
}
|
||
|
};
|
||
|
for (Operation* op : op_list) {
|
||
|
auto op_fusibility = dyn_cast<InferFusibilityOpInterface>(op);
|
||
|
if (!op_fusibility) continue;
|
||
|
int numInput = op->getNumOperands();
|
||
|
int numOutput = op->getNumResults();
|
||
|
// shape equality propagation between inputs.
|
||
|
for (int input1 = 0; input1 < numInput; ++input1)
|
||
|
for (int input2 = input1 + 1; input2 < numInput; ++input2)
|
||
|
if (op_fusibility.inferInputsShapeEquality(input1, input2))
|
||
|
update(op->getOperand(input1), op->getOperand(input2));
|
||
|
|
||
|
// shape equality propagation between outputs.
|
||
|
for (int output1 = 0; output1 < numOutput; ++output1)
|
||
|
for (int output2 = output1 + 1; output2 < numOutput; ++output2)
|
||
|
if (op_fusibility.inferOutputsShapeEquality(output1, output2))
|
||
|
update(op->getResult(output1), op->getResult(output2));
|
||
|
|
||
|
// shape equality propagation between input and output.
|
||
|
for (int input = 0; input < numInput; ++input)
|
||
|
for (int output = 0; output < numOutput; ++output)
|
||
|
if (op_fusibility.inferInputOutputShapeEquality(input, output))
|
||
|
update(op->getOperand(input), op->getResult(output));
|
||
|
}
|
||
|
} while (!converged);
|
||
|
}
|
||
|
|
||
|
// a UnionFind set
|
||
|
EquivalenceClasses<ValueWrapper> impl_;
|
||
|
};
|
||
|
|
||
|
// A fusion planner that can propose a fusion plan for a block of ops.
|
||
|
// The fusion plan is consisted of a group of fusion patterns.
|
||
|
//
|
||
|
// Currently all proposed patterns followed xla kLoop/kInput like fusion
|
||
|
// templates while are adapted to the fully dynamic shape world.
|
||
|
//
|
||
|
// kLoop fusion template satifies:
|
||
|
// - all ops in the fusion pattern are element-wise.
|
||
|
// - all the shapes of outputs of fusion pattern are same, and thus can
|
||
|
// fit into a same parallel loop.
|
||
|
//
|
||
|
// kInput fusion template satifies:
|
||
|
// - any op in the fusion pattern is either element-wise or a reduction.
|
||
|
// - if a op is a reduction, its output cannot be consumered by other
|
||
|
// ops in the same fusion pattern.
|
||
|
// - all the effective shapes of outputs of fusion pattern are same.
|
||
|
// - For element-wise op, its effective shape is its output shape.
|
||
|
// - For reduction op, its effective shape is its operand shape.
|
||
|
class FusionPlanner {
|
||
|
public:
|
||
|
explicit FusionPlanner(const SmallVectorImpl<Operation*>& op_list)
|
||
|
: op_list_(op_list),
|
||
|
shape_analysis_(op_list),
|
||
|
cycle_detector_(op_list.size()) {
|
||
|
BuildNodeMap();
|
||
|
}
|
||
|
|
||
|
// Returns a fusion plan if success, otherwise none.
|
||
|
llvm::Optional<FusionPlan> Run() {
|
||
|
// Greedily search connected fusible pattern, and ops belonging to
|
||
|
// a same fusion pattern are grouped into a cluster.
|
||
|
RunEdgeContractionLoop();
|
||
|
|
||
|
// After doing edge contraction, each unique cluster having size
|
||
|
// more than one represents a potential fusion pattern.
|
||
|
// We collect all these clusters and construct a fusion plan.
|
||
|
//
|
||
|
// Note that the ops in a fusion pattern are in topological ordering.
|
||
|
FusionPlan plan;
|
||
|
DenseMap<int, int> pattern_ids;
|
||
|
for (Operation* op : op_list_) {
|
||
|
Cluster* cluster = GetClusterForNode(op);
|
||
|
int node_id = cluster->cycles_graph_node_id();
|
||
|
if (!IsFusible(op_list_[node_id]) ||
|
||
|
EffectiveSize(GetClusterForNode(op)->fused_pattern()) <= 1) {
|
||
|
continue;
|
||
|
}
|
||
|
if (!pattern_ids.count(node_id)) {
|
||
|
int pattern_id = pattern_ids.size();
|
||
|
pattern_ids[node_id] = pattern_id;
|
||
|
plan.emplace_back();
|
||
|
}
|
||
|
plan[pattern_ids[node_id]].push_back(op);
|
||
|
}
|
||
|
return plan;
|
||
|
}
|
||
|
|
||
|
// Returns the op_list this planner operates on.
|
||
|
const SmallVectorImpl<Operation*>& op_list() const { return op_list_; }
|
||
|
|
||
|
private:
|
||
|
// Represent a (partial) fused pattern
|
||
|
class Cluster {
|
||
|
public:
|
||
|
Cluster(int node_id, FusionPlanner* planner) : node_id_(node_id) {
|
||
|
const SmallVectorImpl<Operation*>& op_list = planner->op_list();
|
||
|
pattern_.push_back(op_list[node_id]);
|
||
|
}
|
||
|
|
||
|
// Merges `other` into this cluster, and clears `other`.
|
||
|
void Merge(Cluster* other) {
|
||
|
pattern_.insert(pattern_.end(), other->pattern_.begin(),
|
||
|
other->pattern_.end());
|
||
|
other->pattern_.clear();
|
||
|
}
|
||
|
|
||
|
// The number of nodes in this cluster.
|
||
|
int cluster_size() const { return pattern_.size(); }
|
||
|
|
||
|
// The ID of the cluster as represented in `cycle_detector_`.
|
||
|
int cycles_graph_node_id() const { return node_id_; }
|
||
|
|
||
|
// Sets the ID of the cluster as represented in `cycle_detector_`.
|
||
|
void set_cycles_graph_node_id(int cycles_graph_node_id) {
|
||
|
node_id_ = cycles_graph_node_id;
|
||
|
}
|
||
|
|
||
|
// Currently the fused pattern this cluster holds.
|
||
|
const FusionPattern& fused_pattern() { return pattern_; }
|
||
|
|
||
|
private:
|
||
|
// ID of the representative node of this cluster.
|
||
|
int node_id_;
|
||
|
|
||
|
// the fused pattern this cluster holds.
|
||
|
FusionPattern pattern_;
|
||
|
};
|
||
|
|
||
|
private:
|
||
|
Cluster* MakeCluster(int cycles_graph_node_id) {
|
||
|
cluster_storage_.emplace_back(new Cluster(cycles_graph_node_id, this));
|
||
|
return cluster_storage_.back().get();
|
||
|
}
|
||
|
|
||
|
void BuildNodeMap() {
|
||
|
int num_nodes = op_list_.size();
|
||
|
for (int node_id = 0; node_id < num_nodes; ++node_id) {
|
||
|
Operation* op = op_list_[node_id];
|
||
|
MakeCluster(node_id);
|
||
|
op_to_node_id_[op] = node_id;
|
||
|
leader_for_node_.insert(node_id);
|
||
|
for (Value operand : op->getOperands()) {
|
||
|
Operation* operand_op = operand.getDefiningOp();
|
||
|
if (operand_op == nullptr) {
|
||
|
// skip block argument
|
||
|
continue;
|
||
|
}
|
||
|
auto iter = op_to_node_id_.find(operand_op);
|
||
|
assert(iter != op_to_node_id_.end());
|
||
|
cycle_detector_.InsertEdge(iter->second, node_id);
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Returns the cluster contains this op.
|
||
|
Cluster* GetClusterForNode(Operation* n) {
|
||
|
int id = op_to_node_id_.at(n);
|
||
|
id = leader_for_node_.getLeaderValue(id);
|
||
|
return cluster_storage_[id].get();
|
||
|
}
|
||
|
|
||
|
// Returns the cluster contains the op having `node_id`.
|
||
|
Cluster* GetClusterForCyclesGraphNode(int node_id) {
|
||
|
return cluster_storage_[leader_for_node_.getLeaderValue(node_id)].get();
|
||
|
}
|
||
|
|
||
|
// Merges the clusters `cluster_from` and `cluster_to`.
|
||
|
bool MergeClusters(Cluster* cluster_from, Cluster* cluster_to) {
|
||
|
int from = cluster_from->cycles_graph_node_id();
|
||
|
int to = cluster_to->cycles_graph_node_id();
|
||
|
|
||
|
auto optional_merged_node = cycle_detector_.ContractEdge(from, to);
|
||
|
if (!optional_merged_node.hasValue()) {
|
||
|
llvm::dbgs() << "Could not contract " << from << " -> " << to
|
||
|
<< " because contracting the edge would create a cycle.";
|
||
|
return false;
|
||
|
}
|
||
|
|
||
|
// Merge the clusters.
|
||
|
cluster_from->Merge(cluster_to);
|
||
|
cluster_from->set_cycles_graph_node_id(*optional_merged_node);
|
||
|
|
||
|
// Merge the UnionFind Set.
|
||
|
leader_for_node_.unionSets(from, to);
|
||
|
return true;
|
||
|
}
|
||
|
|
||
|
template <typename FnTy>
|
||
|
bool ForEachEdgeInPostOrder(FnTy fn) {
|
||
|
bool changed = false;
|
||
|
for (int32_t node : cycle_detector_.AllNodesInPostOrder()) {
|
||
|
Cluster* cluster_from = GetClusterForCyclesGraphNode(node);
|
||
|
// Make a copy of the set of successors because we may modify the graph in
|
||
|
// TryToContractEdge.
|
||
|
std::vector<int32_t> successors_copy =
|
||
|
cycle_detector_.SuccessorsCopy(cluster_from->cycles_graph_node_id());
|
||
|
|
||
|
for (int to : successors_copy) {
|
||
|
Cluster* cluster_to = GetClusterForCyclesGraphNode(to);
|
||
|
bool contracted_edge = fn(cluster_from, cluster_to);
|
||
|
changed |= contracted_edge;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return changed;
|
||
|
}
|
||
|
|
||
|
// returns the outputs if two cluster were merged
|
||
|
SmallVector<Value, 4> GetResultsOfFusedPattern(Cluster* from, Cluster* to) {
|
||
|
FusionPattern fused_pattern =
|
||
|
MergeFusionPattern(from->fused_pattern(), to->fused_pattern());
|
||
|
return GetOutputsOfFusionPattern(fused_pattern);
|
||
|
}
|
||
|
|
||
|
// This function check if fusing `from` with `to` is valid and if so perform
|
||
|
// the merge. The validity is based on the operations in the clusters and
|
||
|
// the compatibility of the shapes of the outputs of the would-be fused
|
||
|
// clusters.
|
||
|
// Returns true is the merge was performed.
|
||
|
bool TryToContractEdge(Cluster* from, Cluster* to) {
|
||
|
int node_to = to->cycles_graph_node_id();
|
||
|
int node_from = from->cycles_graph_node_id();
|
||
|
|
||
|
// Both node_to and node_from should be fusible
|
||
|
if (!IsFusible(op_list_[node_to]) || !IsFusible(op_list_[node_from])) {
|
||
|
return false;
|
||
|
}
|
||
|
|
||
|
auto op_from_fusibility =
|
||
|
dyn_cast<InferFusibilityOpInterface>(op_list_[node_from]);
|
||
|
if (op_from_fusibility && !op_from_fusibility.isFusibleWithConsumer()) {
|
||
|
// This op cannot be fused with its consumers.
|
||
|
return false;
|
||
|
}
|
||
|
|
||
|
auto op_to_fusibility =
|
||
|
dyn_cast<InferFusibilityOpInterface>(op_list_[node_to]);
|
||
|
if (op_to_fusibility && !op_to_fusibility.isFusibleWithOperand()) {
|
||
|
// This op cannot be fused with its operands.
|
||
|
return false;
|
||
|
}
|
||
|
|
||
|
// Output shapes of a fusion pattern should be compatible as described in
|
||
|
// the document of this class.
|
||
|
SmallVector<Value, 4> results = GetResultsOfFusedPattern(from, to);
|
||
|
auto get_workload_shape = [](Value v) {
|
||
|
Operation* op = v.getDefiningOp();
|
||
|
// Block argument
|
||
|
if (!op) return v;
|
||
|
auto op_fusibility = dyn_cast<InferFusibilityOpInterface>(op);
|
||
|
// Const value
|
||
|
if (!op_fusibility) return v;
|
||
|
llvm::Optional<Value> workload =
|
||
|
op_fusibility.inferEffectiveWorkloadShape();
|
||
|
return workload.hasValue() ? *workload : v;
|
||
|
};
|
||
|
|
||
|
Value ref = get_workload_shape(results[0]);
|
||
|
if (!llvm::all_of(results, [&](Value result) {
|
||
|
Value val = get_workload_shape(result);
|
||
|
return shape_analysis_.HasSameShape(ref, val);
|
||
|
})) {
|
||
|
return false;
|
||
|
}
|
||
|
|
||
|
return MergeClusters(from, to);
|
||
|
}
|
||
|
|
||
|
// Greedily fuse connected node.
|
||
|
bool RunEdgeContractionLoop() {
|
||
|
using std::placeholders::_1;
|
||
|
using std::placeholders::_2;
|
||
|
return ForEachEdgeInPostOrder(
|
||
|
std::bind(&FusionPlanner::TryToContractEdge, this, _1, _2));
|
||
|
}
|
||
|
|
||
|
const SmallVectorImpl<Operation*>& op_list_;
|
||
|
|
||
|
// Shape equality checker
|
||
|
ShapeConstraintAnalysis shape_analysis_;
|
||
|
|
||
|
// op -> node_id
|
||
|
std::unordered_map<Operation*, int> op_to_node_id_;
|
||
|
|
||
|
// make sure not introduce cycle after fusion
|
||
|
GraphCycles cycle_detector_;
|
||
|
std::vector<std::unique_ptr<Cluster>> cluster_storage_;
|
||
|
|
||
|
// a UnionFind set. Each set represents a (partial) fused pattern
|
||
|
// and has a leader as representation.
|
||
|
EquivalenceClasses<int32_t> leader_for_node_;
|
||
|
};
|
||
|
|
||
|
struct XlaHloFusion : public mlir::PassWrapper<XlaHloFusion, FunctionPass> {
|
||
|
void runOnFunction() override {
|
||
|
FuncOp func = getFunction();
|
||
|
if (!IsTargetFunc(func)) {
|
||
|
return;
|
||
|
}
|
||
|
|
||
|
// process each block and do fusion within a block.
|
||
|
for (Block& block : func) {
|
||
|
SmallVector<Operation*, 4> op_list;
|
||
|
for (Operation& op : block) {
|
||
|
op_list.push_back(&op);
|
||
|
}
|
||
|
|
||
|
FusionPlanner planner(op_list);
|
||
|
llvm::Optional<FusionPlan> plan = planner.Run();
|
||
|
if (!plan) {
|
||
|
emitError(func.getLoc(), "can't find a fusion plan");
|
||
|
signalPassFailure();
|
||
|
return;
|
||
|
}
|
||
|
if (!ApplyFusionPlan(*plan)) {
|
||
|
emitError(func.getLoc(), "apply fusion plan failed");
|
||
|
signalPassFailure();
|
||
|
return;
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
bool IsTargetFunc(FuncOp func) {
|
||
|
int num_fusible_ops = 0;
|
||
|
bool is_target_func = false;
|
||
|
// We only process the function having enough candidates
|
||
|
func.walk([&](Operation* op) {
|
||
|
num_fusible_ops +=
|
||
|
static_cast<int>(dyn_cast<InferFusibilityOpInterface>(op) != nullptr);
|
||
|
is_target_func = (num_fusible_ops > 1);
|
||
|
// early stop
|
||
|
if (is_target_func) return WalkResult::interrupt();
|
||
|
return WalkResult::advance();
|
||
|
});
|
||
|
return is_target_func;
|
||
|
}
|
||
|
|
||
|
bool ApplyFusionPlan(const FusionPlan& plan) {
|
||
|
for (const FusionPattern& pattern : plan) {
|
||
|
OpBuilder b(pattern.back());
|
||
|
|
||
|
SmallVector<Location, 4> locations;
|
||
|
locations.reserve(pattern.size());
|
||
|
for (Operation* op : pattern) {
|
||
|
locations.push_back(op->getLoc());
|
||
|
}
|
||
|
Location fused_loc =
|
||
|
FusedLoc::get(locations, pattern.back()->getContext());
|
||
|
|
||
|
SmallVector<Value, 4> inputs = GetInputsOfFusionPattern(pattern);
|
||
|
SmallVector<Value, 4> outputs = GetOutputsOfFusionPattern(pattern);
|
||
|
SmallVector<Type, 4> output_types;
|
||
|
output_types.reserve(outputs.size());
|
||
|
for (Value v : outputs) {
|
||
|
output_types.push_back(v.getType());
|
||
|
}
|
||
|
|
||
|
FusionOp fusion =
|
||
|
b.create<xla_hlo::FusionOp>(fused_loc, output_types, inputs);
|
||
|
Region& region = fusion.fused_computation();
|
||
|
region.push_back(new Block);
|
||
|
Block& block = region.front();
|
||
|
for (Operation* op : pattern) {
|
||
|
op->moveBefore(&block, block.end());
|
||
|
}
|
||
|
b.setInsertionPoint(&block, block.end());
|
||
|
b.create<xla_hlo::ReturnOp>(fused_loc, outputs);
|
||
|
|
||
|
for (auto output_and_result : llvm::zip(outputs, fusion.getResults())) {
|
||
|
Value output = std::get<0>(output_and_result);
|
||
|
Value fusion_result = std::get<1>(output_and_result);
|
||
|
for (OpOperand& use : llvm::make_early_inc_range(output.getUses())) {
|
||
|
if (use.getOwner()->getBlock() != &block) use.set(fusion_result);
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
return true;
|
||
|
}
|
||
|
};
|
||
|
|
||
|
} // namespace
|
||
|
|
||
|
std::unique_ptr<OperationPass<FuncOp>> createXlaHloFusion() {
|
||
|
return std::make_unique<XlaHloFusion>();
|
||
|
}
|
||
|
|
||
|
static PassRegistration<XlaHloFusion> xla_hlo_fusion_pass(
|
||
|
"xla-hlo-fusion", "fuse xla_hlo ops to kLoop/kInput fusion patterns.");
|
||
|
|
||
|
} // namespace xla_hlo
|
||
|
} // namespace mlir
|