aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/debug
diff options
context:
space:
mode:
authorGravatar Shanqing Cai <cais@google.com>2018-08-28 19:03:56 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-28 19:12:42 -0700
commit3d35a07179d4d38d0cabac4415c550f1cbce00c0 (patch)
tree34c06cc6339a273722f48689e7e58de88d18c0ba /tensorflow/python/debug
parent20d5683b826be03776978af3b8108fc3b5dc9cb8 (diff)
tfdbg: Add adjustable limit to total bytes dumped to disk
RELNOTES: tfdbg: Limit the total disk space occupied by dumped tensor data to 100 GBytes. Add environment variable `TFDBG_DISK_BYTES_LIMIT` to allow adjustment of this upper limit. PiperOrigin-RevId: 210648585
Diffstat (limited to 'tensorflow/python/debug')
-rw-r--r--tensorflow/python/debug/BUILD17
-rw-r--r--tensorflow/python/debug/lib/debug_utils.py12
-rw-r--r--tensorflow/python/debug/wrappers/disk_usage_test.py109
-rw-r--r--tensorflow/python/debug/wrappers/framework.py25
-rw-r--r--tensorflow/python/debug/wrappers/hooks.py5
-rw-r--r--tensorflow/python/debug/wrappers/local_cli_wrapper.py5
6 files changed, 166 insertions, 7 deletions
diff --git a/tensorflow/python/debug/BUILD b/tensorflow/python/debug/BUILD
index 55d2709845..849d165bfa 100644
--- a/tensorflow/python/debug/BUILD
+++ b/tensorflow/python/debug/BUILD
@@ -1100,6 +1100,23 @@ py_test(
],
)
+py_test(
+ name = "disk_usage_test",
+ size = "small",
+ srcs = ["wrappers/disk_usage_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":dumping_wrapper",
+ ":hooks",
+ "//tensorflow/python:client",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:platform_test",
+ "//tensorflow/python:training",
+ "//tensorflow/python:variables",
+ ],
+)
+
sh_test(
name = "examples_test",
size = "medium",
diff --git a/tensorflow/python/debug/lib/debug_utils.py b/tensorflow/python/debug/lib/debug_utils.py
index f1e972940b..f2a43a6152 100644
--- a/tensorflow/python/debug/lib/debug_utils.py
+++ b/tensorflow/python/debug/lib/debug_utils.py
@@ -87,7 +87,8 @@ def watch_graph(run_options,
op_type_regex_whitelist=None,
tensor_dtype_regex_whitelist=None,
tolerate_debug_op_creation_failures=False,
- global_step=-1):
+ global_step=-1,
+ reset_disk_byte_usage=False):
"""Add debug watches to `RunOptions` for a TensorFlow graph.
To watch all `Tensor`s on the graph, let both `node_name_regex_whitelist`
@@ -130,6 +131,8 @@ def watch_graph(run_options,
throwing exceptions.
global_step: (`int`) Optional global_step count for this debug tensor
watch.
+ reset_disk_byte_usage: (`bool`) whether to reset the tracked disk byte
+ usage to zero (default: `False`).
"""
if isinstance(debug_ops, str):
@@ -170,6 +173,7 @@ def watch_graph(run_options,
tolerate_debug_op_creation_failures=(
tolerate_debug_op_creation_failures),
global_step=global_step)
+ run_options.debug_options.reset_disk_byte_usage = reset_disk_byte_usage
def watch_graph_with_blacklists(run_options,
@@ -180,7 +184,8 @@ def watch_graph_with_blacklists(run_options,
op_type_regex_blacklist=None,
tensor_dtype_regex_blacklist=None,
tolerate_debug_op_creation_failures=False,
- global_step=-1):
+ global_step=-1,
+ reset_disk_byte_usage=False):
"""Add debug tensor watches, blacklisting nodes and op types.
This is similar to `watch_graph()`, but the node names and op types are
@@ -219,6 +224,8 @@ def watch_graph_with_blacklists(run_options,
throwing exceptions.
global_step: (`int`) Optional global_step count for this debug tensor
watch.
+ reset_disk_byte_usage: (`bool`) whether to reset the tracked disk byte
+ usage to zero (default: `False`).
"""
if isinstance(debug_ops, str):
@@ -259,3 +266,4 @@ def watch_graph_with_blacklists(run_options,
tolerate_debug_op_creation_failures=(
tolerate_debug_op_creation_failures),
global_step=global_step)
+ run_options.debug_options.reset_disk_byte_usage = reset_disk_byte_usage
diff --git a/tensorflow/python/debug/wrappers/disk_usage_test.py b/tensorflow/python/debug/wrappers/disk_usage_test.py
new file mode 100644
index 0000000000..0874525966
--- /dev/null
+++ b/tensorflow/python/debug/wrappers/disk_usage_test.py
@@ -0,0 +1,109 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Debugger Wrapper Session Consisting of a Local Curses-based CLI."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import tempfile
+
+from tensorflow.python.client import session
+from tensorflow.python.debug.wrappers import dumping_wrapper
+from tensorflow.python.debug.wrappers import hooks
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import googletest
+from tensorflow.python.training import monitored_session
+
+
+class DumpingDebugWrapperDiskUsageLimitTest(test_util.TensorFlowTestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ # For efficient testing, set the disk usage bytes limit to a small
+ # number (10).
+ os.environ["TFDBG_DISK_BYTES_LIMIT"] = "10"
+
+ def setUp(self):
+ self.session_root = tempfile.mkdtemp()
+
+ self.v = variables.Variable(10.0, dtype=dtypes.float32, name="v")
+ self.delta = constant_op.constant(1.0, dtype=dtypes.float32, name="delta")
+ self.eta = constant_op.constant(-1.4, dtype=dtypes.float32, name="eta")
+ self.inc_v = state_ops.assign_add(self.v, self.delta, name="inc_v")
+ self.dec_v = state_ops.assign_add(self.v, self.eta, name="dec_v")
+
+ self.sess = session.Session()
+ self.sess.run(self.v.initializer)
+
+ def testWrapperSessionNotExceedingLimit(self):
+ def _watch_fn(fetches, feeds):
+ del fetches, feeds
+ return "DebugIdentity", r"(.*delta.*|.*inc_v.*)", r".*"
+ sess = dumping_wrapper.DumpingDebugWrapperSession(
+ self.sess, session_root=self.session_root,
+ watch_fn=_watch_fn, log_usage=False)
+ sess.run(self.inc_v)
+
+ def testWrapperSessionExceedingLimit(self):
+ def _watch_fn(fetches, feeds):
+ del fetches, feeds
+ return "DebugIdentity", r".*delta.*", r".*"
+ sess = dumping_wrapper.DumpingDebugWrapperSession(
+ self.sess, session_root=self.session_root,
+ watch_fn=_watch_fn, log_usage=False)
+ # Due to the watch function, each run should dump only 1 tensor,
+ # which has a size of 4 bytes, which corresponds to the dumped 'delta:0'
+ # tensor of scalar shape and float32 dtype.
+ # 1st run should pass, after which the disk usage is at 4 bytes.
+ sess.run(self.inc_v)
+ # 2nd run should also pass, after which 8 bytes are used.
+ sess.run(self.inc_v)
+ # 3rd run should fail, because the total byte count (12) exceeds the
+ # limit (10)
+ with self.assertRaises(ValueError):
+ sess.run(self.inc_v)
+
+ def testHookNotExceedingLimit(self):
+ def _watch_fn(fetches, feeds):
+ del fetches, feeds
+ return "DebugIdentity", r".*delta.*", r".*"
+ dumping_hook = hooks.DumpingDebugHook(
+ self.session_root, watch_fn=_watch_fn, log_usage=False)
+ mon_sess = monitored_session._HookedSession(self.sess, [dumping_hook])
+ mon_sess.run(self.inc_v)
+
+ def testHookExceedingLimit(self):
+ def _watch_fn(fetches, feeds):
+ del fetches, feeds
+ return "DebugIdentity", r".*delta.*", r".*"
+ dumping_hook = hooks.DumpingDebugHook(
+ self.session_root, watch_fn=_watch_fn, log_usage=False)
+ mon_sess = monitored_session._HookedSession(self.sess, [dumping_hook])
+ # Like in `testWrapperSessionExceedingLimit`, the first two calls
+ # should be within the byte limit, but the third one should error
+ # out due to exceeding the limit.
+ mon_sess.run(self.inc_v)
+ mon_sess.run(self.inc_v)
+ with self.assertRaises(ValueError):
+ mon_sess.run(self.inc_v)
+
+
+if __name__ == "__main__":
+ googletest.main()
diff --git a/tensorflow/python/debug/wrappers/framework.py b/tensorflow/python/debug/wrappers/framework.py
index b9524ce649..afda1fdc0d 100644
--- a/tensorflow/python/debug/wrappers/framework.py
+++ b/tensorflow/python/debug/wrappers/framework.py
@@ -447,13 +447,16 @@ class BaseDebugWrapperSession(session.SessionInterface):
"callable_runner and callable_options are mutually exclusive, but "
"are both specified in this call to BaseDebugWrapperSession.run().")
- if not (callable_runner or callable_options):
- self.increment_run_call_count()
- elif callable_runner and (fetches or feed_dict):
+ if callable_runner and (fetches or feed_dict):
raise ValueError(
"callable_runner and fetches/feed_dict are mutually exclusive, "
"but are used simultaneously.")
+ elif callable_options and (fetches or feed_dict):
+ raise ValueError(
+ "callable_options and fetches/feed_dict are mutually exclusive, "
+ "but are used simultaneously.")
+ self.increment_run_call_count()
empty_fetches = not nest.flatten(fetches)
if empty_fetches:
tf_logging.info(
@@ -649,6 +652,18 @@ class BaseDebugWrapperSession(session.SessionInterface):
def increment_run_call_count(self):
self._run_call_count += 1
+ def _is_disk_usage_reset_each_run(self):
+ """Indicates whether disk usage is reset after each Session.run.
+
+ Subclasses that clean up the disk usage after every run should
+ override this protected method.
+
+ Returns:
+ (`bool`) Whether the disk usage amount is reset to zero after
+ each Session.run.
+ """
+ return False
+
def _decorate_run_options_for_debug(
self,
run_options,
@@ -686,7 +701,9 @@ class BaseDebugWrapperSession(session.SessionInterface):
node_name_regex_whitelist=node_name_regex_whitelist,
op_type_regex_whitelist=op_type_regex_whitelist,
tensor_dtype_regex_whitelist=tensor_dtype_regex_whitelist,
- tolerate_debug_op_creation_failures=tolerate_debug_op_creation_failures)
+ tolerate_debug_op_creation_failures=tolerate_debug_op_creation_failures,
+ reset_disk_byte_usage=(self._run_call_count == 1 or
+ self._is_disk_usage_reset_each_run()))
def _decorate_run_options_for_profile(self, run_options):
"""Modify a RunOptions object for profiling TensorFlow graph execution.
diff --git a/tensorflow/python/debug/wrappers/hooks.py b/tensorflow/python/debug/wrappers/hooks.py
index 5e4604fda4..872b675506 100644
--- a/tensorflow/python/debug/wrappers/hooks.py
+++ b/tensorflow/python/debug/wrappers/hooks.py
@@ -188,6 +188,7 @@ class DumpingDebugHook(session_run_hook.SessionRunHook):
pass
def before_run(self, run_context):
+ reset_disk_byte_usage = False
if not self._session_wrapper:
self._session_wrapper = dumping_wrapper.DumpingDebugWrapperSession(
run_context.session,
@@ -195,6 +196,7 @@ class DumpingDebugHook(session_run_hook.SessionRunHook):
watch_fn=self._watch_fn,
thread_name_filter=self._thread_name_filter,
log_usage=self._log_usage)
+ reset_disk_byte_usage = True
self._session_wrapper.increment_run_call_count()
@@ -212,7 +214,8 @@ class DumpingDebugHook(session_run_hook.SessionRunHook):
op_type_regex_whitelist=watch_options.op_type_regex_whitelist,
tensor_dtype_regex_whitelist=watch_options.tensor_dtype_regex_whitelist,
tolerate_debug_op_creation_failures=(
- watch_options.tolerate_debug_op_creation_failures))
+ watch_options.tolerate_debug_op_creation_failures),
+ reset_disk_byte_usage=reset_disk_byte_usage)
run_args = session_run_hook.SessionRunArgs(
None, feed_dict=None, options=run_options)
diff --git a/tensorflow/python/debug/wrappers/local_cli_wrapper.py b/tensorflow/python/debug/wrappers/local_cli_wrapper.py
index 668ffb57f1..a3ce4d388b 100644
--- a/tensorflow/python/debug/wrappers/local_cli_wrapper.py
+++ b/tensorflow/python/debug/wrappers/local_cli_wrapper.py
@@ -124,6 +124,11 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession):
self._ui_type = ui_type
+ def _is_disk_usage_reset_each_run(self):
+ # The dumped tensors are all cleaned up after every Session.run
+ # in a command-line wrapper.
+ return True
+
def _initialize_argparsers(self):
self._argparsers = {}
ap = argparse.ArgumentParser(