mlir-hlo/lib/utils/cycle_detector.cc

342 lines
9.8 KiB
C++
Raw Permalink Normal View History

/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mlir-hlo/utils/cycle_detector.h"
#include <algorithm>
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/SmallVector.h"
namespace mlir {
namespace {
using NodeSet = llvm::DenseSet<int32_t>;
using OrderedNodeSet = OrderedSet<int32_t>;
template <typename T>
struct VecStruct {
typedef llvm::SmallVector<T, 4> type;
};
template <typename T>
using Vec = typename VecStruct<T>::type;
struct Node {
// rank number assigned by Pearce-Kelly algorithm
int32_t rank;
// Temporary marker used by depth-first-search
bool visited;
// User-supplied data
void* data;
// List of immediate predecessor nodes in graph
OrderedNodeSet in;
// List of immediate successor nodes in graph
OrderedNodeSet out;
};
} // namespace
struct GraphCycles::Rep {
Vec<Node*> nodes;
// Indices for unused entries in nodes
Vec<int32_t> free_nodes;
// Temporary state.
// Results of forward DFS
Vec<int32_t> deltaf;
// Results of backward DFS
Vec<int32_t> deltab;
// All nodes to reprocess
Vec<int32_t> list;
// Rank values to assign to list entries
Vec<int32_t> merged;
// Emulates recursion stack when doing depth first search
Vec<int32_t> stack;
};
GraphCycles::GraphCycles(int32_t num_nodes) : rep_(new Rep) {
rep_->nodes.reserve(num_nodes);
for (int32_t i = 0; i < num_nodes; ++i) {
Node* n = new Node;
n->visited = false;
n->data = nullptr;
n->rank = rep_->nodes.size();
rep_->nodes.push_back(n);
}
}
GraphCycles::~GraphCycles() {
for (Vec<Node*>::size_type i = 0, e = rep_->nodes.size(); i < e; ++i) {
delete rep_->nodes[i];
}
delete rep_;
}
bool GraphCycles::HasEdge(int32_t x, int32_t y) const {
return rep_->nodes[x]->out.Contains(y);
}
void GraphCycles::RemoveEdge(int32_t x, int32_t y) {
rep_->nodes[x]->out.Erase(y);
rep_->nodes[y]->in.Erase(x);
// No need to update the rank assignment since a previous valid
// rank assignment remains valid after an edge deletion.
}
static bool ForwardDFS(GraphCycles::Rep* r, int32_t n, int32_t upper_bound);
static void BackwardDFS(GraphCycles::Rep* r, int32_t n, int32_t lower_bound);
static void Reorder(GraphCycles::Rep* r);
static void Sort(const Vec<Node*>&, Vec<int32_t>* delta);
static void MoveToList(GraphCycles::Rep* r, Vec<int32_t>* src,
Vec<int32_t>* dst);
static void ClearVisitedBits(GraphCycles::Rep* r, const Vec<int32_t>& nodes);
bool GraphCycles::InsertEdge(int32_t x, int32_t y) {
if (x == y) return false;
Rep* r = rep_;
Node* nx = r->nodes[x];
if (!nx->out.Insert(y)) {
// Edge already exists.
return true;
}
Node* ny = r->nodes[y];
ny->in.Insert(x);
if (nx->rank <= ny->rank) {
// New edge is consistent with existing rank assignment.
return true;
}
// Current rank assignments are incompatible with the new edge. Recompute.
// We only need to consider nodes that fall in the range [ny->rank,nx->rank].
if (ForwardDFS(r, y, nx->rank)) {
// Found a cycle. Undo the insertion and tell caller.
nx->out.Erase(y);
ny->in.Erase(x);
// Since we do not call Reorder() on this path, clear any visited
// markers left by ForwardDFS.
ClearVisitedBits(r, r->deltaf);
return false;
}
BackwardDFS(r, x, ny->rank);
Reorder(r);
return true;
}
// Follows the edges from producer to consumer and searchs if the node having
// rank `n` can reach the node having rank `upper_bound` using a DFS search.
// When doing DFS search, We only consider the pathes that satisfy the ranks
// of the nodes of the path are all smaller than `upper_bound`.
//
// Returns true if such path exists.
static bool ForwardDFS(GraphCycles::Rep* r, int32_t n, int32_t upper_bound) {
// Avoid recursion since stack space might be limited.
// We instead keep a stack of nodes to visit.
r->deltaf.clear();
r->stack.clear();
r->stack.push_back(n);
while (!r->stack.empty()) {
n = r->stack.back();
r->stack.pop_back();
Node* nn = r->nodes[n];
if (nn->visited) continue;
nn->visited = true;
r->deltaf.push_back(n);
for (auto w : nn->out.GetSequence()) {
Node* nw = r->nodes[w];
if (nw->rank == upper_bound) {
return true;
}
if (!nw->visited && nw->rank < upper_bound) {
r->stack.push_back(w);
}
}
}
return false;
}
// Follows the edges from consumer to producer and visit all the nodes that
// is reachable from node `n` and have rank larger than `lower_bound`.
static void BackwardDFS(GraphCycles::Rep* r, int32_t n, int32_t lower_bound) {
r->deltab.clear();
r->stack.clear();
r->stack.push_back(n);
while (!r->stack.empty()) {
n = r->stack.back();
r->stack.pop_back();
Node* nn = r->nodes[n];
if (nn->visited) continue;
nn->visited = true;
r->deltab.push_back(n);
for (auto w : nn->in.GetSequence()) {
Node* nw = r->nodes[w];
if (!nw->visited && lower_bound < nw->rank) {
r->stack.push_back(w);
}
}
}
}
// Recomputes rank assignments to make them compatible with the edges (producer
// has smaller rank than its consumer)
static void Reorder(GraphCycles::Rep* r) {
Sort(r->nodes, &r->deltab);
Sort(r->nodes, &r->deltaf);
// Adds contents of delta lists to list (backwards deltas first).
r->list.clear();
MoveToList(r, &r->deltab, &r->list);
MoveToList(r, &r->deltaf, &r->list);
// Produce sorted list of all ranks that will be reassigned.
r->merged.resize(r->deltab.size() + r->deltaf.size());
std::merge(r->deltab.begin(), r->deltab.end(), r->deltaf.begin(),
r->deltaf.end(), r->merged.begin());
// Assign the ranks in order to the collected list.
for (Vec<int32_t>::size_type i = 0, e = r->list.size(); i < e; ++i) {
r->nodes[r->list[i]]->rank = r->merged[i];
}
}
// Sorts nodes in the vector according to their ranks. Small rank first.
static void Sort(const Vec<Node*>& nodes, Vec<int32_t>* delta) {
struct ByRank {
const Vec<Node*>* nodes;
bool operator()(int32_t a, int32_t b) const {
return (*nodes)[a]->rank < (*nodes)[b]->rank;
}
};
ByRank cmp;
cmp.nodes = &nodes;
std::sort(delta->begin(), delta->end(), cmp);
}
// Collects ranks of nodes in vector `src` to vector `dst`
static void MoveToList(GraphCycles::Rep* r, Vec<int32_t>* src,
Vec<int32_t>* dst) {
for (Vec<int32_t>::size_type i = 0, e = src->size(); i < e; i++) {
int32_t w = (*src)[i];
// Replace src entry with its rank
(*src)[i] = r->nodes[w]->rank;
// Prepare for future DFS calls
r->nodes[w]->visited = false;
dst->push_back(w);
}
}
// Clears bookkeeping fileds used during the last DFS process.
static void ClearVisitedBits(GraphCycles::Rep* r, const Vec<int32_t>& nodes) {
for (Vec<int32_t>::size_type i = 0, e = nodes.size(); i < e; i++) {
r->nodes[nodes[i]]->visited = false;
}
}
bool GraphCycles::IsReachable(int32_t x, int32_t y) {
if (x == y) return true;
Rep* r = rep_;
Node* nx = r->nodes[x];
Node* ny = r->nodes[y];
if (nx->rank >= ny->rank) {
// x cannot reach y since it is after it in the topological ordering
return false;
}
// See if x can reach y using a DFS search that is limited to y's rank
bool reachable = ForwardDFS(r, x, ny->rank);
// Clear any visited markers left by ForwardDFS.
ClearVisitedBits(r, r->deltaf);
return reachable;
}
llvm::Optional<int32_t> GraphCycles::ContractEdge(int32_t a, int32_t b) {
assert(HasEdge(a, b));
RemoveEdge(a, b);
if (IsReachable(a, b)) {
// Restore the graph to its original state.
InsertEdge(a, b);
return {};
}
if (rep_->nodes[b]->in.Size() + rep_->nodes[b]->out.Size() >
rep_->nodes[a]->in.Size() + rep_->nodes[a]->out.Size()) {
// Swap "a" and "b" to minimize copying.
std::swap(a, b);
}
Node* nb = rep_->nodes[b];
OrderedNodeSet out = std::move(nb->out);
OrderedNodeSet in = std::move(nb->in);
for (int32_t y : out.GetSequence()) {
rep_->nodes[y]->in.Erase(b);
}
for (int32_t y : in.GetSequence()) {
rep_->nodes[y]->out.Erase(b);
}
rep_->free_nodes.push_back(b);
rep_->nodes[a]->out.Reserve(rep_->nodes[a]->out.Size() + out.Size());
for (int32_t y : out.GetSequence()) {
InsertEdge(a, y);
}
rep_->nodes[a]->in.Reserve(rep_->nodes[a]->in.Size() + in.Size());
for (int32_t y : in.GetSequence()) {
InsertEdge(y, a);
}
// Note, if the swap happened it might be what originally was called "b".
return a;
}
std::vector<int32_t> GraphCycles::SuccessorsCopy(int32_t node) const {
return rep_->nodes[node]->out.GetSequence();
}
namespace {
void SortInPostOrder(const Vec<Node*>& nodes, std::vector<int32_t>* to_sort) {
std::sort(to_sort->begin(), to_sort->end(), [&](int32_t a, int32_t b) {
return nodes[a]->rank > nodes[b]->rank;
});
}
} // namespace
std::vector<int32_t> GraphCycles::AllNodesInPostOrder() const {
llvm::DenseSet<int32_t> free_nodes_set;
for (int32_t n : rep_->free_nodes) free_nodes_set.insert(n);
std::vector<int32_t> all_nodes;
all_nodes.reserve(rep_->nodes.size() - free_nodes_set.size());
for (size_t i = 0, e = rep_->nodes.size(); i < e; i++) {
if (!free_nodes_set.count(i)) {
all_nodes.push_back(i);
}
}
SortInPostOrder(rep_->nodes, &all_nodes);
return all_nodes;
}
} // namespace mlir