diff options
author | Shanqing Cai <cais@google.com> | 2018-03-13 17:11:56 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-03-13 17:25:43 -0700 |
commit | bac72dc5a59b844d908d3badf77d62a8a93e2c3a (patch) | |
tree | 1009aeae805b41de05c0a5d135c718c34b99c81d /tensorflow/python/debug | |
parent | 8784e216a382a7831ea486e8350cdde812b7f888 (diff) |
tfdbg: split session_debug_grpc_test
* so that the test sizes are medium for both the existing session_debug_grpc_test
and the new grpc_large_data_test
Also in this CL
* Consolidate the functions for creating no-grappler-rewrite ConfigProtos
in one place: in session_debug_testlib.py
PiperOrigin-RevId: 188955135
Diffstat (limited to 'tensorflow/python/debug')
-rw-r--r-- | tensorflow/python/debug/BUILD | 26 | ||||
-rw-r--r-- | tensorflow/python/debug/lib/grpc_large_data_test.py | 210 | ||||
-rw-r--r-- | tensorflow/python/debug/lib/session_debug_file_test.py | 11 | ||||
-rw-r--r-- | tensorflow/python/debug/lib/session_debug_grpc_test.py | 219 |
4 files changed, 264 insertions, 202 deletions
diff --git a/tensorflow/python/debug/BUILD b/tensorflow/python/debug/BUILD index 253588fc3b..512d292ee2 100644 --- a/tensorflow/python/debug/BUILD +++ b/tensorflow/python/debug/BUILD @@ -957,7 +957,7 @@ cuda_py_test( cuda_py_test( name = "session_debug_grpc_test", - size = "large", + size = "medium", srcs = ["lib/session_debug_grpc_test.py"], additional_deps = [ ":debug_data", @@ -967,7 +967,6 @@ cuda_py_test( ":grpc_wrapper", ":hooks", ":session_debug_testlib", - "//third_party/py/numpy", "//tensorflow/python:client", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", @@ -983,6 +982,29 @@ cuda_py_test( ], ) +cuda_py_test( + name = "grpc_large_data_test", + size = "medium", + srcs = ["lib/grpc_large_data_test.py"], + additional_deps = [ + ":dumping_wrapper", + ":grpc_debug_test_server", + ":grpc_wrapper", + ":session_debug_testlib", + "//third_party/py/numpy", + "//tensorflow/python:client", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:platform_test", + "//tensorflow/python:variables", + ], + tags = [ + "no_oss", # Test flaky due to port collisions. + "no_windows", + "oss_serial", + ], +) + # TODO(cais): Run the test in OSS, perhaps through a sh_test. cuda_py_test( name = "dist_session_debug_grpc_test", diff --git a/tensorflow/python/debug/lib/grpc_large_data_test.py b/tensorflow/python/debug/lib/grpc_large_data_test.py new file mode 100644 index 0000000000..5bc477a9ba --- /dev/null +++ b/tensorflow/python/debug/lib/grpc_large_data_test.py @@ -0,0 +1,210 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for sending large-size data through tfdbg grpc channels. + +"Large-size data" includes large GraphDef protos and large Tensor protos. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +from six.moves import xrange # pylint: disable=redefined-builtin + +from tensorflow.python.debug.lib import grpc_debug_test_server +from tensorflow.python.debug.lib import session_debug_testlib +from tensorflow.python.debug.wrappers import framework +from tensorflow.python.debug.wrappers import grpc_wrapper +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import googletest +from tensorflow.python.platform import test +from tensorflow.python.platform import tf_logging + + +class LargeGraphAndLargeTensorsDebugTest(test_util.TensorFlowTestCase): + + @classmethod + def setUpClass(cls): + (cls.debug_server_port, cls.debug_server_url, _, cls.debug_server_thread, + cls.debug_server + ) = grpc_debug_test_server.start_server_on_separate_thread( + dump_to_filesystem=False) + tf_logging.info("debug server url: %s", cls.debug_server_url) + + @classmethod + def tearDownClass(cls): + cls.debug_server.stop_server().wait() + cls.debug_server_thread.join() + + def tearDown(self): + ops.reset_default_graph() + self.debug_server.clear_data() + + def testSendingLargeGraphDefsWorks(self): + with self.test_session( + use_gpu=True, + config=session_debug_testlib.no_rewrite_session_config()) as sess: + u = variables.Variable(42.0, name="original_u") + for _ in xrange(50 * 1000): + u = array_ops.identity(u) + sess.run(variables.global_variables_initializer()) + + def watch_fn(fetches, feeds): + del fetches, feeds + return framework.WatchOptions( + debug_ops=["DebugIdentity"], + node_name_regex_whitelist=r"original_u") + sess = grpc_wrapper.GrpcDebugWrapperSession( + sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn) + self.assertAllClose(42.0, sess.run(u)) + + self.assertAllClose( + [42.0], + self.debug_server.debug_tensor_values["original_u:0:DebugIdentity"]) + self.assertEqual(2 if test.is_gpu_available() else 1, + len(self.debug_server.partition_graph_defs)) + max_graph_def_size = max([ + len(graph_def.SerializeToString()) + for graph_def in self.debug_server.partition_graph_defs]) + self.assertGreater(max_graph_def_size, 4 * 1024 * 1024) + + def testSendingLargeFloatTensorWorks(self): + with self.test_session( + use_gpu=True, + config=session_debug_testlib.no_rewrite_session_config()) as sess: + u_init_val_array = list(xrange(1200 * 1024)) + # Size: 4 * 1200 * 1024 = 4800k > 4M + + u_init = constant_op.constant( + u_init_val_array, dtype=dtypes.float32, name="u_init") + u = variables.Variable(u_init, name="u") + + def watch_fn(fetches, feeds): + del fetches, feeds # Unused by this watch_fn. + return framework.WatchOptions( + debug_ops=["DebugIdentity"], + node_name_regex_whitelist=r"u_init") + sess = grpc_wrapper.GrpcDebugWrapperSession( + sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn) + sess.run(u.initializer) + + self.assertAllEqual( + u_init_val_array, + self.debug_server.debug_tensor_values["u_init:0:DebugIdentity"][0]) + + def testSendingStringTensorWithAlmostTooLargeStringsWorks(self): + with self.test_session( + use_gpu=True, + config=session_debug_testlib.no_rewrite_session_config()) as sess: + u_init_val = [ + b"", b"spam", b"A" * 2500 * 1024, b"B" * 2500 * 1024, b"egg", b""] + u_init = constant_op.constant( + u_init_val, dtype=dtypes.string, name="u_init") + u = variables.Variable(u_init, name="u") + + def watch_fn(fetches, feeds): + del fetches, feeds + return framework.WatchOptions( + debug_ops=["DebugIdentity"], + node_name_regex_whitelist=r"u_init") + sess = grpc_wrapper.GrpcDebugWrapperSession( + sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn) + sess.run(u.initializer) + + self.assertAllEqual( + u_init_val, + self.debug_server.debug_tensor_values["u_init:0:DebugIdentity"][0]) + + def testSendingLargeStringTensorWorks(self): + with self.test_session( + use_gpu=True, + config=session_debug_testlib.no_rewrite_session_config()) as sess: + strs_total_size_threshold = 5000 * 1024 + cum_size = 0 + u_init_val_array = [] + while cum_size < strs_total_size_threshold: + strlen = np.random.randint(200) + u_init_val_array.append(b"A" * strlen) + cum_size += strlen + + u_init = constant_op.constant( + u_init_val_array, dtype=dtypes.string, name="u_init") + u = variables.Variable(u_init, name="u") + + def watch_fn(fetches, feeds): + del fetches, feeds + return framework.WatchOptions( + debug_ops=["DebugIdentity"], + node_name_regex_whitelist=r"u_init") + sess = grpc_wrapper.GrpcDebugWrapperSession( + sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn) + sess.run(u.initializer) + + self.assertAllEqual( + u_init_val_array, + self.debug_server.debug_tensor_values["u_init:0:DebugIdentity"][0]) + + def testSendingEmptyFloatTensorWorks(self): + with self.test_session( + use_gpu=True, + config=session_debug_testlib.no_rewrite_session_config()) as sess: + u_init = constant_op.constant( + [], dtype=dtypes.float32, shape=[0], name="u_init") + u = variables.Variable(u_init, name="u") + + def watch_fn(fetches, feeds): + del fetches, feeds + return framework.WatchOptions( + debug_ops=["DebugIdentity"], + node_name_regex_whitelist=r"u_init") + sess = grpc_wrapper.GrpcDebugWrapperSession( + sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn) + sess.run(u.initializer) + + u_init_value = self.debug_server.debug_tensor_values[ + "u_init:0:DebugIdentity"][0] + self.assertEqual(np.float32, u_init_value.dtype) + self.assertEqual(0, len(u_init_value)) + + def testSendingEmptyStringTensorWorks(self): + with self.test_session( + use_gpu=True, + config=session_debug_testlib.no_rewrite_session_config()) as sess: + u_init = constant_op.constant( + [], dtype=dtypes.string, shape=[0], name="u_init") + u = variables.Variable(u_init, name="u") + + def watch_fn(fetches, feeds): + del fetches, feeds + return framework.WatchOptions( + debug_ops=["DebugIdentity"], + node_name_regex_whitelist=r"u_init") + sess = grpc_wrapper.GrpcDebugWrapperSession( + sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn) + sess.run(u.initializer) + + u_init_value = self.debug_server.debug_tensor_values[ + "u_init:0:DebugIdentity"][0] + self.assertEqual(np.object, u_init_value.dtype) + self.assertEqual(0, len(u_init_value)) + + +if __name__ == "__main__": + googletest.main() diff --git a/tensorflow/python/debug/lib/session_debug_file_test.py b/tensorflow/python/debug/lib/session_debug_file_test.py index 1a6bedbbcb..ba0f15b4e2 100644 --- a/tensorflow/python/debug/lib/session_debug_file_test.py +++ b/tensorflow/python/debug/lib/session_debug_file_test.py @@ -22,7 +22,6 @@ import shutil import tempfile from tensorflow.core.protobuf import config_pb2 -from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python.client import session from tensorflow.python.debug.lib import debug_data from tensorflow.python.debug.lib import debug_utils @@ -36,13 +35,6 @@ from tensorflow.python.platform import googletest class SessionDebugFileTest(session_debug_testlib.SessionDebugTestBase): - def _no_rewrite_session_config(self): - rewriter_config = rewriter_config_pb2.RewriterConfig( - disable_model_pruning=True, - arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF) - graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config) - return config_pb2.ConfigProto(graph_options=graph_options) - def _debug_urls(self, run_number=None): return ["file://%s" % self._debug_dump_dir(run_number=run_number)] @@ -55,7 +47,8 @@ class SessionDebugFileTest(session_debug_testlib.SessionDebugTestBase): def testAllowsDifferentWatchesOnDifferentRuns(self): """Test watching different tensors on different runs of the same graph.""" - with session.Session(config=self._no_rewrite_session_config()) as sess: + with session.Session( + config=session_debug_testlib.no_rewrite_session_config()) as sess: u_init_val = [[5.0, 3.0], [-1.0, 0.0]] v_init_val = [[2.0], [-1.0]] diff --git a/tensorflow/python/debug/lib/session_debug_grpc_test.py b/tensorflow/python/debug/lib/session_debug_grpc_test.py index b623ee31c5..ff49b69547 100644 --- a/tensorflow/python/debug/lib/session_debug_grpc_test.py +++ b/tensorflow/python/debug/lib/session_debug_grpc_test.py @@ -24,11 +24,9 @@ from __future__ import print_function import os import shutil -import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.core.protobuf import config_pb2 -from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python.client import session from tensorflow.python.debug.lib import debug_data from tensorflow.python.debug.lib import debug_utils @@ -38,28 +36,15 @@ from tensorflow.python.debug.wrappers import framework from tensorflow.python.debug.wrappers import grpc_wrapper from tensorflow.python.debug.wrappers import hooks from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import test_util -from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import variables from tensorflow.python.platform import googletest -from tensorflow.python.platform import test -from tensorflow.python.platform import tf_logging from tensorflow.python.training import monitored_session -def no_rewrite_session_config(): - rewriter_config = rewriter_config_pb2.RewriterConfig( - disable_model_pruning=True, - arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF, - dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF) - graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config) - return config_pb2.ConfigProto(graph_options=graph_options) - - class GrpcDebugServerTest(test_util.TensorFlowTestCase): def testRepeatedRunServerRaisesException(self): @@ -142,19 +127,22 @@ class SessionDebugGrpcTest(session_debug_testlib.SessionDebugTestBase): return os.path.join(self._dump_root, "run_%d" % run_number) def testConstructGrpcDebugWrapperSessionWithInvalidTypeRaisesException(self): - sess = session.Session(config=no_rewrite_session_config()) + sess = session.Session( + config=session_debug_testlib.no_rewrite_session_config()) with self.assertRaisesRegexp( TypeError, "Expected type str or list in grpc_debug_server_addresses"): grpc_wrapper.GrpcDebugWrapperSession(sess, 1337) def testConstructGrpcDebugWrapperSessionWithInvalidTypeRaisesException2(self): - sess = session.Session(config=no_rewrite_session_config()) + sess = session.Session( + config=session_debug_testlib.no_rewrite_session_config()) with self.assertRaisesRegexp( TypeError, "Expected type str in list grpc_debug_server_addresses"): grpc_wrapper.GrpcDebugWrapperSession(sess, ["localhost:1337", 1338]) def testUseInvalidWatchFnTypeWithGrpcDebugWrapperSessionRaisesException(self): - sess = session.Session(config=no_rewrite_session_config()) + sess = session.Session( + config=session_debug_testlib.no_rewrite_session_config()) with self.assertRaises(TypeError): grpc_wrapper.GrpcDebugWrapperSession( sess, "localhost:%d" % self._server_port, watch_fn="foo") @@ -164,7 +152,8 @@ class SessionDebugGrpcTest(session_debug_testlib.SessionDebugTestBase): v = variables.Variable(20.0, name="v") w = math_ops.multiply(u, v, name="w") - sess = session.Session(config=no_rewrite_session_config()) + sess = session.Session( + config=session_debug_testlib.no_rewrite_session_config()) sess.run(u.initializer) sess.run(v.initializer) @@ -190,7 +179,8 @@ class SessionDebugGrpcTest(session_debug_testlib.SessionDebugTestBase): v = variables.Variable(20.0, name="v") w = math_ops.multiply(u, v, name="w") - sess = session.Session(config=no_rewrite_session_config()) + sess = session.Session( + config=session_debug_testlib.no_rewrite_session_config()) sess.run(u.initializer) sess.run(v.initializer) @@ -223,7 +213,8 @@ class SessionDebugGrpcTest(session_debug_testlib.SessionDebugTestBase): v = variables.Variable(20.0, name="v") w = math_ops.multiply(u, v, name="w") - sess = session.Session(config=no_rewrite_session_config()) + sess = session.Session( + config=session_debug_testlib.no_rewrite_session_config()) sess.run(u.initializer) sess.run(v.initializer) @@ -254,7 +245,8 @@ class SessionDebugGrpcTest(session_debug_testlib.SessionDebugTestBase): v = variables.Variable(20.0, name="v") w = math_ops.multiply(u, v, name="w") - sess = session.Session(config=no_rewrite_session_config()) + sess = session.Session( + config=session_debug_testlib.no_rewrite_session_config()) sess.run(u.initializer) sess.run(v.initializer) @@ -298,7 +290,8 @@ class SessionDebugGrpcTest(session_debug_testlib.SessionDebugTestBase): v = variables.Variable(20.0, name="v") w = math_ops.multiply(u, v, name="w") - sess = session.Session(config=no_rewrite_session_config()) + sess = session.Session( + config=session_debug_testlib.no_rewrite_session_config()) sess.run(variables.global_variables_initializer()) grpc_debug_hook = hooks.TensorBoardDebugHook( @@ -324,168 +317,6 @@ class SessionDebugGrpcTest(session_debug_testlib.SessionDebugTestBase): hooks.GrpcDebugHook(["foo:42424"]) -class LargeGraphAndLargeTensorsDebugTest(test_util.TensorFlowTestCase): - - @classmethod - def setUpClass(cls): - (cls.debug_server_port, cls.debug_server_url, _, cls.debug_server_thread, - cls.debug_server - ) = grpc_debug_test_server.start_server_on_separate_thread( - dump_to_filesystem=False) - tf_logging.info("debug server url: %s", cls.debug_server_url) - - @classmethod - def tearDownClass(cls): - cls.debug_server.stop_server().wait() - cls.debug_server_thread.join() - - def tearDown(self): - ops.reset_default_graph() - self.debug_server.clear_data() - - def testSendingLargeGraphDefsWorks(self): - with self.test_session( - use_gpu=True, config=no_rewrite_session_config()) as sess: - u = variables.Variable(42.0, name="original_u") - for _ in xrange(50 * 1000): - u = array_ops.identity(u) - sess.run(variables.global_variables_initializer()) - - def watch_fn(fetches, feeds): - del fetches, feeds - return framework.WatchOptions( - debug_ops=["DebugIdentity"], - node_name_regex_whitelist=r"original_u") - sess = grpc_wrapper.GrpcDebugWrapperSession( - sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn) - self.assertAllClose(42.0, sess.run(u)) - - self.assertAllClose( - [42.0], - self.debug_server.debug_tensor_values["original_u:0:DebugIdentity"]) - self.assertEqual(2 if test.is_gpu_available() else 1, - len(self.debug_server.partition_graph_defs)) - max_graph_def_size = max([ - len(graph_def.SerializeToString()) - for graph_def in self.debug_server.partition_graph_defs]) - self.assertGreater(max_graph_def_size, 4 * 1024 * 1024) - - def testSendingLargeFloatTensorWorks(self): - with self.test_session( - use_gpu=True, config=no_rewrite_session_config()) as sess: - u_init_val_array = list(xrange(1200 * 1024)) - # Size: 4 * 1200 * 1024 = 4800k > 4M - - u_init = constant_op.constant( - u_init_val_array, dtype=dtypes.float32, name="u_init") - u = variables.Variable(u_init, name="u") - - def watch_fn(fetches, feeds): - del fetches, feeds # Unused by this watch_fn. - return framework.WatchOptions( - debug_ops=["DebugIdentity"], - node_name_regex_whitelist=r"u_init") - sess = grpc_wrapper.GrpcDebugWrapperSession( - sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn) - sess.run(u.initializer) - - self.assertAllEqual( - u_init_val_array, - self.debug_server.debug_tensor_values["u_init:0:DebugIdentity"][0]) - - def testSendingStringTensorWithAlmostTooLargeStringsWorks(self): - with self.test_session( - use_gpu=True, config=no_rewrite_session_config()) as sess: - u_init_val = [ - b"", b"spam", b"A" * 2500 * 1024, b"B" * 2500 * 1024, b"egg", b""] - u_init = constant_op.constant( - u_init_val, dtype=dtypes.string, name="u_init") - u = variables.Variable(u_init, name="u") - - def watch_fn(fetches, feeds): - del fetches, feeds - return framework.WatchOptions( - debug_ops=["DebugIdentity"], - node_name_regex_whitelist=r"u_init") - sess = grpc_wrapper.GrpcDebugWrapperSession( - sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn) - sess.run(u.initializer) - - self.assertAllEqual( - u_init_val, - self.debug_server.debug_tensor_values["u_init:0:DebugIdentity"][0]) - - def testSendingLargeStringTensorWorks(self): - with self.test_session( - use_gpu=True, config=no_rewrite_session_config()) as sess: - strs_total_size_threshold = 5000 * 1024 - cum_size = 0 - u_init_val_array = [] - while cum_size < strs_total_size_threshold: - strlen = np.random.randint(200) - u_init_val_array.append(b"A" * strlen) - cum_size += strlen - - u_init = constant_op.constant( - u_init_val_array, dtype=dtypes.string, name="u_init") - u = variables.Variable(u_init, name="u") - - def watch_fn(fetches, feeds): - del fetches, feeds - return framework.WatchOptions( - debug_ops=["DebugIdentity"], - node_name_regex_whitelist=r"u_init") - sess = grpc_wrapper.GrpcDebugWrapperSession( - sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn) - sess.run(u.initializer) - - self.assertAllEqual( - u_init_val_array, - self.debug_server.debug_tensor_values["u_init:0:DebugIdentity"][0]) - - def testSendingEmptyFloatTensorWorks(self): - with self.test_session( - use_gpu=True, config=no_rewrite_session_config()) as sess: - u_init = constant_op.constant( - [], dtype=dtypes.float32, shape=[0], name="u_init") - u = variables.Variable(u_init, name="u") - - def watch_fn(fetches, feeds): - del fetches, feeds - return framework.WatchOptions( - debug_ops=["DebugIdentity"], - node_name_regex_whitelist=r"u_init") - sess = grpc_wrapper.GrpcDebugWrapperSession( - sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn) - sess.run(u.initializer) - - u_init_value = self.debug_server.debug_tensor_values[ - "u_init:0:DebugIdentity"][0] - self.assertEqual(np.float32, u_init_value.dtype) - self.assertEqual(0, len(u_init_value)) - - def testSendingEmptyStringTensorWorks(self): - with self.test_session( - use_gpu=True, config=no_rewrite_session_config()) as sess: - u_init = constant_op.constant( - [], dtype=dtypes.string, shape=[0], name="u_init") - u = variables.Variable(u_init, name="u") - - def watch_fn(fetches, feeds): - del fetches, feeds - return framework.WatchOptions( - debug_ops=["DebugIdentity"], - node_name_regex_whitelist=r"u_init") - sess = grpc_wrapper.GrpcDebugWrapperSession( - sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn) - sess.run(u.initializer) - - u_init_value = self.debug_server.debug_tensor_values[ - "u_init:0:DebugIdentity"][0] - self.assertEqual(np.object, u_init_value.dtype) - self.assertEqual(0, len(u_init_value)) - - class SessionDebugConcurrentTest( session_debug_testlib.DebugConcurrentRunCallsTest): @@ -548,7 +379,8 @@ class SessionDebugGrpcGatingTest(test_util.TensorFlowTestCase): self._server_2.clear_data() def testToggleEnableTwoDebugWatchesNoCrosstalkBetweenDebugNodes(self): - with session.Session(config=no_rewrite_session_config()) as sess: + with session.Session( + config=session_debug_testlib.no_rewrite_session_config()) as sess: v_1 = variables.Variable(50.0, name="v_1") v_2 = variables.Variable(-50.0, name="v_1") delta_1 = constant_op.constant(5.0, name="delta_1") @@ -617,7 +449,8 @@ class SessionDebugGrpcGatingTest(test_util.TensorFlowTestCase): ("toggled_2", 0, "DebugIdentity")]) self._servers_and_threads.append((server, server_thread)) - with session.Session(config=no_rewrite_session_config()) as sess: + with session.Session( + config=session_debug_testlib.no_rewrite_session_config()) as sess: v_1 = variables.Variable(50.0, name="v_1") v_2 = variables.Variable(-50.0, name="v_1") # These two nodes have names that match those in the @@ -656,7 +489,8 @@ class SessionDebugGrpcGatingTest(test_util.TensorFlowTestCase): self.assertEqual(0, len(server.debug_tensor_values)) def testToggleEnableTwoDebugWatchesNoCrosstalkBetweenServers(self): - with session.Session(config=no_rewrite_session_config()) as sess: + with session.Session( + config=session_debug_testlib.no_rewrite_session_config()) as sess: v = variables.Variable(50.0, name="v") delta = constant_op.constant(5.0, name="delta") inc_v = state_ops.assign_add(v, delta, name="inc_v") @@ -698,7 +532,8 @@ class SessionDebugGrpcGatingTest(test_util.TensorFlowTestCase): self.assertEqual(0, len(self._server_2.debug_tensor_values)) def testToggleBreakpointsWorks(self): - with session.Session(config=no_rewrite_session_config()) as sess: + with session.Session( + config=session_debug_testlib.no_rewrite_session_config()) as sess: v_1 = variables.Variable(50.0, name="v_1") v_2 = variables.Variable(-50.0, name="v_2") delta_1 = constant_op.constant(5.0, name="delta_1") @@ -755,7 +590,8 @@ class SessionDebugGrpcGatingTest(test_util.TensorFlowTestCase): self.assertSetEqual(set(), self._server_1.breakpoints) def testTensorBoardDebuggerWrapperToggleBreakpointsWorks(self): - with session.Session(config=no_rewrite_session_config()) as sess: + with session.Session( + config=session_debug_testlib.no_rewrite_session_config()) as sess: v_1 = variables.Variable(50.0, name="v_1") v_2 = variables.Variable(-50.0, name="v_2") delta_1 = constant_op.constant(5.0, name="delta_1") @@ -827,7 +663,8 @@ class SessionDebugGrpcGatingTest(test_util.TensorFlowTestCase): self._server_1.query_source_file_line(__file__, 1) def testTensorBoardDebuggerWrapperDisablingTracebackSourceSendingWorks(self): - with session.Session(config=no_rewrite_session_config()) as sess: + with session.Session( + config=session_debug_testlib.no_rewrite_session_config()) as sess: v_1 = variables.Variable(50.0, name="v_1") v_2 = variables.Variable(-50.0, name="v_2") delta_1 = constant_op.constant(5.0, name="delta_1") |