diff options
Diffstat (limited to 'tensorflow/core/grappler/utils/scc.cc')
-rw-r--r-- | tensorflow/core/grappler/utils/scc.cc | 12 |
1 files changed, 9 insertions, 3 deletions
diff --git a/tensorflow/core/grappler/utils/scc.cc b/tensorflow/core/grappler/utils/scc.cc index f2a6507d94..d033e9c522 100644 --- a/tensorflow/core/grappler/utils/scc.cc +++ b/tensorflow/core/grappler/utils/scc.cc @@ -142,9 +142,13 @@ void StronglyConnectedComponents( // Create a list of top-level parents (add them to object queue) // Also create a mapping from nodes to their children. + // Inputs might not be present if called on a subgraph. for (const NodeDef& node : graph.node()) { for (const string& input : node.input()) { - name_to_data[NodeName(input)]->children.push_back(node_to_data[&node]); + auto it = name_to_data.find(NodeName(input)); + if (it != name_to_data.end()) { + it->second->children.push_back(node_to_data[&node]); + } } } @@ -202,10 +206,12 @@ int IdentifyLoops(const GraphDef& graph, const std::vector<const NodeDef*>& component_nodes = component.second; std::vector<std::pair<NodeDef*, string>> next_iter_nodes; GraphDef subgraph; + std::unordered_map<const NodeDef*, const NodeDef*> subgraph_mapping; for (const auto& component_node : component_nodes) { NodeDef* node = subgraph.add_node(); *node = *component_node; + subgraph_mapping[node] = component_node; if (IsNextIteration(*node)) { CHECK_EQ(1, node->input_size()); next_iter_nodes.emplace_back(node, node->input(0)); @@ -227,13 +233,13 @@ int IdentifyLoops(const GraphDef& graph, int num_components = 0; std::unordered_map<const NodeDef*, int> components; StronglyConnectedComponents(subgraph, &components, &num_components); - CHECK_EQ(1, num_components); + CHECK_GE(num_components, 1); for (const auto it : components) { int id = it.second; if (id < 0) { continue; } - (*loops)[it.first].push_back(loop_id); + (*loops)[subgraph_mapping[it.first]].push_back(loop_id); } ++loop_id; } |