aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Shanqing Cai <cais@google.com>2017-02-14 10:21:02 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-02-14 10:36:50 -0800
commit2b2f26b08874524a913cfd8cb3d416d6d48ee156 (patch)
tree96283461d5e7f32bc4ca738124a2fbcbaec95b61
parent3e8728dbbc544d60ba731c0c973d2bfa7672257e (diff)
tfdbg bug fix: avoid hanging while watching Enter and NextIteration ops in while loops
Change: 147488620
-rw-r--r--tensorflow/core/debug/debug_graph_utils.cc8
-rw-r--r--tensorflow/python/debug/lib/session_debug_testlib.py47
2 files changed, 52 insertions, 3 deletions
diff --git a/tensorflow/core/debug/debug_graph_utils.cc b/tensorflow/core/debug/debug_graph_utils.cc
index 4e7a2b4bb8..6f5fe945b7 100644
--- a/tensorflow/core/debug/debug_graph_utils.cc
+++ b/tensorflow/core/debug/debug_graph_utils.cc
@@ -231,10 +231,12 @@ Status DebugNodeInserter::InsertNodes(
// Add control edges from the debug nodes to the destination node
// to ensure that the debug nodes are executed before the destination
- // node.
+ // node. Skip Enter and NextIteration ops to avoid hanging.
for (Node* debug_node : debug_nodes) {
- graph->AddEdge(debug_node, Graph::kControlSlot, edge->dst(),
- Graph::kControlSlot);
+ if (!src_node->IsEnter() && !src_node->IsNextIteration()) {
+ graph->AddEdge(debug_node, Graph::kControlSlot, edge->dst(),
+ Graph::kControlSlot);
+ }
}
}
}
diff --git a/tensorflow/python/debug/lib/session_debug_testlib.py b/tensorflow/python/debug/lib/session_debug_testlib.py
index 3f74d96c84..2042a3bea0 100644
--- a/tensorflow/python/debug/lib/session_debug_testlib.py
+++ b/tensorflow/python/debug/lib/session_debug_testlib.py
@@ -410,6 +410,53 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
self.assertEqual(10, len(dump.watch_key_to_data(watch_keys[0])))
self.assertEqual([], dump.watch_key_to_data("foo"))
+ def testDebugWhileLoopWatchingWholeGraphWorks(self):
+ with session.Session() as sess:
+ loop_body = lambda i: math_ops.add(i, 2)
+ loop_cond = lambda i: math_ops.less(i, 16)
+
+ i = constant_op.constant(10, name="i")
+ loop = control_flow_ops.while_loop(loop_cond, loop_body, [i])
+
+ run_options = config_pb2.RunOptions(output_partition_graphs=True)
+ debug_utils.watch_graph(run_options,
+ sess.graph,
+ debug_urls=self._debug_urls())
+ run_metadata = config_pb2.RunMetadata()
+ self.assertEqual(
+ 16, sess.run(loop, options=run_options, run_metadata=run_metadata))
+
+ dump = debug_data.DebugDumpDir(
+ self._dump_root, partition_graphs=run_metadata.partition_graphs)
+
+ self.assertEqual(
+ [[10]], dump.get_tensors("while/Enter", 0, "DebugIdentity"))
+ self.assertEqual(
+ [[12], [14], [16]],
+ dump.get_tensors("while/NextIteration", 0, "DebugIdentity"))
+
+ def testDebugCondWatchingWholeGraphWorks(self):
+ with session.Session() as sess:
+ x = variables.Variable(10.0, name="x")
+ y = variables.Variable(20.0, name="y")
+ cond = control_flow_ops.cond(
+ x > y, lambda: math_ops.add(x, 1), lambda: math_ops.add(y, 1))
+
+ sess.run(variables.global_variables_initializer())
+
+ run_options = config_pb2.RunOptions(output_partition_graphs=True)
+ debug_utils.watch_graph(run_options,
+ sess.graph,
+ debug_urls=self._debug_urls())
+ run_metadata = config_pb2.RunMetadata()
+ self.assertEqual(
+ 21, sess.run(cond, options=run_options, run_metadata=run_metadata))
+
+ dump = debug_data.DebugDumpDir(
+ self._dump_root, partition_graphs=run_metadata.partition_graphs)
+ self.assertAllClose(
+ [21.0], dump.get_tensors("cond/Merge", 0, "DebugIdentity"))
+
def testFindNodesWithBadTensorValues(self):
with session.Session() as sess:
u_name = "testFindNodesWithBadTensorValues/u"