diff options
author | 2018-02-13 17:52:50 -0800 | |
---|---|---|
committer | 2018-02-13 17:55:43 -0800 | |
commit | 8742df3a115e0714fb25fb7ae199de93be205ab0 (patch) | |
tree | ef951ed4ca9ac8b6aeb0492ab1b6e27ba01c82fe /tensorflow/python/debug | |
parent | d0f4faeb2de8843d6996cba96c70af29feba3876 (diff) |
TensorBoard debugger plugin: SIGINT handler for easier termination of debugged runtime
from TensorBoardDebugWrapperSession and TensorBoardDebugHook.
PiperOrigin-RevId: 185617989
Diffstat (limited to 'tensorflow/python/debug')
-rw-r--r-- | tensorflow/python/debug/wrappers/grpc_wrapper.py | 27 | ||||
-rw-r--r-- | tensorflow/python/debug/wrappers/hooks.py | 1 |
2 files changed, 28 insertions, 0 deletions
diff --git a/tensorflow/python/debug/wrappers/grpc_wrapper.py b/tensorflow/python/debug/wrappers/grpc_wrapper.py index 74d7c2b9e2..fb9494f576 100644 --- a/tensorflow/python/debug/wrappers/grpc_wrapper.py +++ b/tensorflow/python/debug/wrappers/grpc_wrapper.py @@ -17,6 +17,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import signal +import sys import traceback # Google-internal import(s). @@ -137,6 +139,29 @@ class GrpcDebugWrapperSession(framework.NonInteractiveDebugWrapperSession): if not address.startswith(common.GRPC_URL_PREFIX) else address) +def _signal_handler(unused_signal, unused_frame): + try: + input_func = raw_input + except NameError: + # Python 3 does not have raw_input. + input_func = input + + while True: + response = input_func("\nSIGINT received. Quit program? (Y/n): ").strip() + if response in ("", "Y", "y"): + sys.exit(0) + elif response in ("N", "n"): + break + + +def register_signal_handler(): + try: + signal.signal(signal.SIGINT, _signal_handler) + except ValueError: + # This can happen if we are not in the MainThread. + pass + + class TensorBoardDebugWrapperSession(GrpcDebugWrapperSession): """A tfdbg Session wrapper that can be used with TensorBoard Debugger Plugin. @@ -185,6 +210,8 @@ class TensorBoardDebugWrapperSession(GrpcDebugWrapperSession): # sent to the debug servers. self._sent_graph_version = -1 + register_signal_handler() + def run(self, fetches, feed_dict=None, diff --git a/tensorflow/python/debug/wrappers/hooks.py b/tensorflow/python/debug/wrappers/hooks.py index 0204254cca..6705cd31e2 100644 --- a/tensorflow/python/debug/wrappers/hooks.py +++ b/tensorflow/python/debug/wrappers/hooks.py @@ -345,6 +345,7 @@ class TensorBoardDebugHook(GrpcDebugHook): self._grpc_debug_server_addresses = grpc_debug_server_addresses self._send_traceback_and_source_code = send_traceback_and_source_code self._sent_graph_version = -1 + grpc_wrapper.register_signal_handler() def before_run(self, run_context): if self._send_traceback_and_source_code: |