diff options
Diffstat (limited to 'tensorflow/core/graph/algorithm.h')
-rw-r--r-- | tensorflow/core/graph/algorithm.h | 16 |
1 files changed, 13 insertions, 3 deletions
diff --git a/tensorflow/core/graph/algorithm.h b/tensorflow/core/graph/algorithm.h index ac4a099013..5bbbc6f6dc 100644 --- a/tensorflow/core/graph/algorithm.h +++ b/tensorflow/core/graph/algorithm.h @@ -28,6 +28,8 @@ namespace tensorflow { // Comparator for two nodes. This is used in order to get a stable ording. using NodeComparator = std::function<bool(const Node*, const Node*)>; +using EdgeFilter = std::function<bool(const Edge&)>; + // Compares two node based on their ids. struct NodeComparatorID { bool operator()(const Node* n1, const Node* n2) const { @@ -47,9 +49,11 @@ struct NodeComparatorName { // If leave is not empty, calls leave(n) after visiting all children of n. // If stable_comparator is set, a stable ordering of visit is achieved by // sorting a node's neighbors first before visiting them. +// If edge_filter is set then ignores edges for which edge_filter returns false. extern void DFS(const Graph& g, const std::function<void(Node*)>& enter, const std::function<void(Node*)>& leave, - const NodeComparator& stable_comparator = {}); + const NodeComparator& stable_comparator = {}, + const EdgeFilter& edge_filter = {}); // Perform a reverse depth-first-search on g starting at the sink node. // If enter is not empty, calls enter(n) before visiting any parents of n. @@ -83,15 +87,21 @@ extern void ReverseDFSFrom(const Graph& g, gtl::ArraySlice<const Node*> start, // If stable_comparator is set, a stable ordering of visit is achieved by // sorting a node's neighbors first before visiting them. // +// If edge_filter is set then ignores edges for which edge_filter returns false. +// // REQUIRES: order is not NULL. void GetPostOrder(const Graph& g, std::vector<Node*>* order, - const NodeComparator& stable_comparator = {}); + const NodeComparator& stable_comparator = {}, + const EdgeFilter& edge_filter = {}); // Stores in *order the reverse post-order numbering of all nodes // If stable_comparator is set, a stable ordering of visit is achieved by // sorting a node's neighbors first before visiting them. +// +// If edge_filter is set then ignores edges for which edge_filter returns false. void GetReversePostOrder(const Graph& g, std::vector<Node*>* order, - const NodeComparator& stable_comparator = {}); + const NodeComparator& stable_comparator = {}, + const EdgeFilter& edge_filter = {}); // Prune nodes in "g" that are not in some path from the source node // to any node in 'nodes'. Returns true if changes were made to the graph. |