aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c/c_api.cc
diff options
context:
space:
mode:
authorGravatar Olivia Nordquist <nolivia@google.com>2017-09-06 14:29:51 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-06 14:34:14 -0700
commitf8d4b5eae59124c9a4b01a4cc1097e1f0006137a (patch)
treea98c959b859b2c0c02312311544a2eb47ff83379 /tensorflow/c/c_api.cc
parentd528a52baa3562cbf952a5be71ebdc38e1e5c551 (diff)
detecting cycles when users add a control edge to a graph
PiperOrigin-RevId: 167773598
Diffstat (limited to 'tensorflow/c/c_api.cc')
-rw-r--r--tensorflow/c/c_api.cc65
1 files changed, 65 insertions, 0 deletions
diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc
index c454c94249..334f867e47 100644
--- a/tensorflow/c/c_api.cc
+++ b/tensorflow/c/c_api.cc
@@ -374,6 +374,65 @@ void TF_Reset_Helper(const TF_SessionOptions* opt, const char** containers,
status->status = Reset(opt->options, container_names);
}
+// This traverses the specified nodes in topological order to verify there are
+// no cycles. Starting with inputless nodes, it visits nodes whose inputs have
+// all been visited, and counts the total number of visited nodes. If there is a
+// cycle, nodes in the cycle will never be visited, and the visited count will
+// be less than the total node count.
+Status ValidateNoCycles(const Graph& g) {
+ // TODO(nolivia): check this on a subset of the graph instead of all of it.
+ int total_num_nodes = g.num_node_ids();
+ // A node is ready when all of its inputs have been visited.
+ std::vector<const Node*> ready;
+ std::vector<int> pending_count(total_num_nodes, 0);
+
+ for (int i = 0; i < total_num_nodes; ++i) {
+ const Node* n = g.FindNodeId(i);
+ if (n == nullptr) continue;
+ pending_count[i] = n->in_edges().size();
+ if (n->IsMerge()) {
+ // While-loop cycles are legal cycles so we manually adjust the
+ // pending_count to make sure that the loop is visited.
+ for (const Edge* e : n->in_edges()) {
+ if (!e->IsControlEdge() && e->src()->IsNextIteration()) {
+ pending_count[i]--;
+ }
+ }
+ }
+ if (pending_count[i] == 0) {
+ ready.push_back(n);
+ }
+ }
+
+ int processed = 0;
+ while (!ready.empty()) {
+ const Node* node = ready.back();
+ ready.pop_back();
+ ++processed;
+
+ for (const Edge* out : node->out_edges()) {
+ const int output_id = out->dst()->id();
+ pending_count[output_id]--;
+ if (pending_count[output_id] == 0) {
+ ready.push_back(out->dst());
+ }
+ }
+ }
+
+ if (processed < total_num_nodes) {
+ std::vector<string> nodes_in_cycle;
+ for (int i = 0; i < pending_count.size() && nodes_in_cycle.size() < 3;
+ ++i) {
+ if (pending_count[i] != 0) {
+ nodes_in_cycle.push_back(g.FindNodeId(i)->name());
+ }
+ }
+ return errors::InvalidArgument(
+ "Graph is invalid, contains a cycle with ", total_num_nodes - processed,
+ " nodes, including: ", str_util::Join(nodes_in_cycle, ", "));
+ }
+ return Status::OK();
+}
} // namespace
} // namespace tensorflow
@@ -2251,6 +2310,12 @@ static bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) {
const Graph& graph = session->graph->graph;
const auto num_nodes = graph.num_node_ids();
if (session->last_num_graph_nodes < num_nodes) {
+ status->status = tensorflow::ValidateNoCycles(session->graph->graph);
+ if (!status->status.ok()) {
+ session->graph->mu.unlock();
+ return false;
+ }
+
GraphDef graph_def;
*graph_def.mutable_versions() = graph.versions();
// Fill graph_def with nodes with ids in the range