diff options
author | Shanqing Cai <cais@google.com> | 2017-02-14 10:21:02 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-02-14 10:36:50 -0800 |
commit | 2b2f26b08874524a913cfd8cb3d416d6d48ee156 (patch) | |
tree | 96283461d5e7f32bc4ca738124a2fbcbaec95b61 | |
parent | 3e8728dbbc544d60ba731c0c973d2bfa7672257e (diff) |
tfdbg bug fix: avoid hanging while watching Enter and NextIteration ops in while loops
Change: 147488620
-rw-r--r-- | tensorflow/core/debug/debug_graph_utils.cc | 8 | ||||
-rw-r--r-- | tensorflow/python/debug/lib/session_debug_testlib.py | 47 |
2 files changed, 52 insertions, 3 deletions
diff --git a/tensorflow/core/debug/debug_graph_utils.cc b/tensorflow/core/debug/debug_graph_utils.cc index 4e7a2b4bb8..6f5fe945b7 100644 --- a/tensorflow/core/debug/debug_graph_utils.cc +++ b/tensorflow/core/debug/debug_graph_utils.cc @@ -231,10 +231,12 @@ Status DebugNodeInserter::InsertNodes( // Add control edges from the debug nodes to the destination node // to ensure that the debug nodes are executed before the destination - // node. + // node. Skip Enter and NextIteration ops to avoid hanging. for (Node* debug_node : debug_nodes) { - graph->AddEdge(debug_node, Graph::kControlSlot, edge->dst(), - Graph::kControlSlot); + if (!src_node->IsEnter() && !src_node->IsNextIteration()) { + graph->AddEdge(debug_node, Graph::kControlSlot, edge->dst(), + Graph::kControlSlot); + } } } } diff --git a/tensorflow/python/debug/lib/session_debug_testlib.py b/tensorflow/python/debug/lib/session_debug_testlib.py index 3f74d96c84..2042a3bea0 100644 --- a/tensorflow/python/debug/lib/session_debug_testlib.py +++ b/tensorflow/python/debug/lib/session_debug_testlib.py @@ -410,6 +410,53 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase): self.assertEqual(10, len(dump.watch_key_to_data(watch_keys[0]))) self.assertEqual([], dump.watch_key_to_data("foo")) + def testDebugWhileLoopWatchingWholeGraphWorks(self): + with session.Session() as sess: + loop_body = lambda i: math_ops.add(i, 2) + loop_cond = lambda i: math_ops.less(i, 16) + + i = constant_op.constant(10, name="i") + loop = control_flow_ops.while_loop(loop_cond, loop_body, [i]) + + run_options = config_pb2.RunOptions(output_partition_graphs=True) + debug_utils.watch_graph(run_options, + sess.graph, + debug_urls=self._debug_urls()) + run_metadata = config_pb2.RunMetadata() + self.assertEqual( + 16, sess.run(loop, options=run_options, run_metadata=run_metadata)) + + dump = debug_data.DebugDumpDir( + self._dump_root, partition_graphs=run_metadata.partition_graphs) + + self.assertEqual( + [[10]], dump.get_tensors("while/Enter", 0, "DebugIdentity")) + self.assertEqual( + [[12], [14], [16]], + dump.get_tensors("while/NextIteration", 0, "DebugIdentity")) + + def testDebugCondWatchingWholeGraphWorks(self): + with session.Session() as sess: + x = variables.Variable(10.0, name="x") + y = variables.Variable(20.0, name="y") + cond = control_flow_ops.cond( + x > y, lambda: math_ops.add(x, 1), lambda: math_ops.add(y, 1)) + + sess.run(variables.global_variables_initializer()) + + run_options = config_pb2.RunOptions(output_partition_graphs=True) + debug_utils.watch_graph(run_options, + sess.graph, + debug_urls=self._debug_urls()) + run_metadata = config_pb2.RunMetadata() + self.assertEqual( + 21, sess.run(cond, options=run_options, run_metadata=run_metadata)) + + dump = debug_data.DebugDumpDir( + self._dump_root, partition_graphs=run_metadata.partition_graphs) + self.assertAllClose( + [21.0], dump.get_tensors("cond/Merge", 0, "DebugIdentity")) + def testFindNodesWithBadTensorValues(self): with session.Session() as sess: u_name = "testFindNodesWithBadTensorValues/u" |