aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/graph/algorithm.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/graph/algorithm.cc')
-rw-r--r--tensorflow/core/graph/algorithm.cc37
1 files changed, 21 insertions, 16 deletions
diff --git a/tensorflow/core/graph/algorithm.cc b/tensorflow/core/graph/algorithm.cc
index 4652fbe406..9b4200e0b4 100644
--- a/tensorflow/core/graph/algorithm.cc
+++ b/tensorflow/core/graph/algorithm.cc
@@ -25,7 +25,8 @@ namespace tensorflow {
void DFS(const Graph& g, const std::function<void(Node*)>& enter,
const std::function<void(Node*)>& leave,
- const NodeComparator& stable_comparator) {
+ const NodeComparator& stable_comparator,
+ const EdgeFilter& edge_filter) {
// Stack of work to do.
struct Work {
Node* node;
@@ -52,7 +53,6 @@ void DFS(const Graph& g, const std::function<void(Node*)>& enter,
// Arrange to call leave(n) when all done with descendants.
if (leave) stack.push_back(Work{n, true});
- gtl::iterator_range<NeighborIter> nodes = n->out_nodes();
auto add_work = [&visited, &stack](Node* out) {
if (!visited[out->id()]) {
// Note; we must not mark as visited until we actually process it.
@@ -62,16 +62,20 @@ void DFS(const Graph& g, const std::function<void(Node*)>& enter,
if (stable_comparator) {
std::vector<Node*> nodes_sorted;
- for (Node* out : nodes) {
- nodes_sorted.emplace_back(out);
+ for (const Edge* out_edge : n->out_edges()) {
+ if (!edge_filter || edge_filter(*out_edge)) {
+ nodes_sorted.emplace_back(out_edge->dst());
+ }
}
std::sort(nodes_sorted.begin(), nodes_sorted.end(), stable_comparator);
for (Node* out : nodes_sorted) {
add_work(out);
}
} else {
- for (Node* out : nodes) {
- add_work(out);
+ for (const Edge* out_edge : n->out_edges()) {
+ if (!edge_filter || edge_filter(*out_edge)) {
+ add_work(out_edge->dst());
+ }
}
}
}
@@ -118,8 +122,6 @@ void ReverseDFSFromHelper(const Graph& g, gtl::ArraySlice<T> start,
// Arrange to call leave(n) when all done with descendants.
if (leave) stack.push_back(Work{n, true});
- gtl::iterator_range<NeighborIter> nodes = n->in_nodes();
-
auto add_work = [&visited, &stack](T out) {
if (!visited[out->id()]) {
// Note; we must not mark as visited until we actually process it.
@@ -129,16 +131,16 @@ void ReverseDFSFromHelper(const Graph& g, gtl::ArraySlice<T> start,
if (stable_comparator) {
std::vector<T> nodes_sorted;
- for (T in : nodes) {
- nodes_sorted.emplace_back(in);
+ for (const Edge* in_edge : n->in_edges()) {
+ nodes_sorted.emplace_back(in_edge->src());
}
std::sort(nodes_sorted.begin(), nodes_sorted.end(), stable_comparator);
for (T in : nodes_sorted) {
add_work(in);
}
} else {
- for (T in : nodes) {
- add_work(in);
+ for (const Edge* in_edge : n->in_edges()) {
+ add_work(in_edge->src());
}
}
}
@@ -161,14 +163,17 @@ void ReverseDFSFrom(const Graph& g, gtl::ArraySlice<Node*> start,
}
void GetPostOrder(const Graph& g, std::vector<Node*>* order,
- const NodeComparator& stable_comparator) {
+ const NodeComparator& stable_comparator,
+ const EdgeFilter& edge_filter) {
order->clear();
- DFS(g, nullptr, [order](Node* n) { order->push_back(n); }, stable_comparator);
+ DFS(g, nullptr, [order](Node* n) { order->push_back(n); }, stable_comparator,
+ edge_filter);
}
void GetReversePostOrder(const Graph& g, std::vector<Node*>* order,
- const NodeComparator& stable_comparator) {
- GetPostOrder(g, order, stable_comparator);
+ const NodeComparator& stable_comparator,
+ const EdgeFilter& edge_filter) {
+ GetPostOrder(g, order, stable_comparator, edge_filter);
std::reverse(order->begin(), order->end());
}