diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-06-22 17:31:32 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-22 17:36:31 -0700 |
commit | 6dac3f07fe4c69864a39a8a7639ee314d608d38d (patch) | |
tree | c8493b70184b26e27833507b6f09826127a2f19b /tensorflow/c/c_api.cc | |
parent | 357522a0e73a313eff682c1e4b7b485090518c4a (diff) |
Move cycle detection helper function from c/c_api to core/graph/validate.
PiperOrigin-RevId: 201766085
Diffstat (limited to 'tensorflow/c/c_api.cc')
-rw-r--r-- | tensorflow/c/c_api.cc | 63 |
1 files changed, 4 insertions, 59 deletions
diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index cb0b093ad2..09a03639d6 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -45,6 +45,7 @@ limitations under the License. #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/graph/validate.h" #include "tensorflow/core/lib/core/coding.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" @@ -390,64 +391,6 @@ void TF_Reset_Helper(const TF_SessionOptions* opt, const char** containers, status->status = Reset(opt->options, container_names); } -// This traverses the specified nodes in topological order to verify there are -// no cycles. Starting with inputless nodes, it visits nodes whose inputs have -// all been visited, and counts the total number of visited nodes. If there is a -// cycle, nodes in the cycle will never be visited, and the visited count will -// be less than the total node count. -Status ValidateNoCycles(const Graph& g) { - // TODO(nolivia): check this on a subset of the graph instead of all of it. - // A node is ready when all of its inputs have been visited. - std::vector<const Node*> ready; - std::vector<int> pending_count(g.num_node_ids(), 0); - - for (int i = 0; i < g.num_node_ids(); ++i) { - const Node* n = g.FindNodeId(i); - if (n == nullptr) continue; - pending_count[i] = n->in_edges().size(); - if (n->IsMerge()) { - // While-loop cycles are legal cycles so we manually adjust the - // pending_count to make sure that the loop is visited. - for (const Edge* e : n->in_edges()) { - if (!e->IsControlEdge() && e->src()->IsNextIteration()) { - pending_count[i]--; - } - } - } - if (pending_count[i] == 0) { - ready.push_back(n); - } - } - - int processed = 0; - while (!ready.empty()) { - const Node* node = ready.back(); - ready.pop_back(); - ++processed; - - for (const Edge* out : node->out_edges()) { - const int output_id = out->dst()->id(); - pending_count[output_id]--; - if (pending_count[output_id] == 0) { - ready.push_back(out->dst()); - } - } - } - - if (processed < g.num_nodes()) { - std::vector<string> nodes_in_cycle; - for (int i = 0; i < pending_count.size() && nodes_in_cycle.size() < 3; - ++i) { - if (pending_count[i] != 0) { - nodes_in_cycle.push_back(g.FindNodeId(i)->name()); - } - } - return errors::InvalidArgument( - "Graph is invalid, contains a cycle with ", g.num_nodes() - processed, - " nodes, including: ", str_util::Join(nodes_in_cycle, ", ")); - } - return Status::OK(); -} } // namespace } // namespace tensorflow @@ -746,7 +689,9 @@ bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) { const auto num_nodes = graph.num_node_ids(); if (session->last_num_graph_nodes < num_nodes) { - status->status = tensorflow::ValidateNoCycles(session->graph->graph); + // TODO(nolivia): check this on a subset of the graph instead of all of + // it. + status->status = graph::ValidateGraphHasNoCycle(session->graph->graph); if (!status->status.ok()) { session->graph->mu.unlock(); return false; |