diff options
Diffstat (limited to 'tensorflow/core/graph/algorithm.cc')
-rw-r--r-- | tensorflow/core/graph/algorithm.cc | 37 |
1 files changed, 21 insertions, 16 deletions
diff --git a/tensorflow/core/graph/algorithm.cc b/tensorflow/core/graph/algorithm.cc index 4652fbe406..9b4200e0b4 100644 --- a/tensorflow/core/graph/algorithm.cc +++ b/tensorflow/core/graph/algorithm.cc @@ -25,7 +25,8 @@ namespace tensorflow { void DFS(const Graph& g, const std::function<void(Node*)>& enter, const std::function<void(Node*)>& leave, - const NodeComparator& stable_comparator) { + const NodeComparator& stable_comparator, + const EdgeFilter& edge_filter) { // Stack of work to do. struct Work { Node* node; @@ -52,7 +53,6 @@ void DFS(const Graph& g, const std::function<void(Node*)>& enter, // Arrange to call leave(n) when all done with descendants. if (leave) stack.push_back(Work{n, true}); - gtl::iterator_range<NeighborIter> nodes = n->out_nodes(); auto add_work = [&visited, &stack](Node* out) { if (!visited[out->id()]) { // Note; we must not mark as visited until we actually process it. @@ -62,16 +62,20 @@ void DFS(const Graph& g, const std::function<void(Node*)>& enter, if (stable_comparator) { std::vector<Node*> nodes_sorted; - for (Node* out : nodes) { - nodes_sorted.emplace_back(out); + for (const Edge* out_edge : n->out_edges()) { + if (!edge_filter || edge_filter(*out_edge)) { + nodes_sorted.emplace_back(out_edge->dst()); + } } std::sort(nodes_sorted.begin(), nodes_sorted.end(), stable_comparator); for (Node* out : nodes_sorted) { add_work(out); } } else { - for (Node* out : nodes) { - add_work(out); + for (const Edge* out_edge : n->out_edges()) { + if (!edge_filter || edge_filter(*out_edge)) { + add_work(out_edge->dst()); + } } } } @@ -118,8 +122,6 @@ void ReverseDFSFromHelper(const Graph& g, gtl::ArraySlice<T> start, // Arrange to call leave(n) when all done with descendants. if (leave) stack.push_back(Work{n, true}); - gtl::iterator_range<NeighborIter> nodes = n->in_nodes(); - auto add_work = [&visited, &stack](T out) { if (!visited[out->id()]) { // Note; we must not mark as visited until we actually process it. @@ -129,16 +131,16 @@ void ReverseDFSFromHelper(const Graph& g, gtl::ArraySlice<T> start, if (stable_comparator) { std::vector<T> nodes_sorted; - for (T in : nodes) { - nodes_sorted.emplace_back(in); + for (const Edge* in_edge : n->in_edges()) { + nodes_sorted.emplace_back(in_edge->src()); } std::sort(nodes_sorted.begin(), nodes_sorted.end(), stable_comparator); for (T in : nodes_sorted) { add_work(in); } } else { - for (T in : nodes) { - add_work(in); + for (const Edge* in_edge : n->in_edges()) { + add_work(in_edge->src()); } } } @@ -161,14 +163,17 @@ void ReverseDFSFrom(const Graph& g, gtl::ArraySlice<Node*> start, } void GetPostOrder(const Graph& g, std::vector<Node*>* order, - const NodeComparator& stable_comparator) { + const NodeComparator& stable_comparator, + const EdgeFilter& edge_filter) { order->clear(); - DFS(g, nullptr, [order](Node* n) { order->push_back(n); }, stable_comparator); + DFS(g, nullptr, [order](Node* n) { order->push_back(n); }, stable_comparator, + edge_filter); } void GetReversePostOrder(const Graph& g, std::vector<Node*>* order, - const NodeComparator& stable_comparator) { - GetPostOrder(g, order, stable_comparator); + const NodeComparator& stable_comparator, + const EdgeFilter& edge_filter) { + GetPostOrder(g, order, stable_comparator, edge_filter); std::reverse(order->begin(), order->end()); } |