diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-06-22 17:31:32 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-22 17:36:31 -0700 |
commit | 6dac3f07fe4c69864a39a8a7639ee314d608d38d (patch) | |
tree | c8493b70184b26e27833507b6f09826127a2f19b /tensorflow/core/graph | |
parent | 357522a0e73a313eff682c1e4b7b485090518c4a (diff) |
Move cycle detection helper function from c/c_api to core/graph/validate.
PiperOrigin-RevId: 201766085
Diffstat (limited to 'tensorflow/core/graph')
-rw-r--r-- | tensorflow/core/graph/validate.cc | 54 | ||||
-rw-r--r-- | tensorflow/core/graph/validate.h | 9 |
2 files changed, 63 insertions, 0 deletions
diff --git a/tensorflow/core/graph/validate.cc b/tensorflow/core/graph/validate.cc index bd905651d2..e44eb91d48 100644 --- a/tensorflow/core/graph/validate.cc +++ b/tensorflow/core/graph/validate.cc @@ -59,5 +59,59 @@ void GetOpListForValidation(OpList* op_list, const OpRegistry& op_registry) { RemoveDescriptionsFromOpList(op_list); } +Status ValidateGraphHasNoCycle(const Graph& graph) { + // A node is ready when all of its inputs have been visited. + std::vector<const Node*> ready; + std::vector<int> pending_count(graph.num_node_ids(), 0); + + for (int i = 0; i < graph.num_node_ids(); ++i) { + const Node* n = graph.FindNodeId(i); + if (n == nullptr) continue; + pending_count[i] = n->in_edges().size(); + if (n->IsMerge()) { + // While-loop cycles are legal cycles so we manually adjust the + // pending_count to make sure that the loop is visited. + for (const Edge* e : n->in_edges()) { + if (!e->IsControlEdge() && e->src()->IsNextIteration()) { + pending_count[i]--; + } + } + } + if (pending_count[i] == 0) { + ready.push_back(n); + } + } + + int processed = 0; + while (!ready.empty()) { + const Node* node = ready.back(); + ready.pop_back(); + ++processed; + + for (const Edge* out : node->out_edges()) { + const int output_id = out->dst()->id(); + pending_count[output_id]--; + if (pending_count[output_id] == 0) { + ready.push_back(out->dst()); + } + } + } + + if (processed < graph.num_nodes()) { + std::vector<string> nodes_in_cycle; + for (int i = 0; i < pending_count.size() && nodes_in_cycle.size() < 3; + ++i) { + if (pending_count[i] != 0) { + nodes_in_cycle.push_back(graph.FindNodeId(i)->name()); + } + } + return errors::InvalidArgument( + "Graph is invalid, contains a cycle with ", + graph.num_nodes() - processed, + " nodes, including: ", str_util::Join(nodes_in_cycle, ", ")); + } + return Status::OK(); +} + } // namespace graph } // namespace tensorflow diff --git a/tensorflow/core/graph/validate.h b/tensorflow/core/graph/validate.h index cda93fe1de..08879dca60 100644 --- a/tensorflow/core/graph/validate.h +++ b/tensorflow/core/graph/validate.h @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/op.h" +#include "tensorflow/core/graph/graph.h" #include "tensorflow/core/lib/core/status.h" namespace tensorflow { @@ -50,6 +51,14 @@ Status ValidateGraphDefAgainstOpList(const GraphDef& graph_def, void GetOpListForValidation( OpList* op_list, const OpRegistry& op_registry = *OpRegistry::Global()); +// Validate that the graph has no cycle except for legal while loop cycles. +// This traverses the specified nodes in topological order to verify there are +// no cycles. Starting with inputless nodes, it visits nodes whose inputs have +// all been visited, and counts the total number of visited nodes. If there is a +// cycle, nodes in the cycle will never be visited, and the visited count will +// be less than the total node count. +Status ValidateGraphHasNoCycle(const Graph& graph); + } // namespace graph } // namespace tensorflow |