aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/utils.cc
diff options
context:
space:
mode:
authorGravatar Yao Zhang <yaozhang@google.com>2017-08-30 13:59:20 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-08-30 14:03:08 -0700
commitac2c94fb30c8e8b44ca40b01082016813eea18d8 (patch)
treee3f54acdd2f8cd5d61435766c9218c8cf6a948a7 /tensorflow/core/grappler/utils.cc
parent8f25fa964751c56a50a6c9c9d420927b98fadf65 (diff)
Handle duplicated inputs in topological sort. And do not add the redundant control dependencies. The would result in malfunction of
topological sort as it previously doesn't handle duplicated inputs. For example, say node A has three repeated input ^B, node A will never get added to queue in topological sort, because the number of ready inputs will always be less than the number of inputs (B is only counted once). node { name: "A" op: "SomeOp" input:"^B" input:"^B" input:"^B" } PiperOrigin-RevId: 167045325
Diffstat (limited to 'tensorflow/core/grappler/utils.cc')
-rw-r--r--tensorflow/core/grappler/utils.cc35
1 files changed, 35 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/utils.cc b/tensorflow/core/grappler/utils.cc
index 9e15744fab..c8830e9b3c 100644
--- a/tensorflow/core/grappler/utils.cc
+++ b/tensorflow/core/grappler/utils.cc
@@ -95,6 +95,41 @@ void NodeMap::UpdateOutput(const string& node_name,
outputs.insert(nodes_[new_output_name]);
}
+OutputMap::OutputMap(GraphDef* graph) : graph_(graph) {
+ for (int i = 0; i < graph_->node_size(); i++) {
+ auto node = graph_->mutable_node(i);
+ auto rslt = nodes_.insert(std::make_pair(node->name(), node));
+ // Check that the graph doesn't contain multiple nodes with the same name.
+ CHECK(rslt.second);
+ for (const auto& input : node->input()) {
+ string input_node = NodeName(input);
+ if (outputs_[input_node].count(node) == 0) {
+ outputs_[input_node].insert(std::make_pair(node, 1));
+ } else {
+ outputs_[input_node][node]++;
+ }
+ }
+ }
+}
+
+NodeDef* OutputMap::GetNode(const string& name) const {
+ string node_name = NodeName(name);
+ auto it = nodes_.find(node_name);
+ if (it == nodes_.end()) {
+ return nullptr;
+ }
+ return it->second;
+}
+
+const std::unordered_map<NodeDef*, int>& OutputMap::GetOutputs(
+ const string& node_name) const {
+ auto it = outputs_.find(node_name);
+ if (it == outputs_.end()) {
+ return empty_map_;
+ }
+ return it->second;
+}
+
bool IsSameInput(const string& name1, const string& name2) {
if (name1 == name2) {
return true;