aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/graph
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-22 17:31:32 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-22 17:36:31 -0700
commit6dac3f07fe4c69864a39a8a7639ee314d608d38d (patch)
treec8493b70184b26e27833507b6f09826127a2f19b /tensorflow/core/graph
parent357522a0e73a313eff682c1e4b7b485090518c4a (diff)
Move cycle detection helper function from c/c_api to core/graph/validate.
PiperOrigin-RevId: 201766085
Diffstat (limited to 'tensorflow/core/graph')
-rw-r--r--tensorflow/core/graph/validate.cc54
-rw-r--r--tensorflow/core/graph/validate.h9
2 files changed, 63 insertions, 0 deletions
diff --git a/tensorflow/core/graph/validate.cc b/tensorflow/core/graph/validate.cc
index bd905651d2..e44eb91d48 100644
--- a/tensorflow/core/graph/validate.cc
+++ b/tensorflow/core/graph/validate.cc
@@ -59,5 +59,59 @@ void GetOpListForValidation(OpList* op_list, const OpRegistry& op_registry) {
RemoveDescriptionsFromOpList(op_list);
}
+Status ValidateGraphHasNoCycle(const Graph& graph) {
+ // A node is ready when all of its inputs have been visited.
+ std::vector<const Node*> ready;
+ std::vector<int> pending_count(graph.num_node_ids(), 0);
+
+ for (int i = 0; i < graph.num_node_ids(); ++i) {
+ const Node* n = graph.FindNodeId(i);
+ if (n == nullptr) continue;
+ pending_count[i] = n->in_edges().size();
+ if (n->IsMerge()) {
+ // While-loop cycles are legal cycles so we manually adjust the
+ // pending_count to make sure that the loop is visited.
+ for (const Edge* e : n->in_edges()) {
+ if (!e->IsControlEdge() && e->src()->IsNextIteration()) {
+ pending_count[i]--;
+ }
+ }
+ }
+ if (pending_count[i] == 0) {
+ ready.push_back(n);
+ }
+ }
+
+ int processed = 0;
+ while (!ready.empty()) {
+ const Node* node = ready.back();
+ ready.pop_back();
+ ++processed;
+
+ for (const Edge* out : node->out_edges()) {
+ const int output_id = out->dst()->id();
+ pending_count[output_id]--;
+ if (pending_count[output_id] == 0) {
+ ready.push_back(out->dst());
+ }
+ }
+ }
+
+ if (processed < graph.num_nodes()) {
+ std::vector<string> nodes_in_cycle;
+ for (int i = 0; i < pending_count.size() && nodes_in_cycle.size() < 3;
+ ++i) {
+ if (pending_count[i] != 0) {
+ nodes_in_cycle.push_back(graph.FindNodeId(i)->name());
+ }
+ }
+ return errors::InvalidArgument(
+ "Graph is invalid, contains a cycle with ",
+ graph.num_nodes() - processed,
+ " nodes, including: ", str_util::Join(nodes_in_cycle, ", "));
+ }
+ return Status::OK();
+}
+
} // namespace graph
} // namespace tensorflow
diff --git a/tensorflow/core/graph/validate.h b/tensorflow/core/graph/validate.h
index cda93fe1de..08879dca60 100644
--- a/tensorflow/core/graph/validate.h
+++ b/tensorflow/core/graph/validate.h
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
@@ -50,6 +51,14 @@ Status ValidateGraphDefAgainstOpList(const GraphDef& graph_def,
void GetOpListForValidation(
OpList* op_list, const OpRegistry& op_registry = *OpRegistry::Global());
+// Validate that the graph has no cycle except for legal while loop cycles.
+// This traverses the specified nodes in topological order to verify there are
+// no cycles. Starting with inputless nodes, it visits nodes whose inputs have
+// all been visited, and counts the total number of visited nodes. If there is a
+// cycle, nodes in the cycle will never be visited, and the visited count will
+// be less than the total node count.
+Status ValidateGraphHasNoCycle(const Graph& graph);
+
} // namespace graph
} // namespace tensorflow