aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler
diff options
context:
space:
mode:
authorGravatar Yao Zhang <yaozhang@google.com>2017-08-30 13:59:20 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-08-30 14:03:08 -0700
commitac2c94fb30c8e8b44ca40b01082016813eea18d8 (patch)
treee3f54acdd2f8cd5d61435766c9218c8cf6a948a7 /tensorflow/core/grappler
parent8f25fa964751c56a50a6c9c9d420927b98fadf65 (diff)
Handle duplicated inputs in topological sort. And do not add the redundant control dependencies. The would result in malfunction of
topological sort as it previously doesn't handle duplicated inputs. For example, say node A has three repeated input ^B, node A will never get added to queue in topological sort, because the number of ready inputs will always be less than the number of inputs (B is only counted once). node { name: "A" op: "SomeOp" input:"^B" input:"^B" input:"^B" } PiperOrigin-RevId: 167045325
Diffstat (limited to 'tensorflow/core/grappler')
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding.cc8
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding_test.cc37
-rw-r--r--tensorflow/core/grappler/utils.cc35
-rw-r--r--tensorflow/core/grappler/utils.h16
-rw-r--r--tensorflow/core/grappler/utils/topological_sort.cc11
-rw-r--r--tensorflow/core/grappler/utils/topological_sort_test.cc12
6 files changed, 113 insertions, 6 deletions
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc
index 35b0b7c163..443c0b72ab 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding.cc
@@ -568,12 +568,16 @@ Status ConstantFolding::FoldNode(NodeDef* node) {
// Forward control dependencies.
for (const auto& input : node->input()) {
- if (IsControlInput(input)) {
+ if (IsControlInput(input) &&
+ std::find(const_node->input().begin(), const_node->input().end(),
+ input) == const_node->input().end()) {
*const_node->add_input() = input;
} else {
NodeDef* input_node = node_map_->GetNode(input);
for (const auto& fanin_of_input : input_node->input()) {
- if (IsControlInput(fanin_of_input)) {
+ if (IsControlInput(fanin_of_input) &&
+ std::find(const_node->input().begin(), const_node->input().end(),
+ fanin_of_input) == const_node->input().end()) {
*const_node->add_input() = fanin_of_input;
}
}
diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
index 2db843f2a4..0f7e7f1d49 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
@@ -235,6 +235,43 @@ TEST_F(ConstantFoldingTest, ControlDependenciesEmptyFetch) {
EXPECT_EQ(2, found);
}
+TEST_F(ConstantFoldingTest, ControlDependenciesDeduplicate) {
+ tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
+ Output dflt = ops::Const(scope.WithOpName("dflt"), 3.14f, {1});
+ Output p1 = ops::PlaceholderWithDefault(scope.WithOpName("p1"), dflt, {1});
+ Output p2 = ops::PlaceholderWithDefault(scope.WithOpName("p2"), dflt, {1});
+ Output c =
+ ops::Const(scope.WithOpName("c").WithControlDependencies(p1), 10, {3});
+ Output i1 = ops::Identity(scope.WithOpName("i1")
+ .WithControlDependencies(p2)
+ .WithControlDependencies(p1),
+ {c});
+ Output i2 = ops::Identity(scope.WithOpName("i2"), {i1});
+
+ GrapplerItem item;
+ item.fetch.push_back("i2");
+ TF_CHECK_OK(scope.ToGraphDef(&item.graph));
+
+ ConstantFolding fold;
+ GraphDef output;
+ Status status = fold.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
+
+ std::vector<string> expected_nodes = {"dflt", "p1", "p2", "i1", "i2"};
+ EXPECT_EQ(output.node_size(), expected_nodes.size());
+ int i = 0;
+ for (const auto& node : output.node()) {
+ EXPECT_EQ(expected_nodes[i], output.node(i).name());
+ i++;
+ if (node.name() == "i1") {
+ EXPECT_EQ("Const", node.op());
+ EXPECT_EQ(2, node.input_size());
+ EXPECT_EQ("^p1", node.input(0));
+ EXPECT_EQ("^p2", node.input(1));
+ }
+ }
+}
+
TEST_F(ConstantFoldingTest, VariableNumberOfOutputs) {
tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
// Add a DynamicPartition node to the graph
diff --git a/tensorflow/core/grappler/utils.cc b/tensorflow/core/grappler/utils.cc
index 9e15744fab..c8830e9b3c 100644
--- a/tensorflow/core/grappler/utils.cc
+++ b/tensorflow/core/grappler/utils.cc
@@ -95,6 +95,41 @@ void NodeMap::UpdateOutput(const string& node_name,
outputs.insert(nodes_[new_output_name]);
}
+OutputMap::OutputMap(GraphDef* graph) : graph_(graph) {
+ for (int i = 0; i < graph_->node_size(); i++) {
+ auto node = graph_->mutable_node(i);
+ auto rslt = nodes_.insert(std::make_pair(node->name(), node));
+ // Check that the graph doesn't contain multiple nodes with the same name.
+ CHECK(rslt.second);
+ for (const auto& input : node->input()) {
+ string input_node = NodeName(input);
+ if (outputs_[input_node].count(node) == 0) {
+ outputs_[input_node].insert(std::make_pair(node, 1));
+ } else {
+ outputs_[input_node][node]++;
+ }
+ }
+ }
+}
+
+NodeDef* OutputMap::GetNode(const string& name) const {
+ string node_name = NodeName(name);
+ auto it = nodes_.find(node_name);
+ if (it == nodes_.end()) {
+ return nullptr;
+ }
+ return it->second;
+}
+
+const std::unordered_map<NodeDef*, int>& OutputMap::GetOutputs(
+ const string& node_name) const {
+ auto it = outputs_.find(node_name);
+ if (it == outputs_.end()) {
+ return empty_map_;
+ }
+ return it->second;
+}
+
bool IsSameInput(const string& name1, const string& name2) {
if (name1 == name2) {
return true;
diff --git a/tensorflow/core/grappler/utils.h b/tensorflow/core/grappler/utils.h
index a9eccd685b..03f49c0ca2 100644
--- a/tensorflow/core/grappler/utils.h
+++ b/tensorflow/core/grappler/utils.h
@@ -52,6 +52,22 @@ class NodeMap {
std::unordered_map<string, std::set<NodeDef*>> outputs_;
};
+// A utility class to lookup a node's outputs and the number of times it
+// presents in each output.
+class OutputMap {
+ public:
+ explicit OutputMap(GraphDef* graph);
+ NodeDef* GetNode(const string& name) const;
+ const std::unordered_map<NodeDef*, int>& GetOutputs(
+ const string& node_name) const;
+
+ private:
+ GraphDef* graph_;
+ std::unordered_map<NodeDef*, int> empty_map_;
+ std::unordered_map<string, NodeDef*> nodes_;
+ std::unordered_map<string, std::unordered_map<NodeDef*, int>> outputs_;
+};
+
// True iff 'name' refers to a control inputs, i.e. a node name prefixed with
// the ^ character.
bool IsControlInput(const string& name);
diff --git a/tensorflow/core/grappler/utils/topological_sort.cc b/tensorflow/core/grappler/utils/topological_sort.cc
index a5a7f34db0..77d4702d21 100644
--- a/tensorflow/core/grappler/utils/topological_sort.cc
+++ b/tensorflow/core/grappler/utils/topological_sort.cc
@@ -26,7 +26,7 @@ namespace grappler {
// Kahn's algorithm is implemented.
// For details, see https://en.wikipedia.org/wiki/Topological_sorting
void TopologicalSort(GraphDef* graph) {
- NodeMap node_map(graph);
+ OutputMap output_map(graph);
std::vector<NodeDef*> ready_nodes;
ready_nodes.reserve(graph->node_size());
int front = 0;
@@ -41,7 +41,7 @@ void TopologicalSort(GraphDef* graph) {
if (IsMerge(*node)) {
ready_inputs[node] = 0;
for (const auto& input : node->input()) {
- if (IsNextIteration(*node_map.GetNode(input))) {
+ if (IsNextIteration(*output_map.GetNode(input))) {
ready_inputs[node]++;
}
}
@@ -52,8 +52,9 @@ void TopologicalSort(GraphDef* graph) {
while (front != back) {
auto ready_node = ready_nodes[front];
- for (const auto& fanout : node_map.GetOutputs(ready_node->name())) {
- ready_inputs[fanout]++;
+ for (const auto& fanout_pair : output_map.GetOutputs(ready_node->name())) {
+ auto fanout = fanout_pair.first;
+ ready_inputs[fanout] += fanout_pair.second;
if (ready_inputs[fanout] == fanout->input_size()) {
ready_nodes.push_back(fanout);
back++;
@@ -70,6 +71,8 @@ void TopologicalSort(GraphDef* graph) {
new_node->Swap(ready_nodes[i]);
}
graph->mutable_node()->Swap(new_graph.mutable_node());
+ } else {
+ LOG(ERROR) << "The graph couldn't be sorted in topological order.";
}
}
diff --git a/tensorflow/core/grappler/utils/topological_sort_test.cc b/tensorflow/core/grappler/utils/topological_sort_test.cc
index 55f66b2734..dc99cb1052 100644
--- a/tensorflow/core/grappler/utils/topological_sort_test.cc
+++ b/tensorflow/core/grappler/utils/topological_sort_test.cc
@@ -89,6 +89,18 @@ TEST_F(TopologicalSortTest, WithIllegalLoop) {
}
}
+TEST_F(TopologicalSortTest, DuplicatedInputs) {
+ GraphDef graph;
+ *graph.add_node() = CreateNode("2", {"1", "1"});
+ *graph.add_node() = CreateNode("1", {});
+
+ TopologicalSort(&graph);
+ std::vector<string> order = {"1", "2"};
+ for (int i = 0; i < order.size(); i++) {
+ EXPECT_EQ(graph.node(i).name(), order[i]);
+ }
+}
+
} // namespace
} // namespace grappler
} // namespace tensorflow