aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/utils/scc.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/grappler/utils/scc.cc')
-rw-r--r--tensorflow/core/grappler/utils/scc.cc12
1 files changed, 9 insertions, 3 deletions
diff --git a/tensorflow/core/grappler/utils/scc.cc b/tensorflow/core/grappler/utils/scc.cc
index f2a6507d94..d033e9c522 100644
--- a/tensorflow/core/grappler/utils/scc.cc
+++ b/tensorflow/core/grappler/utils/scc.cc
@@ -142,9 +142,13 @@ void StronglyConnectedComponents(
// Create a list of top-level parents (add them to object queue)
// Also create a mapping from nodes to their children.
+ // Inputs might not be present if called on a subgraph.
for (const NodeDef& node : graph.node()) {
for (const string& input : node.input()) {
- name_to_data[NodeName(input)]->children.push_back(node_to_data[&node]);
+ auto it = name_to_data.find(NodeName(input));
+ if (it != name_to_data.end()) {
+ it->second->children.push_back(node_to_data[&node]);
+ }
}
}
@@ -202,10 +206,12 @@ int IdentifyLoops(const GraphDef& graph,
const std::vector<const NodeDef*>& component_nodes = component.second;
std::vector<std::pair<NodeDef*, string>> next_iter_nodes;
GraphDef subgraph;
+ std::unordered_map<const NodeDef*, const NodeDef*> subgraph_mapping;
for (const auto& component_node : component_nodes) {
NodeDef* node = subgraph.add_node();
*node = *component_node;
+ subgraph_mapping[node] = component_node;
if (IsNextIteration(*node)) {
CHECK_EQ(1, node->input_size());
next_iter_nodes.emplace_back(node, node->input(0));
@@ -227,13 +233,13 @@ int IdentifyLoops(const GraphDef& graph,
int num_components = 0;
std::unordered_map<const NodeDef*, int> components;
StronglyConnectedComponents(subgraph, &components, &num_components);
- CHECK_EQ(1, num_components);
+ CHECK_GE(num_components, 1);
for (const auto it : components) {
int id = it.second;
if (id < 0) {
continue;
}
- (*loops)[it.first].push_back(loop_id);
+ (*loops)[subgraph_mapping[it.first]].push_back(loop_id);
}
++loop_id;
}