aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/python/debug/wrappers/local_cli_wrapper.py7
-rw-r--r--tensorflow/python/debug/wrappers/local_cli_wrapper_test.py11
2 files changed, 18 insertions, 0 deletions
diff --git a/tensorflow/python/debug/wrappers/local_cli_wrapper.py b/tensorflow/python/debug/wrappers/local_cli_wrapper.py
index ae5e92450a..1aab95152a 100644
--- a/tensorflow/python/debug/wrappers/local_cli_wrapper.py
+++ b/tensorflow/python/debug/wrappers/local_cli_wrapper.py
@@ -264,6 +264,13 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession):
elif request.client_graph_def:
partition_graphs = [request.client_graph_def]
+ if request.tf_error and not os.path.isdir(self._dump_root):
+ # It is possible that the dump root may not exist due to errors that
+ # have occurred prior to graph execution (e.g., invalid device
+ # assignments), in which case we will just raise the exception as the
+ # unwrapped Session does.
+ raise request.tf_error
+
debug_dump = debug_data.DebugDumpDir(
self._dump_root, partition_graphs=partition_graphs)
debug_dump.set_python_graph(self._sess.graph)
diff --git a/tensorflow/python/debug/wrappers/local_cli_wrapper_test.py b/tensorflow/python/debug/wrappers/local_cli_wrapper_test.py
index 67971be3d3..01578dcb2a 100644
--- a/tensorflow/python/debug/wrappers/local_cli_wrapper_test.py
+++ b/tensorflow/python/debug/wrappers/local_cli_wrapper_test.py
@@ -27,6 +27,7 @@ from tensorflow.python.debug.cli import debugger_cli_common
from tensorflow.python.debug.wrappers import local_cli_wrapper
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
@@ -313,6 +314,16 @@ class LocalCLIDebugWrapperSessionTest(test_util.TensorFlowTestCase):
tf_error = wrapped_sess.observers["tf_errors"][0]
self.assertEqual("y", tf_error.op.name)
+ def testRuntimeErrorBeforeGraphExecutionIsRaised(self):
+ # Use an impossible device name to cause an error before graph execution.
+ with ops.device("/gpu:1337"):
+ w = variables.Variable([1.0] * 10, name="w")
+
+ wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
+ [[]], self.sess, dump_root=self._tmp_dir)
+ with self.assertRaisesRegexp(errors.OpError, r".*[Dd]evice.*1337.*"):
+ wrapped_sess.run(w)
+
def testRunTillFilterPassesShouldLaunchCLIAtCorrectRun(self):
# Test command sequence:
# run -f greater_than_twelve; run -f greater_than_twelve; run;