diff options
author | 2017-08-30 13:59:20 -0700 | |
---|---|---|
committer | 2017-08-30 14:03:08 -0700 | |
commit | ac2c94fb30c8e8b44ca40b01082016813eea18d8 (patch) | |
tree | e3f54acdd2f8cd5d61435766c9218c8cf6a948a7 /tensorflow/core/grappler/utils.cc | |
parent | 8f25fa964751c56a50a6c9c9d420927b98fadf65 (diff) |
Handle duplicated inputs in topological sort. And do not add the redundant control dependencies. The would result in malfunction of
topological sort as it previously doesn't handle duplicated inputs. For example, say node A has three repeated input ^B, node A
will never get added to queue in topological sort, because the number of ready inputs will always be less
than the number of inputs (B is only counted once).
node {
name: "A"
op: "SomeOp"
input:"^B"
input:"^B"
input:"^B"
}
PiperOrigin-RevId: 167045325
Diffstat (limited to 'tensorflow/core/grappler/utils.cc')
-rw-r--r-- | tensorflow/core/grappler/utils.cc | 35 |
1 files changed, 35 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/utils.cc b/tensorflow/core/grappler/utils.cc index 9e15744fab..c8830e9b3c 100644 --- a/tensorflow/core/grappler/utils.cc +++ b/tensorflow/core/grappler/utils.cc @@ -95,6 +95,41 @@ void NodeMap::UpdateOutput(const string& node_name, outputs.insert(nodes_[new_output_name]); } +OutputMap::OutputMap(GraphDef* graph) : graph_(graph) { + for (int i = 0; i < graph_->node_size(); i++) { + auto node = graph_->mutable_node(i); + auto rslt = nodes_.insert(std::make_pair(node->name(), node)); + // Check that the graph doesn't contain multiple nodes with the same name. + CHECK(rslt.second); + for (const auto& input : node->input()) { + string input_node = NodeName(input); + if (outputs_[input_node].count(node) == 0) { + outputs_[input_node].insert(std::make_pair(node, 1)); + } else { + outputs_[input_node][node]++; + } + } + } +} + +NodeDef* OutputMap::GetNode(const string& name) const { + string node_name = NodeName(name); + auto it = nodes_.find(node_name); + if (it == nodes_.end()) { + return nullptr; + } + return it->second; +} + +const std::unordered_map<NodeDef*, int>& OutputMap::GetOutputs( + const string& node_name) const { + auto it = outputs_.find(node_name); + if (it == outputs_.end()) { + return empty_map_; + } + return it->second; +} + bool IsSameInput(const string& name1, const string& name2) { if (name1 == name2) { return true; |