diff options
-rw-r--r-- | tensorflow/core/grappler/optimizers/constant_folding.cc | 8 | ||||
-rw-r--r-- | tensorflow/core/grappler/optimizers/constant_folding_test.cc | 37 | ||||
-rw-r--r-- | tensorflow/core/grappler/utils.cc | 35 | ||||
-rw-r--r-- | tensorflow/core/grappler/utils.h | 16 | ||||
-rw-r--r-- | tensorflow/core/grappler/utils/topological_sort.cc | 11 | ||||
-rw-r--r-- | tensorflow/core/grappler/utils/topological_sort_test.cc | 12 |
6 files changed, 113 insertions, 6 deletions
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index 35b0b7c163..443c0b72ab 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -568,12 +568,16 @@ Status ConstantFolding::FoldNode(NodeDef* node) { // Forward control dependencies. for (const auto& input : node->input()) { - if (IsControlInput(input)) { + if (IsControlInput(input) && + std::find(const_node->input().begin(), const_node->input().end(), + input) == const_node->input().end()) { *const_node->add_input() = input; } else { NodeDef* input_node = node_map_->GetNode(input); for (const auto& fanin_of_input : input_node->input()) { - if (IsControlInput(fanin_of_input)) { + if (IsControlInput(fanin_of_input) && + std::find(const_node->input().begin(), const_node->input().end(), + fanin_of_input) == const_node->input().end()) { *const_node->add_input() = fanin_of_input; } } diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc index 2db843f2a4..0f7e7f1d49 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc @@ -235,6 +235,43 @@ TEST_F(ConstantFoldingTest, ControlDependenciesEmptyFetch) { EXPECT_EQ(2, found); } +TEST_F(ConstantFoldingTest, ControlDependenciesDeduplicate) { + tensorflow::Scope scope = tensorflow::Scope::NewRootScope(); + Output dflt = ops::Const(scope.WithOpName("dflt"), 3.14f, {1}); + Output p1 = ops::PlaceholderWithDefault(scope.WithOpName("p1"), dflt, {1}); + Output p2 = ops::PlaceholderWithDefault(scope.WithOpName("p2"), dflt, {1}); + Output c = + ops::Const(scope.WithOpName("c").WithControlDependencies(p1), 10, {3}); + Output i1 = ops::Identity(scope.WithOpName("i1") + .WithControlDependencies(p2) + .WithControlDependencies(p1), + {c}); + Output i2 = ops::Identity(scope.WithOpName("i2"), {i1}); + + GrapplerItem item; + item.fetch.push_back("i2"); + TF_CHECK_OK(scope.ToGraphDef(&item.graph)); + + ConstantFolding fold; + GraphDef output; + Status status = fold.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + + std::vector<string> expected_nodes = {"dflt", "p1", "p2", "i1", "i2"}; + EXPECT_EQ(output.node_size(), expected_nodes.size()); + int i = 0; + for (const auto& node : output.node()) { + EXPECT_EQ(expected_nodes[i], output.node(i).name()); + i++; + if (node.name() == "i1") { + EXPECT_EQ("Const", node.op()); + EXPECT_EQ(2, node.input_size()); + EXPECT_EQ("^p1", node.input(0)); + EXPECT_EQ("^p2", node.input(1)); + } + } +} + TEST_F(ConstantFoldingTest, VariableNumberOfOutputs) { tensorflow::Scope scope = tensorflow::Scope::NewRootScope(); // Add a DynamicPartition node to the graph diff --git a/tensorflow/core/grappler/utils.cc b/tensorflow/core/grappler/utils.cc index 9e15744fab..c8830e9b3c 100644 --- a/tensorflow/core/grappler/utils.cc +++ b/tensorflow/core/grappler/utils.cc @@ -95,6 +95,41 @@ void NodeMap::UpdateOutput(const string& node_name, outputs.insert(nodes_[new_output_name]); } +OutputMap::OutputMap(GraphDef* graph) : graph_(graph) { + for (int i = 0; i < graph_->node_size(); i++) { + auto node = graph_->mutable_node(i); + auto rslt = nodes_.insert(std::make_pair(node->name(), node)); + // Check that the graph doesn't contain multiple nodes with the same name. + CHECK(rslt.second); + for (const auto& input : node->input()) { + string input_node = NodeName(input); + if (outputs_[input_node].count(node) == 0) { + outputs_[input_node].insert(std::make_pair(node, 1)); + } else { + outputs_[input_node][node]++; + } + } + } +} + +NodeDef* OutputMap::GetNode(const string& name) const { + string node_name = NodeName(name); + auto it = nodes_.find(node_name); + if (it == nodes_.end()) { + return nullptr; + } + return it->second; +} + +const std::unordered_map<NodeDef*, int>& OutputMap::GetOutputs( + const string& node_name) const { + auto it = outputs_.find(node_name); + if (it == outputs_.end()) { + return empty_map_; + } + return it->second; +} + bool IsSameInput(const string& name1, const string& name2) { if (name1 == name2) { return true; diff --git a/tensorflow/core/grappler/utils.h b/tensorflow/core/grappler/utils.h index a9eccd685b..03f49c0ca2 100644 --- a/tensorflow/core/grappler/utils.h +++ b/tensorflow/core/grappler/utils.h @@ -52,6 +52,22 @@ class NodeMap { std::unordered_map<string, std::set<NodeDef*>> outputs_; }; +// A utility class to lookup a node's outputs and the number of times it +// presents in each output. +class OutputMap { + public: + explicit OutputMap(GraphDef* graph); + NodeDef* GetNode(const string& name) const; + const std::unordered_map<NodeDef*, int>& GetOutputs( + const string& node_name) const; + + private: + GraphDef* graph_; + std::unordered_map<NodeDef*, int> empty_map_; + std::unordered_map<string, NodeDef*> nodes_; + std::unordered_map<string, std::unordered_map<NodeDef*, int>> outputs_; +}; + // True iff 'name' refers to a control inputs, i.e. a node name prefixed with // the ^ character. bool IsControlInput(const string& name); diff --git a/tensorflow/core/grappler/utils/topological_sort.cc b/tensorflow/core/grappler/utils/topological_sort.cc index a5a7f34db0..77d4702d21 100644 --- a/tensorflow/core/grappler/utils/topological_sort.cc +++ b/tensorflow/core/grappler/utils/topological_sort.cc @@ -26,7 +26,7 @@ namespace grappler { // Kahn's algorithm is implemented. // For details, see https://en.wikipedia.org/wiki/Topological_sorting void TopologicalSort(GraphDef* graph) { - NodeMap node_map(graph); + OutputMap output_map(graph); std::vector<NodeDef*> ready_nodes; ready_nodes.reserve(graph->node_size()); int front = 0; @@ -41,7 +41,7 @@ void TopologicalSort(GraphDef* graph) { if (IsMerge(*node)) { ready_inputs[node] = 0; for (const auto& input : node->input()) { - if (IsNextIteration(*node_map.GetNode(input))) { + if (IsNextIteration(*output_map.GetNode(input))) { ready_inputs[node]++; } } @@ -52,8 +52,9 @@ void TopologicalSort(GraphDef* graph) { while (front != back) { auto ready_node = ready_nodes[front]; - for (const auto& fanout : node_map.GetOutputs(ready_node->name())) { - ready_inputs[fanout]++; + for (const auto& fanout_pair : output_map.GetOutputs(ready_node->name())) { + auto fanout = fanout_pair.first; + ready_inputs[fanout] += fanout_pair.second; if (ready_inputs[fanout] == fanout->input_size()) { ready_nodes.push_back(fanout); back++; @@ -70,6 +71,8 @@ void TopologicalSort(GraphDef* graph) { new_node->Swap(ready_nodes[i]); } graph->mutable_node()->Swap(new_graph.mutable_node()); + } else { + LOG(ERROR) << "The graph couldn't be sorted in topological order."; } } diff --git a/tensorflow/core/grappler/utils/topological_sort_test.cc b/tensorflow/core/grappler/utils/topological_sort_test.cc index 55f66b2734..dc99cb1052 100644 --- a/tensorflow/core/grappler/utils/topological_sort_test.cc +++ b/tensorflow/core/grappler/utils/topological_sort_test.cc @@ -89,6 +89,18 @@ TEST_F(TopologicalSortTest, WithIllegalLoop) { } } +TEST_F(TopologicalSortTest, DuplicatedInputs) { + GraphDef graph; + *graph.add_node() = CreateNode("2", {"1", "1"}); + *graph.add_node() = CreateNode("1", {}); + + TopologicalSort(&graph); + std::vector<string> order = {"1", "2"}; + for (int i = 0; i < order.size(); i++) { + EXPECT_EQ(graph.node(i).name(), order[i]); + } +} + } // namespace } // namespace grappler } // namespace tensorflow |