diff options
author | Sanjoy Das <sanjoy@google.com> | 2018-07-18 15:03:02 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-18 15:10:02 -0700 |
commit | 6619dd5fdcad02f087f5758083e2585bdfef9e78 (patch) | |
tree | f250172e24d9eff6d857476afb4a73539ac11c17 /tensorflow/core/graph | |
parent | a186bcdcb0d3b85909560d85167cda55ccbc973b (diff) |
Don't cluster nodes that have inputs with mismatching deadness
TensorFlow allows nodes to have some live inputs and some dead inputs. The
executor does not execute these nodes but instead propagates a dead signal to
all their outputs (i.e. these nodes are treated as fully dead).
This is a problem for auto-clustering because it means auto-clustering can kill
nodes that used to be alive. For instance say before clustering we have a graph
like
digraph {
Alive0 -> P
Alive1 -> Q
Dead -> R
P -> X
Q -> X
Q -> Y
R -> Y
}
and we cluster P, Q, R, X and Y into a single XLA cluster.
Then after clustering both X and Y are dead because the cluster is a single node
as far as the executor is concerned and said node won't get scheduled if any of
its inputs are dead.
This CL introduces a static analysis pass that our auto-clustering code can use
to ensure nodes that have inputs with mismatching deadness (like "Y" in the
example graph) are not included in XLA clusters.
PiperOrigin-RevId: 205143316
Diffstat (limited to 'tensorflow/core/graph')
-rw-r--r-- | tensorflow/core/graph/algorithm.cc | 37 | ||||
-rw-r--r-- | tensorflow/core/graph/algorithm.h | 16 | ||||
-rw-r--r-- | tensorflow/core/graph/algorithm_test.cc | 52 |
3 files changed, 86 insertions, 19 deletions
diff --git a/tensorflow/core/graph/algorithm.cc b/tensorflow/core/graph/algorithm.cc index 4652fbe406..9b4200e0b4 100644 --- a/tensorflow/core/graph/algorithm.cc +++ b/tensorflow/core/graph/algorithm.cc @@ -25,7 +25,8 @@ namespace tensorflow { void DFS(const Graph& g, const std::function<void(Node*)>& enter, const std::function<void(Node*)>& leave, - const NodeComparator& stable_comparator) { + const NodeComparator& stable_comparator, + const EdgeFilter& edge_filter) { // Stack of work to do. struct Work { Node* node; @@ -52,7 +53,6 @@ void DFS(const Graph& g, const std::function<void(Node*)>& enter, // Arrange to call leave(n) when all done with descendants. if (leave) stack.push_back(Work{n, true}); - gtl::iterator_range<NeighborIter> nodes = n->out_nodes(); auto add_work = [&visited, &stack](Node* out) { if (!visited[out->id()]) { // Note; we must not mark as visited until we actually process it. @@ -62,16 +62,20 @@ void DFS(const Graph& g, const std::function<void(Node*)>& enter, if (stable_comparator) { std::vector<Node*> nodes_sorted; - for (Node* out : nodes) { - nodes_sorted.emplace_back(out); + for (const Edge* out_edge : n->out_edges()) { + if (!edge_filter || edge_filter(*out_edge)) { + nodes_sorted.emplace_back(out_edge->dst()); + } } std::sort(nodes_sorted.begin(), nodes_sorted.end(), stable_comparator); for (Node* out : nodes_sorted) { add_work(out); } } else { - for (Node* out : nodes) { - add_work(out); + for (const Edge* out_edge : n->out_edges()) { + if (!edge_filter || edge_filter(*out_edge)) { + add_work(out_edge->dst()); + } } } } @@ -118,8 +122,6 @@ void ReverseDFSFromHelper(const Graph& g, gtl::ArraySlice<T> start, // Arrange to call leave(n) when all done with descendants. if (leave) stack.push_back(Work{n, true}); - gtl::iterator_range<NeighborIter> nodes = n->in_nodes(); - auto add_work = [&visited, &stack](T out) { if (!visited[out->id()]) { // Note; we must not mark as visited until we actually process it. @@ -129,16 +131,16 @@ void ReverseDFSFromHelper(const Graph& g, gtl::ArraySlice<T> start, if (stable_comparator) { std::vector<T> nodes_sorted; - for (T in : nodes) { - nodes_sorted.emplace_back(in); + for (const Edge* in_edge : n->in_edges()) { + nodes_sorted.emplace_back(in_edge->src()); } std::sort(nodes_sorted.begin(), nodes_sorted.end(), stable_comparator); for (T in : nodes_sorted) { add_work(in); } } else { - for (T in : nodes) { - add_work(in); + for (const Edge* in_edge : n->in_edges()) { + add_work(in_edge->src()); } } } @@ -161,14 +163,17 @@ void ReverseDFSFrom(const Graph& g, gtl::ArraySlice<Node*> start, } void GetPostOrder(const Graph& g, std::vector<Node*>* order, - const NodeComparator& stable_comparator) { + const NodeComparator& stable_comparator, + const EdgeFilter& edge_filter) { order->clear(); - DFS(g, nullptr, [order](Node* n) { order->push_back(n); }, stable_comparator); + DFS(g, nullptr, [order](Node* n) { order->push_back(n); }, stable_comparator, + edge_filter); } void GetReversePostOrder(const Graph& g, std::vector<Node*>* order, - const NodeComparator& stable_comparator) { - GetPostOrder(g, order, stable_comparator); + const NodeComparator& stable_comparator, + const EdgeFilter& edge_filter) { + GetPostOrder(g, order, stable_comparator, edge_filter); std::reverse(order->begin(), order->end()); } diff --git a/tensorflow/core/graph/algorithm.h b/tensorflow/core/graph/algorithm.h index ac4a099013..5bbbc6f6dc 100644 --- a/tensorflow/core/graph/algorithm.h +++ b/tensorflow/core/graph/algorithm.h @@ -28,6 +28,8 @@ namespace tensorflow { // Comparator for two nodes. This is used in order to get a stable ording. using NodeComparator = std::function<bool(const Node*, const Node*)>; +using EdgeFilter = std::function<bool(const Edge&)>; + // Compares two node based on their ids. struct NodeComparatorID { bool operator()(const Node* n1, const Node* n2) const { @@ -47,9 +49,11 @@ struct NodeComparatorName { // If leave is not empty, calls leave(n) after visiting all children of n. // If stable_comparator is set, a stable ordering of visit is achieved by // sorting a node's neighbors first before visiting them. +// If edge_filter is set then ignores edges for which edge_filter returns false. extern void DFS(const Graph& g, const std::function<void(Node*)>& enter, const std::function<void(Node*)>& leave, - const NodeComparator& stable_comparator = {}); + const NodeComparator& stable_comparator = {}, + const EdgeFilter& edge_filter = {}); // Perform a reverse depth-first-search on g starting at the sink node. // If enter is not empty, calls enter(n) before visiting any parents of n. @@ -83,15 +87,21 @@ extern void ReverseDFSFrom(const Graph& g, gtl::ArraySlice<const Node*> start, // If stable_comparator is set, a stable ordering of visit is achieved by // sorting a node's neighbors first before visiting them. // +// If edge_filter is set then ignores edges for which edge_filter returns false. +// // REQUIRES: order is not NULL. void GetPostOrder(const Graph& g, std::vector<Node*>* order, - const NodeComparator& stable_comparator = {}); + const NodeComparator& stable_comparator = {}, + const EdgeFilter& edge_filter = {}); // Stores in *order the reverse post-order numbering of all nodes // If stable_comparator is set, a stable ordering of visit is achieved by // sorting a node's neighbors first before visiting them. +// +// If edge_filter is set then ignores edges for which edge_filter returns false. void GetReversePostOrder(const Graph& g, std::vector<Node*>* order, - const NodeComparator& stable_comparator = {}); + const NodeComparator& stable_comparator = {}, + const EdgeFilter& edge_filter = {}); // Prune nodes in "g" that are not in some path from the source node // to any node in 'nodes'. Returns true if changes were made to the graph. diff --git a/tensorflow/core/graph/algorithm_test.cc b/tensorflow/core/graph/algorithm_test.cc index f67d5a2fd2..60a3e66aa1 100644 --- a/tensorflow/core/graph/algorithm_test.cc +++ b/tensorflow/core/graph/algorithm_test.cc @@ -36,6 +36,11 @@ namespace { REGISTER_OP("TestParams").Output("o: float"); REGISTER_OP("TestInput").Output("a: float").Output("b: float"); REGISTER_OP("TestMul").Input("a: float").Input("b: float").Output("o: float"); +REGISTER_OP("TestUnary").Input("a: float").Output("o: float"); +REGISTER_OP("TestBinary") + .Input("a: float") + .Input("b: float") + .Output("o: float"); // Compares that the order of nodes in 'inputs' respects the // pair orders described in 'ordered_pairs'. @@ -148,5 +153,52 @@ TEST(AlgorithmTest, ReversePostOrderStable) { EXPECT_TRUE(ExpectBefore({{"t2", "t3"}}, order, &error)); } } + +TEST(AlgorithmTest, PostOrderWithEdgeFilter) { + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + string error; + Node* n0 = ops::SourceOp("TestParams", b.opts().WithName("n0")); + Node* n1 = ops::UnaryOp("TestUnary", n0, b.opts().WithName("n1")); + Node* n2 = ops::UnaryOp("TestUnary", n1, b.opts().WithName("n2")); + Node* n3 = ops::BinaryOp("TestBinary", n2, n0, b.opts().WithName("n3")); + + Graph g(OpRegistry::Global()); + TF_ASSERT_OK(GraphDefBuilderToGraph(b, &g)); + + g.AddEdge(g.FindNodeId(n3->id()), 0, g.FindNodeId(n1->id()), 1); + + std::vector<Node*> post_order; + auto edge_filter = [&](const Edge& e) { + return !(e.src()->id() == n3->id() && e.dst()->id() == n1->id()); + }; + + std::vector<Node*> expected_post_order = { + g.sink_node(), g.FindNodeId(n3->id()), g.FindNodeId(n2->id()), + g.FindNodeId(n1->id()), g.FindNodeId(n0->id()), g.source_node()}; + + std::vector<Node*> expected_reverse_post_order = expected_post_order; + std::reverse(expected_reverse_post_order.begin(), + expected_reverse_post_order.end()); + + GetPostOrder(g, &post_order, /*stable_comparator=*/{}, + /*edge_filter=*/edge_filter); + + ASSERT_EQ(expected_post_order.size(), post_order.size()); + for (int i = 0; i < post_order.size(); i++) { + CHECK_EQ(post_order[i], expected_post_order[i]) + << post_order[i]->name() << " vs. " << expected_post_order[i]->name(); + } + + std::vector<Node*> reverse_post_order; + GetReversePostOrder(g, &reverse_post_order, /*stable_comparator=*/{}, + /*edge_filter=*/edge_filter); + + ASSERT_EQ(expected_reverse_post_order.size(), reverse_post_order.size()); + for (int i = 0; i < reverse_post_order.size(); i++) { + CHECK_EQ(reverse_post_order[i], expected_reverse_post_order[i]) + << reverse_post_order[i]->name() << " vs. " + << expected_reverse_post_order[i]->name(); + } +} } // namespace } // namespace tensorflow |