diff options
-rw-r--r-- | tensorflow/c/c_api.cc | 63 | ||||
-rw-r--r-- | tensorflow/core/graph/validate.cc | 54 | ||||
-rw-r--r-- | tensorflow/core/graph/validate.h | 9 |
3 files changed, 67 insertions, 59 deletions
diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index cb0b093ad2..09a03639d6 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -45,6 +45,7 @@ limitations under the License. #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/graph/validate.h" #include "tensorflow/core/lib/core/coding.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" @@ -390,64 +391,6 @@ void TF_Reset_Helper(const TF_SessionOptions* opt, const char** containers, status->status = Reset(opt->options, container_names); } -// This traverses the specified nodes in topological order to verify there are -// no cycles. Starting with inputless nodes, it visits nodes whose inputs have -// all been visited, and counts the total number of visited nodes. If there is a -// cycle, nodes in the cycle will never be visited, and the visited count will -// be less than the total node count. -Status ValidateNoCycles(const Graph& g) { - // TODO(nolivia): check this on a subset of the graph instead of all of it. - // A node is ready when all of its inputs have been visited. - std::vector<const Node*> ready; - std::vector<int> pending_count(g.num_node_ids(), 0); - - for (int i = 0; i < g.num_node_ids(); ++i) { - const Node* n = g.FindNodeId(i); - if (n == nullptr) continue; - pending_count[i] = n->in_edges().size(); - if (n->IsMerge()) { - // While-loop cycles are legal cycles so we manually adjust the - // pending_count to make sure that the loop is visited. - for (const Edge* e : n->in_edges()) { - if (!e->IsControlEdge() && e->src()->IsNextIteration()) { - pending_count[i]--; - } - } - } - if (pending_count[i] == 0) { - ready.push_back(n); - } - } - - int processed = 0; - while (!ready.empty()) { - const Node* node = ready.back(); - ready.pop_back(); - ++processed; - - for (const Edge* out : node->out_edges()) { - const int output_id = out->dst()->id(); - pending_count[output_id]--; - if (pending_count[output_id] == 0) { - ready.push_back(out->dst()); - } - } - } - - if (processed < g.num_nodes()) { - std::vector<string> nodes_in_cycle; - for (int i = 0; i < pending_count.size() && nodes_in_cycle.size() < 3; - ++i) { - if (pending_count[i] != 0) { - nodes_in_cycle.push_back(g.FindNodeId(i)->name()); - } - } - return errors::InvalidArgument( - "Graph is invalid, contains a cycle with ", g.num_nodes() - processed, - " nodes, including: ", str_util::Join(nodes_in_cycle, ", ")); - } - return Status::OK(); -} } // namespace } // namespace tensorflow @@ -746,7 +689,9 @@ bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) { const auto num_nodes = graph.num_node_ids(); if (session->last_num_graph_nodes < num_nodes) { - status->status = tensorflow::ValidateNoCycles(session->graph->graph); + // TODO(nolivia): check this on a subset of the graph instead of all of + // it. + status->status = graph::ValidateGraphHasNoCycle(session->graph->graph); if (!status->status.ok()) { session->graph->mu.unlock(); return false; diff --git a/tensorflow/core/graph/validate.cc b/tensorflow/core/graph/validate.cc index bd905651d2..e44eb91d48 100644 --- a/tensorflow/core/graph/validate.cc +++ b/tensorflow/core/graph/validate.cc @@ -59,5 +59,59 @@ void GetOpListForValidation(OpList* op_list, const OpRegistry& op_registry) { RemoveDescriptionsFromOpList(op_list); } +Status ValidateGraphHasNoCycle(const Graph& graph) { + // A node is ready when all of its inputs have been visited. + std::vector<const Node*> ready; + std::vector<int> pending_count(graph.num_node_ids(), 0); + + for (int i = 0; i < graph.num_node_ids(); ++i) { + const Node* n = graph.FindNodeId(i); + if (n == nullptr) continue; + pending_count[i] = n->in_edges().size(); + if (n->IsMerge()) { + // While-loop cycles are legal cycles so we manually adjust the + // pending_count to make sure that the loop is visited. + for (const Edge* e : n->in_edges()) { + if (!e->IsControlEdge() && e->src()->IsNextIteration()) { + pending_count[i]--; + } + } + } + if (pending_count[i] == 0) { + ready.push_back(n); + } + } + + int processed = 0; + while (!ready.empty()) { + const Node* node = ready.back(); + ready.pop_back(); + ++processed; + + for (const Edge* out : node->out_edges()) { + const int output_id = out->dst()->id(); + pending_count[output_id]--; + if (pending_count[output_id] == 0) { + ready.push_back(out->dst()); + } + } + } + + if (processed < graph.num_nodes()) { + std::vector<string> nodes_in_cycle; + for (int i = 0; i < pending_count.size() && nodes_in_cycle.size() < 3; + ++i) { + if (pending_count[i] != 0) { + nodes_in_cycle.push_back(graph.FindNodeId(i)->name()); + } + } + return errors::InvalidArgument( + "Graph is invalid, contains a cycle with ", + graph.num_nodes() - processed, + " nodes, including: ", str_util::Join(nodes_in_cycle, ", ")); + } + return Status::OK(); +} + } // namespace graph } // namespace tensorflow diff --git a/tensorflow/core/graph/validate.h b/tensorflow/core/graph/validate.h index cda93fe1de..08879dca60 100644 --- a/tensorflow/core/graph/validate.h +++ b/tensorflow/core/graph/validate.h @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/op.h" +#include "tensorflow/core/graph/graph.h" #include "tensorflow/core/lib/core/status.h" namespace tensorflow { @@ -50,6 +51,14 @@ Status ValidateGraphDefAgainstOpList(const GraphDef& graph_def, void GetOpListForValidation( OpList* op_list, const OpRegistry& op_registry = *OpRegistry::Global()); +// Validate that the graph has no cycle except for legal while loop cycles. +// This traverses the specified nodes in topological order to verify there are +// no cycles. Starting with inputless nodes, it visits nodes whose inputs have +// all been visited, and counts the total number of visited nodes. If there is a +// cycle, nodes in the cycle will never be visited, and the visited count will +// be less than the total node count. +Status ValidateGraphHasNoCycle(const Graph& graph); + } // namespace graph } // namespace tensorflow |