aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/graph/algorithm.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/graph/algorithm.h')
-rw-r--r--tensorflow/core/graph/algorithm.h16
1 files changed, 13 insertions, 3 deletions
diff --git a/tensorflow/core/graph/algorithm.h b/tensorflow/core/graph/algorithm.h
index ac4a099013..5bbbc6f6dc 100644
--- a/tensorflow/core/graph/algorithm.h
+++ b/tensorflow/core/graph/algorithm.h
@@ -28,6 +28,8 @@ namespace tensorflow {
// Comparator for two nodes. This is used in order to get a stable ording.
using NodeComparator = std::function<bool(const Node*, const Node*)>;
+using EdgeFilter = std::function<bool(const Edge&)>;
+
// Compares two node based on their ids.
struct NodeComparatorID {
bool operator()(const Node* n1, const Node* n2) const {
@@ -47,9 +49,11 @@ struct NodeComparatorName {
// If leave is not empty, calls leave(n) after visiting all children of n.
// If stable_comparator is set, a stable ordering of visit is achieved by
// sorting a node's neighbors first before visiting them.
+// If edge_filter is set then ignores edges for which edge_filter returns false.
extern void DFS(const Graph& g, const std::function<void(Node*)>& enter,
const std::function<void(Node*)>& leave,
- const NodeComparator& stable_comparator = {});
+ const NodeComparator& stable_comparator = {},
+ const EdgeFilter& edge_filter = {});
// Perform a reverse depth-first-search on g starting at the sink node.
// If enter is not empty, calls enter(n) before visiting any parents of n.
@@ -83,15 +87,21 @@ extern void ReverseDFSFrom(const Graph& g, gtl::ArraySlice<const Node*> start,
// If stable_comparator is set, a stable ordering of visit is achieved by
// sorting a node's neighbors first before visiting them.
//
+// If edge_filter is set then ignores edges for which edge_filter returns false.
+//
// REQUIRES: order is not NULL.
void GetPostOrder(const Graph& g, std::vector<Node*>* order,
- const NodeComparator& stable_comparator = {});
+ const NodeComparator& stable_comparator = {},
+ const EdgeFilter& edge_filter = {});
// Stores in *order the reverse post-order numbering of all nodes
// If stable_comparator is set, a stable ordering of visit is achieved by
// sorting a node's neighbors first before visiting them.
+//
+// If edge_filter is set then ignores edges for which edge_filter returns false.
void GetReversePostOrder(const Graph& g, std::vector<Node*>* order,
- const NodeComparator& stable_comparator = {});
+ const NodeComparator& stable_comparator = {},
+ const EdgeFilter& edge_filter = {});
// Prune nodes in "g" that are not in some path from the source node
// to any node in 'nodes'. Returns true if changes were made to the graph.