diff options
author | Shanqing Cai <cais@google.com> | 2018-08-28 19:03:56 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-28 19:12:42 -0700 |
commit | 3d35a07179d4d38d0cabac4415c550f1cbce00c0 (patch) | |
tree | 34c06cc6339a273722f48689e7e58de88d18c0ba /tensorflow/python/debug | |
parent | 20d5683b826be03776978af3b8108fc3b5dc9cb8 (diff) |
tfdbg: Add adjustable limit to total bytes dumped to disk
RELNOTES: tfdbg: Limit the total disk space occupied by dumped tensor data to 100 GBytes. Add environment variable `TFDBG_DISK_BYTES_LIMIT` to allow adjustment of this upper limit.
PiperOrigin-RevId: 210648585
Diffstat (limited to 'tensorflow/python/debug')
-rw-r--r-- | tensorflow/python/debug/BUILD | 17 | ||||
-rw-r--r-- | tensorflow/python/debug/lib/debug_utils.py | 12 | ||||
-rw-r--r-- | tensorflow/python/debug/wrappers/disk_usage_test.py | 109 | ||||
-rw-r--r-- | tensorflow/python/debug/wrappers/framework.py | 25 | ||||
-rw-r--r-- | tensorflow/python/debug/wrappers/hooks.py | 5 | ||||
-rw-r--r-- | tensorflow/python/debug/wrappers/local_cli_wrapper.py | 5 |
6 files changed, 166 insertions, 7 deletions
diff --git a/tensorflow/python/debug/BUILD b/tensorflow/python/debug/BUILD index 55d2709845..849d165bfa 100644 --- a/tensorflow/python/debug/BUILD +++ b/tensorflow/python/debug/BUILD @@ -1100,6 +1100,23 @@ py_test( ], ) +py_test( + name = "disk_usage_test", + size = "small", + srcs = ["wrappers/disk_usage_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":dumping_wrapper", + ":hooks", + "//tensorflow/python:client", + "//tensorflow/python:errors", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + "//tensorflow/python:training", + "//tensorflow/python:variables", + ], +) + sh_test( name = "examples_test", size = "medium", diff --git a/tensorflow/python/debug/lib/debug_utils.py b/tensorflow/python/debug/lib/debug_utils.py index f1e972940b..f2a43a6152 100644 --- a/tensorflow/python/debug/lib/debug_utils.py +++ b/tensorflow/python/debug/lib/debug_utils.py @@ -87,7 +87,8 @@ def watch_graph(run_options, op_type_regex_whitelist=None, tensor_dtype_regex_whitelist=None, tolerate_debug_op_creation_failures=False, - global_step=-1): + global_step=-1, + reset_disk_byte_usage=False): """Add debug watches to `RunOptions` for a TensorFlow graph. To watch all `Tensor`s on the graph, let both `node_name_regex_whitelist` @@ -130,6 +131,8 @@ def watch_graph(run_options, throwing exceptions. global_step: (`int`) Optional global_step count for this debug tensor watch. + reset_disk_byte_usage: (`bool`) whether to reset the tracked disk byte + usage to zero (default: `False`). """ if isinstance(debug_ops, str): @@ -170,6 +173,7 @@ def watch_graph(run_options, tolerate_debug_op_creation_failures=( tolerate_debug_op_creation_failures), global_step=global_step) + run_options.debug_options.reset_disk_byte_usage = reset_disk_byte_usage def watch_graph_with_blacklists(run_options, @@ -180,7 +184,8 @@ def watch_graph_with_blacklists(run_options, op_type_regex_blacklist=None, tensor_dtype_regex_blacklist=None, tolerate_debug_op_creation_failures=False, - global_step=-1): + global_step=-1, + reset_disk_byte_usage=False): """Add debug tensor watches, blacklisting nodes and op types. This is similar to `watch_graph()`, but the node names and op types are @@ -219,6 +224,8 @@ def watch_graph_with_blacklists(run_options, throwing exceptions. global_step: (`int`) Optional global_step count for this debug tensor watch. + reset_disk_byte_usage: (`bool`) whether to reset the tracked disk byte + usage to zero (default: `False`). """ if isinstance(debug_ops, str): @@ -259,3 +266,4 @@ def watch_graph_with_blacklists(run_options, tolerate_debug_op_creation_failures=( tolerate_debug_op_creation_failures), global_step=global_step) + run_options.debug_options.reset_disk_byte_usage = reset_disk_byte_usage diff --git a/tensorflow/python/debug/wrappers/disk_usage_test.py b/tensorflow/python/debug/wrappers/disk_usage_test.py new file mode 100644 index 0000000000..0874525966 --- /dev/null +++ b/tensorflow/python/debug/wrappers/disk_usage_test.py @@ -0,0 +1,109 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Debugger Wrapper Session Consisting of a Local Curses-based CLI.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import tempfile + +from tensorflow.python.client import session +from tensorflow.python.debug.wrappers import dumping_wrapper +from tensorflow.python.debug.wrappers import hooks +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import test_util +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import googletest +from tensorflow.python.training import monitored_session + + +class DumpingDebugWrapperDiskUsageLimitTest(test_util.TensorFlowTestCase): + + @classmethod + def setUpClass(cls): + # For efficient testing, set the disk usage bytes limit to a small + # number (10). + os.environ["TFDBG_DISK_BYTES_LIMIT"] = "10" + + def setUp(self): + self.session_root = tempfile.mkdtemp() + + self.v = variables.Variable(10.0, dtype=dtypes.float32, name="v") + self.delta = constant_op.constant(1.0, dtype=dtypes.float32, name="delta") + self.eta = constant_op.constant(-1.4, dtype=dtypes.float32, name="eta") + self.inc_v = state_ops.assign_add(self.v, self.delta, name="inc_v") + self.dec_v = state_ops.assign_add(self.v, self.eta, name="dec_v") + + self.sess = session.Session() + self.sess.run(self.v.initializer) + + def testWrapperSessionNotExceedingLimit(self): + def _watch_fn(fetches, feeds): + del fetches, feeds + return "DebugIdentity", r"(.*delta.*|.*inc_v.*)", r".*" + sess = dumping_wrapper.DumpingDebugWrapperSession( + self.sess, session_root=self.session_root, + watch_fn=_watch_fn, log_usage=False) + sess.run(self.inc_v) + + def testWrapperSessionExceedingLimit(self): + def _watch_fn(fetches, feeds): + del fetches, feeds + return "DebugIdentity", r".*delta.*", r".*" + sess = dumping_wrapper.DumpingDebugWrapperSession( + self.sess, session_root=self.session_root, + watch_fn=_watch_fn, log_usage=False) + # Due to the watch function, each run should dump only 1 tensor, + # which has a size of 4 bytes, which corresponds to the dumped 'delta:0' + # tensor of scalar shape and float32 dtype. + # 1st run should pass, after which the disk usage is at 4 bytes. + sess.run(self.inc_v) + # 2nd run should also pass, after which 8 bytes are used. + sess.run(self.inc_v) + # 3rd run should fail, because the total byte count (12) exceeds the + # limit (10) + with self.assertRaises(ValueError): + sess.run(self.inc_v) + + def testHookNotExceedingLimit(self): + def _watch_fn(fetches, feeds): + del fetches, feeds + return "DebugIdentity", r".*delta.*", r".*" + dumping_hook = hooks.DumpingDebugHook( + self.session_root, watch_fn=_watch_fn, log_usage=False) + mon_sess = monitored_session._HookedSession(self.sess, [dumping_hook]) + mon_sess.run(self.inc_v) + + def testHookExceedingLimit(self): + def _watch_fn(fetches, feeds): + del fetches, feeds + return "DebugIdentity", r".*delta.*", r".*" + dumping_hook = hooks.DumpingDebugHook( + self.session_root, watch_fn=_watch_fn, log_usage=False) + mon_sess = monitored_session._HookedSession(self.sess, [dumping_hook]) + # Like in `testWrapperSessionExceedingLimit`, the first two calls + # should be within the byte limit, but the third one should error + # out due to exceeding the limit. + mon_sess.run(self.inc_v) + mon_sess.run(self.inc_v) + with self.assertRaises(ValueError): + mon_sess.run(self.inc_v) + + +if __name__ == "__main__": + googletest.main() diff --git a/tensorflow/python/debug/wrappers/framework.py b/tensorflow/python/debug/wrappers/framework.py index b9524ce649..afda1fdc0d 100644 --- a/tensorflow/python/debug/wrappers/framework.py +++ b/tensorflow/python/debug/wrappers/framework.py @@ -447,13 +447,16 @@ class BaseDebugWrapperSession(session.SessionInterface): "callable_runner and callable_options are mutually exclusive, but " "are both specified in this call to BaseDebugWrapperSession.run().") - if not (callable_runner or callable_options): - self.increment_run_call_count() - elif callable_runner and (fetches or feed_dict): + if callable_runner and (fetches or feed_dict): raise ValueError( "callable_runner and fetches/feed_dict are mutually exclusive, " "but are used simultaneously.") + elif callable_options and (fetches or feed_dict): + raise ValueError( + "callable_options and fetches/feed_dict are mutually exclusive, " + "but are used simultaneously.") + self.increment_run_call_count() empty_fetches = not nest.flatten(fetches) if empty_fetches: tf_logging.info( @@ -649,6 +652,18 @@ class BaseDebugWrapperSession(session.SessionInterface): def increment_run_call_count(self): self._run_call_count += 1 + def _is_disk_usage_reset_each_run(self): + """Indicates whether disk usage is reset after each Session.run. + + Subclasses that clean up the disk usage after every run should + override this protected method. + + Returns: + (`bool`) Whether the disk usage amount is reset to zero after + each Session.run. + """ + return False + def _decorate_run_options_for_debug( self, run_options, @@ -686,7 +701,9 @@ class BaseDebugWrapperSession(session.SessionInterface): node_name_regex_whitelist=node_name_regex_whitelist, op_type_regex_whitelist=op_type_regex_whitelist, tensor_dtype_regex_whitelist=tensor_dtype_regex_whitelist, - tolerate_debug_op_creation_failures=tolerate_debug_op_creation_failures) + tolerate_debug_op_creation_failures=tolerate_debug_op_creation_failures, + reset_disk_byte_usage=(self._run_call_count == 1 or + self._is_disk_usage_reset_each_run())) def _decorate_run_options_for_profile(self, run_options): """Modify a RunOptions object for profiling TensorFlow graph execution. diff --git a/tensorflow/python/debug/wrappers/hooks.py b/tensorflow/python/debug/wrappers/hooks.py index 5e4604fda4..872b675506 100644 --- a/tensorflow/python/debug/wrappers/hooks.py +++ b/tensorflow/python/debug/wrappers/hooks.py @@ -188,6 +188,7 @@ class DumpingDebugHook(session_run_hook.SessionRunHook): pass def before_run(self, run_context): + reset_disk_byte_usage = False if not self._session_wrapper: self._session_wrapper = dumping_wrapper.DumpingDebugWrapperSession( run_context.session, @@ -195,6 +196,7 @@ class DumpingDebugHook(session_run_hook.SessionRunHook): watch_fn=self._watch_fn, thread_name_filter=self._thread_name_filter, log_usage=self._log_usage) + reset_disk_byte_usage = True self._session_wrapper.increment_run_call_count() @@ -212,7 +214,8 @@ class DumpingDebugHook(session_run_hook.SessionRunHook): op_type_regex_whitelist=watch_options.op_type_regex_whitelist, tensor_dtype_regex_whitelist=watch_options.tensor_dtype_regex_whitelist, tolerate_debug_op_creation_failures=( - watch_options.tolerate_debug_op_creation_failures)) + watch_options.tolerate_debug_op_creation_failures), + reset_disk_byte_usage=reset_disk_byte_usage) run_args = session_run_hook.SessionRunArgs( None, feed_dict=None, options=run_options) diff --git a/tensorflow/python/debug/wrappers/local_cli_wrapper.py b/tensorflow/python/debug/wrappers/local_cli_wrapper.py index 668ffb57f1..a3ce4d388b 100644 --- a/tensorflow/python/debug/wrappers/local_cli_wrapper.py +++ b/tensorflow/python/debug/wrappers/local_cli_wrapper.py @@ -124,6 +124,11 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession): self._ui_type = ui_type + def _is_disk_usage_reset_each_run(self): + # The dumped tensors are all cleaned up after every Session.run + # in a command-line wrapper. + return True + def _initialize_argparsers(self): self._argparsers = {} ap = argparse.ArgumentParser( |