diff options
author | Olivia Nordquist <nolivia@google.com> | 2017-09-06 14:29:51 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-09-06 14:34:14 -0700 |
commit | f8d4b5eae59124c9a4b01a4cc1097e1f0006137a (patch) | |
tree | a98c959b859b2c0c02312311544a2eb47ff83379 /tensorflow/c/c_api.cc | |
parent | d528a52baa3562cbf952a5be71ebdc38e1e5c551 (diff) |
detecting cycles when users add a control edge to a graph
PiperOrigin-RevId: 167773598
Diffstat (limited to 'tensorflow/c/c_api.cc')
-rw-r--r-- | tensorflow/c/c_api.cc | 65 |
1 files changed, 65 insertions, 0 deletions
diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index c454c94249..334f867e47 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -374,6 +374,65 @@ void TF_Reset_Helper(const TF_SessionOptions* opt, const char** containers, status->status = Reset(opt->options, container_names); } +// This traverses the specified nodes in topological order to verify there are +// no cycles. Starting with inputless nodes, it visits nodes whose inputs have +// all been visited, and counts the total number of visited nodes. If there is a +// cycle, nodes in the cycle will never be visited, and the visited count will +// be less than the total node count. +Status ValidateNoCycles(const Graph& g) { + // TODO(nolivia): check this on a subset of the graph instead of all of it. + int total_num_nodes = g.num_node_ids(); + // A node is ready when all of its inputs have been visited. + std::vector<const Node*> ready; + std::vector<int> pending_count(total_num_nodes, 0); + + for (int i = 0; i < total_num_nodes; ++i) { + const Node* n = g.FindNodeId(i); + if (n == nullptr) continue; + pending_count[i] = n->in_edges().size(); + if (n->IsMerge()) { + // While-loop cycles are legal cycles so we manually adjust the + // pending_count to make sure that the loop is visited. + for (const Edge* e : n->in_edges()) { + if (!e->IsControlEdge() && e->src()->IsNextIteration()) { + pending_count[i]--; + } + } + } + if (pending_count[i] == 0) { + ready.push_back(n); + } + } + + int processed = 0; + while (!ready.empty()) { + const Node* node = ready.back(); + ready.pop_back(); + ++processed; + + for (const Edge* out : node->out_edges()) { + const int output_id = out->dst()->id(); + pending_count[output_id]--; + if (pending_count[output_id] == 0) { + ready.push_back(out->dst()); + } + } + } + + if (processed < total_num_nodes) { + std::vector<string> nodes_in_cycle; + for (int i = 0; i < pending_count.size() && nodes_in_cycle.size() < 3; + ++i) { + if (pending_count[i] != 0) { + nodes_in_cycle.push_back(g.FindNodeId(i)->name()); + } + } + return errors::InvalidArgument( + "Graph is invalid, contains a cycle with ", total_num_nodes - processed, + " nodes, including: ", str_util::Join(nodes_in_cycle, ", ")); + } + return Status::OK(); +} } // namespace } // namespace tensorflow @@ -2251,6 +2310,12 @@ static bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) { const Graph& graph = session->graph->graph; const auto num_nodes = graph.num_node_ids(); if (session->last_num_graph_nodes < num_nodes) { + status->status = tensorflow::ValidateNoCycles(session->graph->graph); + if (!status->status.ok()) { + session->graph->mu.unlock(); + return false; + } + GraphDef graph_def; *graph_def.mutable_versions() = graph.versions(); // Fill graph_def with nodes with ids in the range |