diff options
-rw-r--r-- | tensorflow/python/debug/wrappers/local_cli_wrapper.py | 7 | ||||
-rw-r--r-- | tensorflow/python/debug/wrappers/local_cli_wrapper_test.py | 11 |
2 files changed, 18 insertions, 0 deletions
diff --git a/tensorflow/python/debug/wrappers/local_cli_wrapper.py b/tensorflow/python/debug/wrappers/local_cli_wrapper.py index ae5e92450a..1aab95152a 100644 --- a/tensorflow/python/debug/wrappers/local_cli_wrapper.py +++ b/tensorflow/python/debug/wrappers/local_cli_wrapper.py @@ -264,6 +264,13 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession): elif request.client_graph_def: partition_graphs = [request.client_graph_def] + if request.tf_error and not os.path.isdir(self._dump_root): + # It is possible that the dump root may not exist due to errors that + # have occurred prior to graph execution (e.g., invalid device + # assignments), in which case we will just raise the exception as the + # unwrapped Session does. + raise request.tf_error + debug_dump = debug_data.DebugDumpDir( self._dump_root, partition_graphs=partition_graphs) debug_dump.set_python_graph(self._sess.graph) diff --git a/tensorflow/python/debug/wrappers/local_cli_wrapper_test.py b/tensorflow/python/debug/wrappers/local_cli_wrapper_test.py index 67971be3d3..01578dcb2a 100644 --- a/tensorflow/python/debug/wrappers/local_cli_wrapper_test.py +++ b/tensorflow/python/debug/wrappers/local_cli_wrapper_test.py @@ -27,6 +27,7 @@ from tensorflow.python.debug.cli import debugger_cli_common from tensorflow.python.debug.wrappers import local_cli_wrapper from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops @@ -313,6 +314,16 @@ class LocalCLIDebugWrapperSessionTest(test_util.TensorFlowTestCase): tf_error = wrapped_sess.observers["tf_errors"][0] self.assertEqual("y", tf_error.op.name) + def testRuntimeErrorBeforeGraphExecutionIsRaised(self): + # Use an impossible device name to cause an error before graph execution. + with ops.device("/gpu:1337"): + w = variables.Variable([1.0] * 10, name="w") + + wrapped_sess = LocalCLIDebuggerWrapperSessionForTest( + [[]], self.sess, dump_root=self._tmp_dir) + with self.assertRaisesRegexp(errors.OpError, r".*[Dd]evice.*1337.*"): + wrapped_sess.run(w) + def testRunTillFilterPassesShouldLaunchCLIAtCorrectRun(self): # Test command sequence: # run -f greater_than_twelve; run -f greater_than_twelve; run; |