aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding.cc8
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding_test.cc37
-rw-r--r--tensorflow/core/grappler/utils.cc35
-rw-r--r--tensorflow/core/grappler/utils.h16
-rw-r--r--tensorflow/core/grappler/utils/topological_sort.cc11
-rw-r--r--tensorflow/core/grappler/utils/topological_sort_test.cc12
6 files changed, 113 insertions, 6 deletions
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc
index 35b0b7c163..443c0b72ab 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding.cc
@@ -568,12 +568,16 @@ Status ConstantFolding::FoldNode(NodeDef* node) {
// Forward control dependencies.
for (const auto& input : node->input()) {
- if (IsControlInput(input)) {
+ if (IsControlInput(input) &&
+ std::find(const_node->input().begin(), const_node->input().end(),
+ input) == const_node->input().end()) {
*const_node->add_input() = input;
} else {
NodeDef* input_node = node_map_->GetNode(input);
for (const auto& fanin_of_input : input_node->input()) {
- if (IsControlInput(fanin_of_input)) {
+ if (IsControlInput(fanin_of_input) &&
+ std::find(const_node->input().begin(), const_node->input().end(),
+ fanin_of_input) == const_node->input().end()) {
*const_node->add_input() = fanin_of_input;
}
}
diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
index 2db843f2a4..0f7e7f1d49 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
@@ -235,6 +235,43 @@ TEST_F(ConstantFoldingTest, ControlDependenciesEmptyFetch) {
EXPECT_EQ(2, found);
}
+TEST_F(ConstantFoldingTest, ControlDependenciesDeduplicate) {
+ tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
+ Output dflt = ops::Const(scope.WithOpName("dflt"), 3.14f, {1});
+ Output p1 = ops::PlaceholderWithDefault(scope.WithOpName("p1"), dflt, {1});
+ Output p2 = ops::PlaceholderWithDefault(scope.WithOpName("p2"), dflt, {1});
+ Output c =
+ ops::Const(scope.WithOpName("c").WithControlDependencies(p1), 10, {3});
+ Output i1 = ops::Identity(scope.WithOpName("i1")
+ .WithControlDependencies(p2)
+ .WithControlDependencies(p1),
+ {c});
+ Output i2 = ops::Identity(scope.WithOpName("i2"), {i1});
+
+ GrapplerItem item;
+ item.fetch.push_back("i2");
+ TF_CHECK_OK(scope.ToGraphDef(&item.graph));
+
+ ConstantFolding fold;
+ GraphDef output;
+ Status status = fold.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
+
+ std::vector<string> expected_nodes = {"dflt", "p1", "p2", "i1", "i2"};
+ EXPECT_EQ(output.node_size(), expected_nodes.size());
+ int i = 0;
+ for (const auto& node : output.node()) {
+ EXPECT_EQ(expected_nodes[i], output.node(i).name());
+ i++;
+ if (node.name() == "i1") {
+ EXPECT_EQ("Const", node.op());
+ EXPECT_EQ(2, node.input_size());
+ EXPECT_EQ("^p1", node.input(0));
+ EXPECT_EQ("^p2", node.input(1));
+ }
+ }
+}
+
TEST_F(ConstantFoldingTest, VariableNumberOfOutputs) {
tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
// Add a DynamicPartition node to the graph
diff --git a/tensorflow/core/grappler/utils.cc b/tensorflow/core/grappler/utils.cc
index 9e15744fab..c8830e9b3c 100644
--- a/tensorflow/core/grappler/utils.cc
+++ b/tensorflow/core/grappler/utils.cc
@@ -95,6 +95,41 @@ void NodeMap::UpdateOutput(const string& node_name,
outputs.insert(nodes_[new_output_name]);
}
+OutputMap::OutputMap(GraphDef* graph) : graph_(graph) {
+ for (int i = 0; i < graph_->node_size(); i++) {
+ auto node = graph_->mutable_node(i);
+ auto rslt = nodes_.insert(std::make_pair(node->name(), node));
+ // Check that the graph doesn't contain multiple nodes with the same name.
+ CHECK(rslt.second);
+ for (const auto& input : node->input()) {
+ string input_node = NodeName(input);
+ if (outputs_[input_node].count(node) == 0) {
+ outputs_[input_node].insert(std::make_pair(node, 1));
+ } else {
+ outputs_[input_node][node]++;
+ }
+ }
+ }
+}
+
+NodeDef* OutputMap::GetNode(const string& name) const {
+ string node_name = NodeName(name);
+ auto it = nodes_.find(node_name);
+ if (it == nodes_.end()) {
+ return nullptr;
+ }
+ return it->second;
+}
+
+const std::unordered_map<NodeDef*, int>& OutputMap::GetOutputs(
+ const string& node_name) const {
+ auto it = outputs_.find(node_name);
+ if (it == outputs_.end()) {
+ return empty_map_;
+ }
+ return it->second;
+}
+
bool IsSameInput(const string& name1, const string& name2) {
if (name1 == name2) {
return true;
diff --git a/tensorflow/core/grappler/utils.h b/tensorflow/core/grappler/utils.h
index a9eccd685b..03f49c0ca2 100644
--- a/tensorflow/core/grappler/utils.h
+++ b/tensorflow/core/grappler/utils.h
@@ -52,6 +52,22 @@ class NodeMap {
std::unordered_map<string, std::set<NodeDef*>> outputs_;
};
+// A utility class to lookup a node's outputs and the number of times it
+// presents in each output.
+class OutputMap {
+ public:
+ explicit OutputMap(GraphDef* graph);
+ NodeDef* GetNode(const string& name) const;
+ const std::unordered_map<NodeDef*, int>& GetOutputs(
+ const string& node_name) const;
+
+ private:
+ GraphDef* graph_;
+ std::unordered_map<NodeDef*, int> empty_map_;
+ std::unordered_map<string, NodeDef*> nodes_;
+ std::unordered_map<string, std::unordered_map<NodeDef*, int>> outputs_;
+};
+
// True iff 'name' refers to a control inputs, i.e. a node name prefixed with
// the ^ character.
bool IsControlInput(const string& name);
diff --git a/tensorflow/core/grappler/utils/topological_sort.cc b/tensorflow/core/grappler/utils/topological_sort.cc
index a5a7f34db0..77d4702d21 100644
--- a/tensorflow/core/grappler/utils/topological_sort.cc
+++ b/tensorflow/core/grappler/utils/topological_sort.cc
@@ -26,7 +26,7 @@ namespace grappler {
// Kahn's algorithm is implemented.
// For details, see https://en.wikipedia.org/wiki/Topological_sorting
void TopologicalSort(GraphDef* graph) {
- NodeMap node_map(graph);
+ OutputMap output_map(graph);
std::vector<NodeDef*> ready_nodes;
ready_nodes.reserve(graph->node_size());
int front = 0;
@@ -41,7 +41,7 @@ void TopologicalSort(GraphDef* graph) {
if (IsMerge(*node)) {
ready_inputs[node] = 0;
for (const auto& input : node->input()) {
- if (IsNextIteration(*node_map.GetNode(input))) {
+ if (IsNextIteration(*output_map.GetNode(input))) {
ready_inputs[node]++;
}
}
@@ -52,8 +52,9 @@ void TopologicalSort(GraphDef* graph) {
while (front != back) {
auto ready_node = ready_nodes[front];
- for (const auto& fanout : node_map.GetOutputs(ready_node->name())) {
- ready_inputs[fanout]++;
+ for (const auto& fanout_pair : output_map.GetOutputs(ready_node->name())) {
+ auto fanout = fanout_pair.first;
+ ready_inputs[fanout] += fanout_pair.second;
if (ready_inputs[fanout] == fanout->input_size()) {
ready_nodes.push_back(fanout);
back++;
@@ -70,6 +71,8 @@ void TopologicalSort(GraphDef* graph) {
new_node->Swap(ready_nodes[i]);
}
graph->mutable_node()->Swap(new_graph.mutable_node());
+ } else {
+ LOG(ERROR) << "The graph couldn't be sorted in topological order.";
}
}
diff --git a/tensorflow/core/grappler/utils/topological_sort_test.cc b/tensorflow/core/grappler/utils/topological_sort_test.cc
index 55f66b2734..dc99cb1052 100644
--- a/tensorflow/core/grappler/utils/topological_sort_test.cc
+++ b/tensorflow/core/grappler/utils/topological_sort_test.cc
@@ -89,6 +89,18 @@ TEST_F(TopologicalSortTest, WithIllegalLoop) {
}
}
+TEST_F(TopologicalSortTest, DuplicatedInputs) {
+ GraphDef graph;
+ *graph.add_node() = CreateNode("2", {"1", "1"});
+ *graph.add_node() = CreateNode("1", {});
+
+ TopologicalSort(&graph);
+ std::vector<string> order = {"1", "2"};
+ for (int i = 0; i < order.size(); i++) {
+ EXPECT_EQ(graph.node(i).name(), order[i]);
+ }
+}
+
} // namespace
} // namespace grappler
} // namespace tensorflow