aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/debug
diff options
context:
space:
mode:
authorGravatar Shanqing Cai <cais@google.com>2018-03-13 17:11:56 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-13 17:25:43 -0700
commitbac72dc5a59b844d908d3badf77d62a8a93e2c3a (patch)
tree1009aeae805b41de05c0a5d135c718c34b99c81d /tensorflow/python/debug
parent8784e216a382a7831ea486e8350cdde812b7f888 (diff)
tfdbg: split session_debug_grpc_test
* so that the test sizes are medium for both the existing session_debug_grpc_test and the new grpc_large_data_test Also in this CL * Consolidate the functions for creating no-grappler-rewrite ConfigProtos in one place: in session_debug_testlib.py PiperOrigin-RevId: 188955135
Diffstat (limited to 'tensorflow/python/debug')
-rw-r--r--tensorflow/python/debug/BUILD26
-rw-r--r--tensorflow/python/debug/lib/grpc_large_data_test.py210
-rw-r--r--tensorflow/python/debug/lib/session_debug_file_test.py11
-rw-r--r--tensorflow/python/debug/lib/session_debug_grpc_test.py219
4 files changed, 264 insertions, 202 deletions
diff --git a/tensorflow/python/debug/BUILD b/tensorflow/python/debug/BUILD
index 253588fc3b..512d292ee2 100644
--- a/tensorflow/python/debug/BUILD
+++ b/tensorflow/python/debug/BUILD
@@ -957,7 +957,7 @@ cuda_py_test(
cuda_py_test(
name = "session_debug_grpc_test",
- size = "large",
+ size = "medium",
srcs = ["lib/session_debug_grpc_test.py"],
additional_deps = [
":debug_data",
@@ -967,7 +967,6 @@ cuda_py_test(
":grpc_wrapper",
":hooks",
":session_debug_testlib",
- "//third_party/py/numpy",
"//tensorflow/python:client",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
@@ -983,6 +982,29 @@ cuda_py_test(
],
)
+cuda_py_test(
+ name = "grpc_large_data_test",
+ size = "medium",
+ srcs = ["lib/grpc_large_data_test.py"],
+ additional_deps = [
+ ":dumping_wrapper",
+ ":grpc_debug_test_server",
+ ":grpc_wrapper",
+ ":session_debug_testlib",
+ "//third_party/py/numpy",
+ "//tensorflow/python:client",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:platform_test",
+ "//tensorflow/python:variables",
+ ],
+ tags = [
+ "no_oss", # Test flaky due to port collisions.
+ "no_windows",
+ "oss_serial",
+ ],
+)
+
# TODO(cais): Run the test in OSS, perhaps through a sh_test.
cuda_py_test(
name = "dist_session_debug_grpc_test",
diff --git a/tensorflow/python/debug/lib/grpc_large_data_test.py b/tensorflow/python/debug/lib/grpc_large_data_test.py
new file mode 100644
index 0000000000..5bc477a9ba
--- /dev/null
+++ b/tensorflow/python/debug/lib/grpc_large_data_test.py
@@ -0,0 +1,210 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for sending large-size data through tfdbg grpc channels.
+
+"Large-size data" includes large GraphDef protos and large Tensor protos.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+from six.moves import xrange # pylint: disable=redefined-builtin
+
+from tensorflow.python.debug.lib import grpc_debug_test_server
+from tensorflow.python.debug.lib import session_debug_testlib
+from tensorflow.python.debug.wrappers import framework
+from tensorflow.python.debug.wrappers import grpc_wrapper
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import googletest
+from tensorflow.python.platform import test
+from tensorflow.python.platform import tf_logging
+
+
+class LargeGraphAndLargeTensorsDebugTest(test_util.TensorFlowTestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ (cls.debug_server_port, cls.debug_server_url, _, cls.debug_server_thread,
+ cls.debug_server
+ ) = grpc_debug_test_server.start_server_on_separate_thread(
+ dump_to_filesystem=False)
+ tf_logging.info("debug server url: %s", cls.debug_server_url)
+
+ @classmethod
+ def tearDownClass(cls):
+ cls.debug_server.stop_server().wait()
+ cls.debug_server_thread.join()
+
+ def tearDown(self):
+ ops.reset_default_graph()
+ self.debug_server.clear_data()
+
+ def testSendingLargeGraphDefsWorks(self):
+ with self.test_session(
+ use_gpu=True,
+ config=session_debug_testlib.no_rewrite_session_config()) as sess:
+ u = variables.Variable(42.0, name="original_u")
+ for _ in xrange(50 * 1000):
+ u = array_ops.identity(u)
+ sess.run(variables.global_variables_initializer())
+
+ def watch_fn(fetches, feeds):
+ del fetches, feeds
+ return framework.WatchOptions(
+ debug_ops=["DebugIdentity"],
+ node_name_regex_whitelist=r"original_u")
+ sess = grpc_wrapper.GrpcDebugWrapperSession(
+ sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn)
+ self.assertAllClose(42.0, sess.run(u))
+
+ self.assertAllClose(
+ [42.0],
+ self.debug_server.debug_tensor_values["original_u:0:DebugIdentity"])
+ self.assertEqual(2 if test.is_gpu_available() else 1,
+ len(self.debug_server.partition_graph_defs))
+ max_graph_def_size = max([
+ len(graph_def.SerializeToString())
+ for graph_def in self.debug_server.partition_graph_defs])
+ self.assertGreater(max_graph_def_size, 4 * 1024 * 1024)
+
+ def testSendingLargeFloatTensorWorks(self):
+ with self.test_session(
+ use_gpu=True,
+ config=session_debug_testlib.no_rewrite_session_config()) as sess:
+ u_init_val_array = list(xrange(1200 * 1024))
+ # Size: 4 * 1200 * 1024 = 4800k > 4M
+
+ u_init = constant_op.constant(
+ u_init_val_array, dtype=dtypes.float32, name="u_init")
+ u = variables.Variable(u_init, name="u")
+
+ def watch_fn(fetches, feeds):
+ del fetches, feeds # Unused by this watch_fn.
+ return framework.WatchOptions(
+ debug_ops=["DebugIdentity"],
+ node_name_regex_whitelist=r"u_init")
+ sess = grpc_wrapper.GrpcDebugWrapperSession(
+ sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn)
+ sess.run(u.initializer)
+
+ self.assertAllEqual(
+ u_init_val_array,
+ self.debug_server.debug_tensor_values["u_init:0:DebugIdentity"][0])
+
+ def testSendingStringTensorWithAlmostTooLargeStringsWorks(self):
+ with self.test_session(
+ use_gpu=True,
+ config=session_debug_testlib.no_rewrite_session_config()) as sess:
+ u_init_val = [
+ b"", b"spam", b"A" * 2500 * 1024, b"B" * 2500 * 1024, b"egg", b""]
+ u_init = constant_op.constant(
+ u_init_val, dtype=dtypes.string, name="u_init")
+ u = variables.Variable(u_init, name="u")
+
+ def watch_fn(fetches, feeds):
+ del fetches, feeds
+ return framework.WatchOptions(
+ debug_ops=["DebugIdentity"],
+ node_name_regex_whitelist=r"u_init")
+ sess = grpc_wrapper.GrpcDebugWrapperSession(
+ sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn)
+ sess.run(u.initializer)
+
+ self.assertAllEqual(
+ u_init_val,
+ self.debug_server.debug_tensor_values["u_init:0:DebugIdentity"][0])
+
+ def testSendingLargeStringTensorWorks(self):
+ with self.test_session(
+ use_gpu=True,
+ config=session_debug_testlib.no_rewrite_session_config()) as sess:
+ strs_total_size_threshold = 5000 * 1024
+ cum_size = 0
+ u_init_val_array = []
+ while cum_size < strs_total_size_threshold:
+ strlen = np.random.randint(200)
+ u_init_val_array.append(b"A" * strlen)
+ cum_size += strlen
+
+ u_init = constant_op.constant(
+ u_init_val_array, dtype=dtypes.string, name="u_init")
+ u = variables.Variable(u_init, name="u")
+
+ def watch_fn(fetches, feeds):
+ del fetches, feeds
+ return framework.WatchOptions(
+ debug_ops=["DebugIdentity"],
+ node_name_regex_whitelist=r"u_init")
+ sess = grpc_wrapper.GrpcDebugWrapperSession(
+ sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn)
+ sess.run(u.initializer)
+
+ self.assertAllEqual(
+ u_init_val_array,
+ self.debug_server.debug_tensor_values["u_init:0:DebugIdentity"][0])
+
+ def testSendingEmptyFloatTensorWorks(self):
+ with self.test_session(
+ use_gpu=True,
+ config=session_debug_testlib.no_rewrite_session_config()) as sess:
+ u_init = constant_op.constant(
+ [], dtype=dtypes.float32, shape=[0], name="u_init")
+ u = variables.Variable(u_init, name="u")
+
+ def watch_fn(fetches, feeds):
+ del fetches, feeds
+ return framework.WatchOptions(
+ debug_ops=["DebugIdentity"],
+ node_name_regex_whitelist=r"u_init")
+ sess = grpc_wrapper.GrpcDebugWrapperSession(
+ sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn)
+ sess.run(u.initializer)
+
+ u_init_value = self.debug_server.debug_tensor_values[
+ "u_init:0:DebugIdentity"][0]
+ self.assertEqual(np.float32, u_init_value.dtype)
+ self.assertEqual(0, len(u_init_value))
+
+ def testSendingEmptyStringTensorWorks(self):
+ with self.test_session(
+ use_gpu=True,
+ config=session_debug_testlib.no_rewrite_session_config()) as sess:
+ u_init = constant_op.constant(
+ [], dtype=dtypes.string, shape=[0], name="u_init")
+ u = variables.Variable(u_init, name="u")
+
+ def watch_fn(fetches, feeds):
+ del fetches, feeds
+ return framework.WatchOptions(
+ debug_ops=["DebugIdentity"],
+ node_name_regex_whitelist=r"u_init")
+ sess = grpc_wrapper.GrpcDebugWrapperSession(
+ sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn)
+ sess.run(u.initializer)
+
+ u_init_value = self.debug_server.debug_tensor_values[
+ "u_init:0:DebugIdentity"][0]
+ self.assertEqual(np.object, u_init_value.dtype)
+ self.assertEqual(0, len(u_init_value))
+
+
+if __name__ == "__main__":
+ googletest.main()
diff --git a/tensorflow/python/debug/lib/session_debug_file_test.py b/tensorflow/python/debug/lib/session_debug_file_test.py
index 1a6bedbbcb..ba0f15b4e2 100644
--- a/tensorflow/python/debug/lib/session_debug_file_test.py
+++ b/tensorflow/python/debug/lib/session_debug_file_test.py
@@ -22,7 +22,6 @@ import shutil
import tempfile
from tensorflow.core.protobuf import config_pb2
-from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.client import session
from tensorflow.python.debug.lib import debug_data
from tensorflow.python.debug.lib import debug_utils
@@ -36,13 +35,6 @@ from tensorflow.python.platform import googletest
class SessionDebugFileTest(session_debug_testlib.SessionDebugTestBase):
- def _no_rewrite_session_config(self):
- rewriter_config = rewriter_config_pb2.RewriterConfig(
- disable_model_pruning=True,
- arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF)
- graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config)
- return config_pb2.ConfigProto(graph_options=graph_options)
-
def _debug_urls(self, run_number=None):
return ["file://%s" % self._debug_dump_dir(run_number=run_number)]
@@ -55,7 +47,8 @@ class SessionDebugFileTest(session_debug_testlib.SessionDebugTestBase):
def testAllowsDifferentWatchesOnDifferentRuns(self):
"""Test watching different tensors on different runs of the same graph."""
- with session.Session(config=self._no_rewrite_session_config()) as sess:
+ with session.Session(
+ config=session_debug_testlib.no_rewrite_session_config()) as sess:
u_init_val = [[5.0, 3.0], [-1.0, 0.0]]
v_init_val = [[2.0], [-1.0]]
diff --git a/tensorflow/python/debug/lib/session_debug_grpc_test.py b/tensorflow/python/debug/lib/session_debug_grpc_test.py
index b623ee31c5..ff49b69547 100644
--- a/tensorflow/python/debug/lib/session_debug_grpc_test.py
+++ b/tensorflow/python/debug/lib/session_debug_grpc_test.py
@@ -24,11 +24,9 @@ from __future__ import print_function
import os
import shutil
-import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.core.protobuf import config_pb2
-from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.client import session
from tensorflow.python.debug.lib import debug_data
from tensorflow.python.debug.lib import debug_utils
@@ -38,28 +36,15 @@ from tensorflow.python.debug.wrappers import framework
from tensorflow.python.debug.wrappers import grpc_wrapper
from tensorflow.python.debug.wrappers import hooks
from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
-from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import googletest
-from tensorflow.python.platform import test
-from tensorflow.python.platform import tf_logging
from tensorflow.python.training import monitored_session
-def no_rewrite_session_config():
- rewriter_config = rewriter_config_pb2.RewriterConfig(
- disable_model_pruning=True,
- arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF,
- dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF)
- graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config)
- return config_pb2.ConfigProto(graph_options=graph_options)
-
-
class GrpcDebugServerTest(test_util.TensorFlowTestCase):
def testRepeatedRunServerRaisesException(self):
@@ -142,19 +127,22 @@ class SessionDebugGrpcTest(session_debug_testlib.SessionDebugTestBase):
return os.path.join(self._dump_root, "run_%d" % run_number)
def testConstructGrpcDebugWrapperSessionWithInvalidTypeRaisesException(self):
- sess = session.Session(config=no_rewrite_session_config())
+ sess = session.Session(
+ config=session_debug_testlib.no_rewrite_session_config())
with self.assertRaisesRegexp(
TypeError, "Expected type str or list in grpc_debug_server_addresses"):
grpc_wrapper.GrpcDebugWrapperSession(sess, 1337)
def testConstructGrpcDebugWrapperSessionWithInvalidTypeRaisesException2(self):
- sess = session.Session(config=no_rewrite_session_config())
+ sess = session.Session(
+ config=session_debug_testlib.no_rewrite_session_config())
with self.assertRaisesRegexp(
TypeError, "Expected type str in list grpc_debug_server_addresses"):
grpc_wrapper.GrpcDebugWrapperSession(sess, ["localhost:1337", 1338])
def testUseInvalidWatchFnTypeWithGrpcDebugWrapperSessionRaisesException(self):
- sess = session.Session(config=no_rewrite_session_config())
+ sess = session.Session(
+ config=session_debug_testlib.no_rewrite_session_config())
with self.assertRaises(TypeError):
grpc_wrapper.GrpcDebugWrapperSession(
sess, "localhost:%d" % self._server_port, watch_fn="foo")
@@ -164,7 +152,8 @@ class SessionDebugGrpcTest(session_debug_testlib.SessionDebugTestBase):
v = variables.Variable(20.0, name="v")
w = math_ops.multiply(u, v, name="w")
- sess = session.Session(config=no_rewrite_session_config())
+ sess = session.Session(
+ config=session_debug_testlib.no_rewrite_session_config())
sess.run(u.initializer)
sess.run(v.initializer)
@@ -190,7 +179,8 @@ class SessionDebugGrpcTest(session_debug_testlib.SessionDebugTestBase):
v = variables.Variable(20.0, name="v")
w = math_ops.multiply(u, v, name="w")
- sess = session.Session(config=no_rewrite_session_config())
+ sess = session.Session(
+ config=session_debug_testlib.no_rewrite_session_config())
sess.run(u.initializer)
sess.run(v.initializer)
@@ -223,7 +213,8 @@ class SessionDebugGrpcTest(session_debug_testlib.SessionDebugTestBase):
v = variables.Variable(20.0, name="v")
w = math_ops.multiply(u, v, name="w")
- sess = session.Session(config=no_rewrite_session_config())
+ sess = session.Session(
+ config=session_debug_testlib.no_rewrite_session_config())
sess.run(u.initializer)
sess.run(v.initializer)
@@ -254,7 +245,8 @@ class SessionDebugGrpcTest(session_debug_testlib.SessionDebugTestBase):
v = variables.Variable(20.0, name="v")
w = math_ops.multiply(u, v, name="w")
- sess = session.Session(config=no_rewrite_session_config())
+ sess = session.Session(
+ config=session_debug_testlib.no_rewrite_session_config())
sess.run(u.initializer)
sess.run(v.initializer)
@@ -298,7 +290,8 @@ class SessionDebugGrpcTest(session_debug_testlib.SessionDebugTestBase):
v = variables.Variable(20.0, name="v")
w = math_ops.multiply(u, v, name="w")
- sess = session.Session(config=no_rewrite_session_config())
+ sess = session.Session(
+ config=session_debug_testlib.no_rewrite_session_config())
sess.run(variables.global_variables_initializer())
grpc_debug_hook = hooks.TensorBoardDebugHook(
@@ -324,168 +317,6 @@ class SessionDebugGrpcTest(session_debug_testlib.SessionDebugTestBase):
hooks.GrpcDebugHook(["foo:42424"])
-class LargeGraphAndLargeTensorsDebugTest(test_util.TensorFlowTestCase):
-
- @classmethod
- def setUpClass(cls):
- (cls.debug_server_port, cls.debug_server_url, _, cls.debug_server_thread,
- cls.debug_server
- ) = grpc_debug_test_server.start_server_on_separate_thread(
- dump_to_filesystem=False)
- tf_logging.info("debug server url: %s", cls.debug_server_url)
-
- @classmethod
- def tearDownClass(cls):
- cls.debug_server.stop_server().wait()
- cls.debug_server_thread.join()
-
- def tearDown(self):
- ops.reset_default_graph()
- self.debug_server.clear_data()
-
- def testSendingLargeGraphDefsWorks(self):
- with self.test_session(
- use_gpu=True, config=no_rewrite_session_config()) as sess:
- u = variables.Variable(42.0, name="original_u")
- for _ in xrange(50 * 1000):
- u = array_ops.identity(u)
- sess.run(variables.global_variables_initializer())
-
- def watch_fn(fetches, feeds):
- del fetches, feeds
- return framework.WatchOptions(
- debug_ops=["DebugIdentity"],
- node_name_regex_whitelist=r"original_u")
- sess = grpc_wrapper.GrpcDebugWrapperSession(
- sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn)
- self.assertAllClose(42.0, sess.run(u))
-
- self.assertAllClose(
- [42.0],
- self.debug_server.debug_tensor_values["original_u:0:DebugIdentity"])
- self.assertEqual(2 if test.is_gpu_available() else 1,
- len(self.debug_server.partition_graph_defs))
- max_graph_def_size = max([
- len(graph_def.SerializeToString())
- for graph_def in self.debug_server.partition_graph_defs])
- self.assertGreater(max_graph_def_size, 4 * 1024 * 1024)
-
- def testSendingLargeFloatTensorWorks(self):
- with self.test_session(
- use_gpu=True, config=no_rewrite_session_config()) as sess:
- u_init_val_array = list(xrange(1200 * 1024))
- # Size: 4 * 1200 * 1024 = 4800k > 4M
-
- u_init = constant_op.constant(
- u_init_val_array, dtype=dtypes.float32, name="u_init")
- u = variables.Variable(u_init, name="u")
-
- def watch_fn(fetches, feeds):
- del fetches, feeds # Unused by this watch_fn.
- return framework.WatchOptions(
- debug_ops=["DebugIdentity"],
- node_name_regex_whitelist=r"u_init")
- sess = grpc_wrapper.GrpcDebugWrapperSession(
- sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn)
- sess.run(u.initializer)
-
- self.assertAllEqual(
- u_init_val_array,
- self.debug_server.debug_tensor_values["u_init:0:DebugIdentity"][0])
-
- def testSendingStringTensorWithAlmostTooLargeStringsWorks(self):
- with self.test_session(
- use_gpu=True, config=no_rewrite_session_config()) as sess:
- u_init_val = [
- b"", b"spam", b"A" * 2500 * 1024, b"B" * 2500 * 1024, b"egg", b""]
- u_init = constant_op.constant(
- u_init_val, dtype=dtypes.string, name="u_init")
- u = variables.Variable(u_init, name="u")
-
- def watch_fn(fetches, feeds):
- del fetches, feeds
- return framework.WatchOptions(
- debug_ops=["DebugIdentity"],
- node_name_regex_whitelist=r"u_init")
- sess = grpc_wrapper.GrpcDebugWrapperSession(
- sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn)
- sess.run(u.initializer)
-
- self.assertAllEqual(
- u_init_val,
- self.debug_server.debug_tensor_values["u_init:0:DebugIdentity"][0])
-
- def testSendingLargeStringTensorWorks(self):
- with self.test_session(
- use_gpu=True, config=no_rewrite_session_config()) as sess:
- strs_total_size_threshold = 5000 * 1024
- cum_size = 0
- u_init_val_array = []
- while cum_size < strs_total_size_threshold:
- strlen = np.random.randint(200)
- u_init_val_array.append(b"A" * strlen)
- cum_size += strlen
-
- u_init = constant_op.constant(
- u_init_val_array, dtype=dtypes.string, name="u_init")
- u = variables.Variable(u_init, name="u")
-
- def watch_fn(fetches, feeds):
- del fetches, feeds
- return framework.WatchOptions(
- debug_ops=["DebugIdentity"],
- node_name_regex_whitelist=r"u_init")
- sess = grpc_wrapper.GrpcDebugWrapperSession(
- sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn)
- sess.run(u.initializer)
-
- self.assertAllEqual(
- u_init_val_array,
- self.debug_server.debug_tensor_values["u_init:0:DebugIdentity"][0])
-
- def testSendingEmptyFloatTensorWorks(self):
- with self.test_session(
- use_gpu=True, config=no_rewrite_session_config()) as sess:
- u_init = constant_op.constant(
- [], dtype=dtypes.float32, shape=[0], name="u_init")
- u = variables.Variable(u_init, name="u")
-
- def watch_fn(fetches, feeds):
- del fetches, feeds
- return framework.WatchOptions(
- debug_ops=["DebugIdentity"],
- node_name_regex_whitelist=r"u_init")
- sess = grpc_wrapper.GrpcDebugWrapperSession(
- sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn)
- sess.run(u.initializer)
-
- u_init_value = self.debug_server.debug_tensor_values[
- "u_init:0:DebugIdentity"][0]
- self.assertEqual(np.float32, u_init_value.dtype)
- self.assertEqual(0, len(u_init_value))
-
- def testSendingEmptyStringTensorWorks(self):
- with self.test_session(
- use_gpu=True, config=no_rewrite_session_config()) as sess:
- u_init = constant_op.constant(
- [], dtype=dtypes.string, shape=[0], name="u_init")
- u = variables.Variable(u_init, name="u")
-
- def watch_fn(fetches, feeds):
- del fetches, feeds
- return framework.WatchOptions(
- debug_ops=["DebugIdentity"],
- node_name_regex_whitelist=r"u_init")
- sess = grpc_wrapper.GrpcDebugWrapperSession(
- sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn)
- sess.run(u.initializer)
-
- u_init_value = self.debug_server.debug_tensor_values[
- "u_init:0:DebugIdentity"][0]
- self.assertEqual(np.object, u_init_value.dtype)
- self.assertEqual(0, len(u_init_value))
-
-
class SessionDebugConcurrentTest(
session_debug_testlib.DebugConcurrentRunCallsTest):
@@ -548,7 +379,8 @@ class SessionDebugGrpcGatingTest(test_util.TensorFlowTestCase):
self._server_2.clear_data()
def testToggleEnableTwoDebugWatchesNoCrosstalkBetweenDebugNodes(self):
- with session.Session(config=no_rewrite_session_config()) as sess:
+ with session.Session(
+ config=session_debug_testlib.no_rewrite_session_config()) as sess:
v_1 = variables.Variable(50.0, name="v_1")
v_2 = variables.Variable(-50.0, name="v_1")
delta_1 = constant_op.constant(5.0, name="delta_1")
@@ -617,7 +449,8 @@ class SessionDebugGrpcGatingTest(test_util.TensorFlowTestCase):
("toggled_2", 0, "DebugIdentity")])
self._servers_and_threads.append((server, server_thread))
- with session.Session(config=no_rewrite_session_config()) as sess:
+ with session.Session(
+ config=session_debug_testlib.no_rewrite_session_config()) as sess:
v_1 = variables.Variable(50.0, name="v_1")
v_2 = variables.Variable(-50.0, name="v_1")
# These two nodes have names that match those in the
@@ -656,7 +489,8 @@ class SessionDebugGrpcGatingTest(test_util.TensorFlowTestCase):
self.assertEqual(0, len(server.debug_tensor_values))
def testToggleEnableTwoDebugWatchesNoCrosstalkBetweenServers(self):
- with session.Session(config=no_rewrite_session_config()) as sess:
+ with session.Session(
+ config=session_debug_testlib.no_rewrite_session_config()) as sess:
v = variables.Variable(50.0, name="v")
delta = constant_op.constant(5.0, name="delta")
inc_v = state_ops.assign_add(v, delta, name="inc_v")
@@ -698,7 +532,8 @@ class SessionDebugGrpcGatingTest(test_util.TensorFlowTestCase):
self.assertEqual(0, len(self._server_2.debug_tensor_values))
def testToggleBreakpointsWorks(self):
- with session.Session(config=no_rewrite_session_config()) as sess:
+ with session.Session(
+ config=session_debug_testlib.no_rewrite_session_config()) as sess:
v_1 = variables.Variable(50.0, name="v_1")
v_2 = variables.Variable(-50.0, name="v_2")
delta_1 = constant_op.constant(5.0, name="delta_1")
@@ -755,7 +590,8 @@ class SessionDebugGrpcGatingTest(test_util.TensorFlowTestCase):
self.assertSetEqual(set(), self._server_1.breakpoints)
def testTensorBoardDebuggerWrapperToggleBreakpointsWorks(self):
- with session.Session(config=no_rewrite_session_config()) as sess:
+ with session.Session(
+ config=session_debug_testlib.no_rewrite_session_config()) as sess:
v_1 = variables.Variable(50.0, name="v_1")
v_2 = variables.Variable(-50.0, name="v_2")
delta_1 = constant_op.constant(5.0, name="delta_1")
@@ -827,7 +663,8 @@ class SessionDebugGrpcGatingTest(test_util.TensorFlowTestCase):
self._server_1.query_source_file_line(__file__, 1)
def testTensorBoardDebuggerWrapperDisablingTracebackSourceSendingWorks(self):
- with session.Session(config=no_rewrite_session_config()) as sess:
+ with session.Session(
+ config=session_debug_testlib.no_rewrite_session_config()) as sess:
v_1 = variables.Variable(50.0, name="v_1")
v_2 = variables.Variable(-50.0, name="v_2")
delta_1 = constant_op.constant(5.0, name="delta_1")