aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/debug
diff options
context:
space:
mode:
authorGravatar Shanqing Cai <cais@google.com>2018-05-24 14:02:30 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-24 14:05:21 -0700
commitd2090672fe8305289156460c43f7fcc1a5dd5422 (patch)
tree9fc8dc6b47beaac3d1abfeca10e92fc6748f94e8 /tensorflow/python/debug
parent51645f15b3854447c887abf0e92d0465d79ea92c (diff)
tfdbg: fix issue where total source file size exceeds gRPC message size limit
* Source file content is now sent one by one, making it less likely that individual messages will have sizes above the 4-MB gRPC message size limit. * In case the message for a single source file exceeds the limit, the client handles it gracefully by skipping the sending and print a warning message. Fixes: https://github.com/tensorflow/tensorboard/issues/1118 PiperOrigin-RevId: 197949416
Diffstat (limited to 'tensorflow/python/debug')
-rw-r--r--tensorflow/python/debug/BUILD1
-rw-r--r--tensorflow/python/debug/lib/grpc_debug_test_server.py13
-rw-r--r--tensorflow/python/debug/lib/source_remote.py23
-rw-r--r--tensorflow/python/debug/lib/source_remote_test.py46
4 files changed, 74 insertions, 9 deletions
diff --git a/tensorflow/python/debug/BUILD b/tensorflow/python/debug/BUILD
index 16ae74a19f..09062abd74 100644
--- a/tensorflow/python/debug/BUILD
+++ b/tensorflow/python/debug/BUILD
@@ -572,6 +572,7 @@ py_test(
":source_utils",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:client",
+ "//tensorflow/python:client_testlib",
"//tensorflow/python:framework_ops",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:math_ops",
diff --git a/tensorflow/python/debug/lib/grpc_debug_test_server.py b/tensorflow/python/debug/lib/grpc_debug_test_server.py
index 9170046948..a7be20948d 100644
--- a/tensorflow/python/debug/lib/grpc_debug_test_server.py
+++ b/tensorflow/python/debug/lib/grpc_debug_test_server.py
@@ -245,7 +245,7 @@ class EventListenerTestServicer(grpc_debug_server.EventListenerBaseServicer):
self._origin_id_to_strings = []
self._graph_tracebacks = []
self._graph_versions = []
- self._source_files = None
+ self._source_files = []
def _initialize_toggle_watch_state(self, toggle_watches):
self._toggle_watches = toggle_watches
@@ -274,7 +274,7 @@ class EventListenerTestServicer(grpc_debug_server.EventListenerBaseServicer):
self._origin_id_to_strings = []
self._graph_tracebacks = []
self._graph_versions = []
- self._source_files = None
+ self._source_files = []
def SendTracebacks(self, request, context):
self._call_types.append(request.call_type)
@@ -286,7 +286,7 @@ class EventListenerTestServicer(grpc_debug_server.EventListenerBaseServicer):
return debug_service_pb2.EventReply()
def SendSourceFiles(self, request, context):
- self._source_files = request
+ self._source_files.append(request)
return debug_service_pb2.EventReply()
def query_op_traceback(self, op_name):
@@ -351,9 +351,10 @@ class EventListenerTestServicer(grpc_debug_server.EventListenerBaseServicer):
if not self._source_files:
raise ValueError(
"This debug server has not received any source file contents yet.")
- for source_file_proto in self._source_files.source_files:
- if source_file_proto.file_path == file_path:
- return source_file_proto.lines[lineno - 1]
+ for source_files in self._source_files:
+ for source_file_proto in source_files.source_files:
+ if source_file_proto.file_path == file_path:
+ return source_file_proto.lines[lineno - 1]
raise ValueError(
"Source file at path %s has not been received by the debug server",
file_path)
diff --git a/tensorflow/python/debug/lib/source_remote.py b/tensorflow/python/debug/lib/source_remote.py
index 4b6b2b995e..4afae41bc9 100644
--- a/tensorflow/python/debug/lib/source_remote.py
+++ b/tensorflow/python/debug/lib/source_remote.py
@@ -28,6 +28,7 @@ from tensorflow.python.debug.lib import common
from tensorflow.python.debug.lib import debug_service_pb2_grpc
from tensorflow.python.debug.lib import source_utils
from tensorflow.python.platform import gfile
+from tensorflow.python.platform import tf_logging
from tensorflow.python.profiler import tfprof_logger
@@ -95,6 +96,11 @@ def _source_file_paths_outside_tensorflow_py_library(code_defs, id_to_string):
return non_tf_files
+def grpc_message_length_bytes():
+ """Maximum gRPC message length in bytes."""
+ return 4 * 1024 * 1024
+
+
def _send_call_tracebacks(destinations,
origin_stack,
is_eager_execution=False,
@@ -155,17 +161,28 @@ def _send_call_tracebacks(destinations,
source_file_paths.update(_source_file_paths_outside_tensorflow_py_library(
[call_traceback.origin_stack], call_traceback.origin_id_to_string))
- debugged_source_files = debug_pb2.DebuggedSourceFiles()
+ debugged_source_files = []
for file_path in source_file_paths:
+ source_files = debug_pb2.DebuggedSourceFiles()
_load_debugged_source_file(
- file_path, debugged_source_files.source_files.add())
+ file_path, source_files.source_files.add())
+ debugged_source_files.append(source_files)
for destination in destinations:
channel = grpc.insecure_channel(destination)
stub = debug_service_pb2_grpc.EventListenerStub(channel)
stub.SendTracebacks(call_traceback)
if send_source:
- stub.SendSourceFiles(debugged_source_files)
+ for path, source_files in zip(
+ source_file_paths, debugged_source_files):
+ if source_files.ByteSize() < grpc_message_length_bytes():
+ stub.SendSourceFiles(source_files)
+ else:
+ tf_logging.warn(
+ "The content of the source file at %s is not sent to "
+ "gRPC debug server %s, because the message size exceeds "
+ "gRPC message length limit (%d bytes)." % (
+ path, destination, grpc_message_length_bytes()))
def send_graph_tracebacks(destinations,
diff --git a/tensorflow/python/debug/lib/source_remote_test.py b/tensorflow/python/debug/lib/source_remote_test.py
index 27bafa45e1..29add425e9 100644
--- a/tensorflow/python/debug/lib/source_remote_test.py
+++ b/tensorflow/python/debug/lib/source_remote_test.py
@@ -33,6 +33,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops # pylint: disable=unused-import
from tensorflow.python.ops import variables
from tensorflow.python.platform import googletest
+from tensorflow.python.platform import test
from tensorflow.python.util import tf_inspect
@@ -155,6 +156,51 @@ class SendTracebacksTest(test_util.TensorFlowTestCase):
self.assertEqual(["dummy_run_key"], server.query_call_keys())
self.assertEqual([sess.graph.version], server.query_graph_versions())
+ def testSourceFileSizeExceedsGrpcMessageLengthLimit(self):
+ """In case source file size exceeds the grpc message length limit.
+
+ it ought not to have been sent to the server.
+ """
+ this_func_name = "testSourceFileSizeExceedsGrpcMessageLengthLimit"
+
+ # Patch the method to simulate a very small message length limit.
+ with test.mock.patch.object(
+ source_remote, "grpc_message_length_bytes", return_value=2):
+ with session.Session() as sess:
+ a = variables.Variable(21.0, name="two/a")
+ a_lineno = line_number_above()
+ b = variables.Variable(2.0, name="two/b")
+ b_lineno = line_number_above()
+ x = math_ops.add(a, b, name="two/x")
+ x_lineno = line_number_above()
+
+ send_traceback = traceback.extract_stack()
+ send_lineno = line_number_above()
+ source_remote.send_graph_tracebacks(
+ [self._server_address, self._server_address_2],
+ "dummy_run_key", send_traceback, sess.graph)
+
+ servers = [self._server, self._server_2]
+ for server in servers:
+ # Even though the source file content is not sent, the traceback
+ # should have been sent.
+ tb = server.query_op_traceback("two/a")
+ self.assertIn((self._curr_file_path, a_lineno, this_func_name), tb)
+ tb = server.query_op_traceback("two/b")
+ self.assertIn((self._curr_file_path, b_lineno, this_func_name), tb)
+ tb = server.query_op_traceback("two/x")
+ self.assertIn((self._curr_file_path, x_lineno, this_func_name), tb)
+
+ self.assertIn(
+ (self._curr_file_path, send_lineno, this_func_name),
+ server.query_origin_stack()[-1])
+
+ tf_trace_file_path = (
+ self._findFirstTraceInsideTensorFlowPyLibrary(x.op))
+ # Verify that the source content is not sent to the server.
+ with self.assertRaises(ValueError):
+ self._server.query_source_file_line(tf_trace_file_path, 0)
+
def testSendEagerTracebacksToSingleDebugServer(self):
this_func_name = "testSendEagerTracebacksToSingleDebugServer"
send_traceback = traceback.extract_stack()