aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/debug
diff options
context:
space:
mode:
authorGravatar Shanqing Cai <cais@google.com>2018-01-29 10:22:10 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-29 10:33:05 -0800
commit1f26c65254268730b7409f517d1ed1b554d01e50 (patch)
treeac62a070800ecd3c6fa057ada2ba463cbf97dc80 /tensorflow/python/debug
parent0905a7ed035e66f3abdb123ab53cb0c640e60f0b (diff)
tfdbg: let session wrappers handle empty fetches correctly
Fixes: #15882 PiperOrigin-RevId: 183685645
Diffstat (limited to 'tensorflow/python/debug')
-rw-r--r--tensorflow/python/debug/wrappers/dumping_wrapper_test.py5
-rw-r--r--tensorflow/python/debug/wrappers/framework.py9
-rw-r--r--tensorflow/python/debug/wrappers/local_cli_wrapper_test.py14
3 files changed, 27 insertions, 1 deletions
diff --git a/tensorflow/python/debug/wrappers/dumping_wrapper_test.py b/tensorflow/python/debug/wrappers/dumping_wrapper_test.py
index acea9433e2..254201c393 100644
--- a/tensorflow/python/debug/wrappers/dumping_wrapper_test.py
+++ b/tensorflow/python/debug/wrappers/dumping_wrapper_test.py
@@ -389,6 +389,11 @@ class DumpingDebugWrapperSessionTest(test_util.TensorFlowTestCase):
r"mode\."):
sess.invoke_node_stepper(node_stepper)
+ def testDumpingWrapperWithEmptyFetchWorks(self):
+ sess = dumping_wrapper.DumpingDebugWrapperSession(
+ self.sess, session_root=self.session_root, log_usage=False)
+ sess.run([])
+
if __name__ == "__main__":
googletest.main()
diff --git a/tensorflow/python/debug/wrappers/framework.py b/tensorflow/python/debug/wrappers/framework.py
index 909150eb6a..c530204bbf 100644
--- a/tensorflow/python/debug/wrappers/framework.py
+++ b/tensorflow/python/debug/wrappers/framework.py
@@ -121,7 +121,9 @@ from tensorflow.python.debug.lib import debug_utils
from tensorflow.python.debug.lib import stepper
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
+from tensorflow.python.platform import tf_logging
from tensorflow.python.training import monitored_session
+from tensorflow.python.util import nest
# Helper function.
@@ -439,7 +441,12 @@ class BaseDebugWrapperSession(session.SessionInterface):
"callable_runner and fetches/feed_dict are mutually exclusive, but "
"are used simultaneously.")
- if self._is_disabled_thread():
+ empty_fetches = not nest.flatten(fetches)
+ if empty_fetches:
+ tf_logging.info(
+ "Due to empty fetches, tfdbg Session wrapper is letting a "
+ "Session.run pass through without any debugging actions.")
+ if self._is_disabled_thread() or empty_fetches:
if callable_runner:
return callable_runner(*callable_runner_args)
else:
diff --git a/tensorflow/python/debug/wrappers/local_cli_wrapper_test.py b/tensorflow/python/debug/wrappers/local_cli_wrapper_test.py
index 770a496aa9..490812c96d 100644
--- a/tensorflow/python/debug/wrappers/local_cli_wrapper_test.py
+++ b/tensorflow/python/debug/wrappers/local_cli_wrapper_test.py
@@ -664,6 +664,20 @@ class LocalCLIDebugWrapperSessionTest(test_util.TensorFlowTestCase):
[["run"], ["run"]], monitored_sess)
self.assertFalse(wrapped_monitored_sess.should_stop())
+ def testRunsWithEmptyFetchWorks(self):
+ wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
+ [["run"]], self.sess, dump_root="")
+
+ run_output = wrapped_sess.run([])
+ self.assertEqual([], run_output)
+
+ def testRunsWithEmptyNestedFetchWorks(self):
+ wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
+ [["run"]], self.sess, dump_root="")
+
+ run_output = wrapped_sess.run({"foo": {"baz": []}, "bar": ()})
+ self.assertEqual({"foo": {"baz": []}, "bar": ()}, run_output)
+
if __name__ == "__main__":
googletest.main()