aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Shanqing Cai <cais@google.com>2017-03-10 15:58:39 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-10 16:11:01 -0800
commit1a833fb731d0daca9a10d55dc9d9f3c4dcaeaa1e (patch)
tree109d70ea5b24ad54b3f40a794125188523e4f010
parent38c0d21b50a357691ccbf190183d48c585653411 (diff)
tfdbg: add tensor_dtype_regex_whitelist & WatchOptions
* tensor_dtype_regex_whitelist allows filtering by tensor data type during watch_graph() calls. * tensor_dtype_regex_blacklist allows filtering (out) by tensor data type during watch_graph_with_blacklists() calls. * Allow watch_fns to return a structured data type (class) WatchOptions for usability and extensibility. The legacy tuple return type (debug_ops, node_name_regex_whitelist, op_type_regex_whitelist) continues to be supported. Change: 149808584
-rw-r--r--tensorflow/python/debug/__init__.py2
-rw-r--r--tensorflow/python/debug/lib/debug_utils.py48
-rw-r--r--tensorflow/python/debug/lib/debug_utils_test.py58
-rw-r--r--tensorflow/python/debug/wrappers/dumping_wrapper_test.py81
-rw-r--r--tensorflow/python/debug/wrappers/framework.py119
-rw-r--r--tensorflow/python/debug/wrappers/hooks.py14
6 files changed, 270 insertions, 52 deletions
diff --git a/tensorflow/python/debug/__init__.py b/tensorflow/python/debug/__init__.py
index 2b2f9db862..d4a84f62cd 100644
--- a/tensorflow/python/debug/__init__.py
+++ b/tensorflow/python/debug/__init__.py
@@ -27,6 +27,7 @@ See the @{$python/tfdbg} guide.
@@DumpingDebugWrapperSession
@@LocalCLIDebugHook
@@LocalCLIDebugWrapperSession
+@@WatchOptions
"""
from __future__ import absolute_import
@@ -44,6 +45,7 @@ from tensorflow.python.debug.lib.debug_utils import watch_graph
from tensorflow.python.debug.lib.debug_utils import watch_graph_with_blacklists
from tensorflow.python.debug.wrappers.dumping_wrapper import DumpingDebugWrapperSession
+from tensorflow.python.debug.wrappers.framework import WatchOptions
from tensorflow.python.debug.wrappers.hooks import DumpingDebugHook
from tensorflow.python.debug.wrappers.hooks import LocalCLIDebugHook
from tensorflow.python.debug.wrappers.local_cli_wrapper import LocalCLIDebugWrapperSession
diff --git a/tensorflow/python/debug/lib/debug_utils.py b/tensorflow/python/debug/lib/debug_utils.py
index 32df3b43a5..7163936631 100644
--- a/tensorflow/python/debug/lib/debug_utils.py
+++ b/tensorflow/python/debug/lib/debug_utils.py
@@ -77,6 +77,7 @@ def watch_graph(run_options,
debug_urls=None,
node_name_regex_whitelist=None,
op_type_regex_whitelist=None,
+ tensor_dtype_regex_whitelist=None,
tolerate_debug_op_creation_failures=False,
global_step=-1):
"""Add debug watches to `RunOptions` for a TensorFlow graph.
@@ -104,6 +105,10 @@ def watch_graph(run_options,
are set, the two filtering operations will occur in a logical `AND`
relation. In other words, a node will be included if and only if it
hits both whitelists.
+ tensor_dtype_regex_whitelist: Regular-experssion whitelist for Tensor
+ data type, e.g., `"^int.*"`.
+ This whitelist operates in logical `AND` relations to the two whitelists
+ above.
tolerate_debug_op_creation_failures: (`bool`) whether debug op creation
failures (e.g., due to dtype incompatibility) are to be tolerated by not
throwing exceptions.
@@ -114,15 +119,12 @@ def watch_graph(run_options,
if isinstance(debug_ops, str):
debug_ops = [debug_ops]
- if node_name_regex_whitelist:
- node_name_pattern = re.compile(node_name_regex_whitelist)
- else:
- node_name_pattern = None
-
- if op_type_regex_whitelist:
- op_type_pattern = re.compile(op_type_regex_whitelist)
- else:
- op_type_pattern = None
+ node_name_pattern = (re.compile(node_name_regex_whitelist)
+ if node_name_regex_whitelist else None)
+ op_type_pattern = (re.compile(op_type_regex_whitelist)
+ if op_type_regex_whitelist else None)
+ tensor_dtype_pattern = (re.compile(tensor_dtype_regex_whitelist)
+ if tensor_dtype_regex_whitelist else None)
ops = graph.get_operations()
for op in ops:
@@ -139,6 +141,10 @@ def watch_graph(run_options,
continue
for slot in xrange(len(op.outputs)):
+ if (tensor_dtype_pattern and
+ not tensor_dtype_pattern.match(op.outputs[slot].dtype.name)):
+ continue
+
add_debug_tensor_watch(
run_options,
node_name,
@@ -156,6 +162,7 @@ def watch_graph_with_blacklists(run_options,
debug_urls=None,
node_name_regex_blacklist=None,
op_type_regex_blacklist=None,
+ tensor_dtype_regex_blacklist=None,
tolerate_debug_op_creation_failures=False,
global_step=-1):
"""Add debug tensor watches, blacklisting nodes and op types.
@@ -182,6 +189,10 @@ def watch_graph_with_blacklists(run_options,
relation. In other words, a node will be excluded if it hits either of
the two blacklists; a node will be included if and only if it hits
neither of the blacklists.
+ tensor_dtype_regex_blacklist: Regular-experssion blacklist for Tensor
+ data type, e.g., `"^int.*"`.
+ This blacklist operates in logical `OR` relations to the two whitelists
+ above.
tolerate_debug_op_creation_failures: (`bool`) whether debug op creation
failures (e.g., due to dtype incompatibility) are to be tolerated by not
throwing exceptions.
@@ -192,15 +203,12 @@ def watch_graph_with_blacklists(run_options,
if isinstance(debug_ops, str):
debug_ops = [debug_ops]
- if node_name_regex_blacklist:
- node_name_pattern = re.compile(node_name_regex_blacklist)
- else:
- node_name_pattern = None
-
- if op_type_regex_blacklist:
- op_type_pattern = re.compile(op_type_regex_blacklist)
- else:
- op_type_pattern = None
+ node_name_pattern = (re.compile(node_name_regex_blacklist) if
+ node_name_regex_blacklist else None)
+ op_type_pattern = (re.compile(op_type_regex_blacklist) if
+ op_type_regex_blacklist else None)
+ tensor_dtype_pattern = (re.compile(tensor_dtype_regex_blacklist) if
+ tensor_dtype_regex_blacklist else None)
ops = graph.get_operations()
for op in ops:
@@ -217,6 +225,10 @@ def watch_graph_with_blacklists(run_options,
continue
for slot in xrange(len(op.outputs)):
+ if (tensor_dtype_pattern and
+ tensor_dtype_pattern.match(op.outputs[slot].dtype.name)):
+ continue
+
add_debug_tensor_watch(
run_options,
node_name,
diff --git a/tensorflow/python/debug/lib/debug_utils_test.py b/tensorflow/python/debug/lib/debug_utils_test.py
index 49b8711f10..d4978fa235 100644
--- a/tensorflow/python/debug/lib/debug_utils_test.py
+++ b/tensorflow/python/debug/lib/debug_utils_test.py
@@ -253,6 +253,31 @@ class DebugUtilsTest(test_util.TensorFlowTestCase):
["DebugIdentity"], ["file:///tmp/tfdbg_1"])
self.assertEqual(["p1"], node_names)
+ def testWatchGraph_tensorDTypeWhitelist(self):
+ debug_utils.watch_graph(
+ self._run_options,
+ self._graph,
+ debug_urls="file:///tmp/tfdbg_1",
+ tensor_dtype_regex_whitelist=".*_ref")
+
+ node_names = self._verify_watches(
+ self._run_options.debug_options.debug_tensor_watch_opts, 0,
+ ["DebugIdentity"], ["file:///tmp/tfdbg_1"])
+ self.assertItemsEqual(["a1", "a1/Assign", "b", "b/Assign"], node_names)
+
+ def testWatchGraph_nodeNameAndTensorDTypeWhitelists(self):
+ debug_utils.watch_graph(
+ self._run_options,
+ self._graph,
+ debug_urls="file:///tmp/tfdbg_1",
+ node_name_regex_whitelist="^a.*",
+ tensor_dtype_regex_whitelist=".*_ref")
+
+ node_names = self._verify_watches(
+ self._run_options.debug_options.debug_tensor_watch_opts, 0,
+ ["DebugIdentity"], ["file:///tmp/tfdbg_1"])
+ self.assertItemsEqual(["a1", "a1/Assign"], node_names)
+
def testWatchGraph_nodeNameBlacklist(self):
debug_utils.watch_graph_with_blacklists(
self._run_options,
@@ -292,6 +317,39 @@ class DebugUtilsTest(test_util.TensorFlowTestCase):
["DebugIdentity"], ["file:///tmp/tfdbg_1"])
self.assertEqual(["s"], node_names)
+ def testWatchGraph_tensorDTypeBlacklists(self):
+ debug_utils.watch_graph_with_blacklists(
+ self._run_options,
+ self._graph,
+ debug_urls="file:///tmp/tfdbg_1",
+ tensor_dtype_regex_blacklist=".*_ref")
+
+ node_names = self._verify_watches(
+ self._run_options.debug_options.debug_tensor_watch_opts, 0,
+ ["DebugIdentity"], ["file:///tmp/tfdbg_1"])
+ self.assertNotIn("a1", node_names)
+ self.assertNotIn("a1/Assign", node_names)
+ self.assertNotIn("b", node_names)
+ self.assertNotIn("b/Assign", node_names)
+ self.assertIn("s", node_names)
+
+ def testWatchGraph_nodeNameAndTensorDTypeBlacklists(self):
+ debug_utils.watch_graph_with_blacklists(
+ self._run_options,
+ self._graph,
+ debug_urls="file:///tmp/tfdbg_1",
+ node_name_regex_blacklist="^s$",
+ tensor_dtype_regex_blacklist=".*_ref")
+
+ node_names = self._verify_watches(
+ self._run_options.debug_options.debug_tensor_watch_opts, 0,
+ ["DebugIdentity"], ["file:///tmp/tfdbg_1"])
+ self.assertNotIn("a1", node_names)
+ self.assertNotIn("a1/Assign", node_names)
+ self.assertNotIn("b", node_names)
+ self.assertNotIn("b/Assign", node_names)
+ self.assertNotIn("s", node_names)
+
if __name__ == "__main__":
googletest.main()
diff --git a/tensorflow/python/debug/wrappers/dumping_wrapper_test.py b/tensorflow/python/debug/wrappers/dumping_wrapper_test.py
index e4096b623c..54bffd689a 100644
--- a/tensorflow/python/debug/wrappers/dumping_wrapper_test.py
+++ b/tensorflow/python/debug/wrappers/dumping_wrapper_test.py
@@ -26,6 +26,7 @@ from tensorflow.python.client import session
from tensorflow.python.debug.lib import debug_data
from tensorflow.python.debug.lib import stepper
from tensorflow.python.debug.wrappers import dumping_wrapper
+from tensorflow.python.debug.wrappers import framework
from tensorflow.python.debug.wrappers import hooks
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -144,7 +145,7 @@ class DumpingDebugWrapperSessionTest(test_util.TensorFlowTestCase):
watch_fn=bad_watch_fn,
log_usage=False)
- def testDumpingWithWatchFnOnFetchesWorks(self):
+ def testDumpingWithLegacyWatchFnOnFetchesWorks(self):
"""Use a watch_fn that returns different whitelists for different runs."""
def watch_fn(fetches, feeds):
@@ -186,8 +187,8 @@ class DumpingDebugWrapperSessionTest(test_util.TensorFlowTestCase):
self.assertEqual(repr(self.dec_v), dump.run_fetches_info)
self.assertEqual(repr(None), dump.run_feed_keys_info)
- def testDumpingWithWatchFnWithNonDefaultDebugOpsWorks(self):
- """Use a watch_fn tha specifies non-default debug ops."""
+ def testDumpingWithLegacyWatchFnWithNonDefaultDebugOpsWorks(self):
+ """Use a watch_fn that specifies non-default debug ops."""
def watch_fn(fetches, feeds):
del fetches, feeds
@@ -209,6 +210,37 @@ class DumpingDebugWrapperSessionTest(test_util.TensorFlowTestCase):
self.assertEqual(12,
len(dump.get_tensors("v", 0, "DebugNumericSummary")[0]))
+ def testDumpingWithWatchFnWithNonDefaultDebugOpsWorks(self):
+ """Use a watch_fn that specifies non-default debug ops."""
+
+ def watch_fn(fetches, feeds):
+ del fetches, feeds
+ return framework.WatchOptions(
+ debug_ops=["DebugIdentity", "DebugNumericSummary"],
+ node_name_regex_whitelist=r"^v.*",
+ op_type_regex_whitelist=r".*",
+ tensor_dtype_regex_whitelist=".*_ref")
+
+ sess = dumping_wrapper.DumpingDebugWrapperSession(
+ self.sess,
+ session_root=self.session_root,
+ watch_fn=watch_fn,
+ log_usage=False)
+
+ sess.run(self.inc_v)
+
+ dump_dirs = glob.glob(os.path.join(self.session_root, "run_*"))
+ self.assertEqual(1, len(dump_dirs))
+ dump = debug_data.DebugDumpDir(dump_dirs[0])
+
+ self.assertAllClose([10.0], dump.get_tensors("v", 0, "DebugIdentity"))
+ self.assertEqual(12,
+ len(dump.get_tensors("v", 0, "DebugNumericSummary")[0]))
+
+ dumped_nodes = [dump.node_name for dump in dump.dumped_tensor_data]
+ self.assertNotIn("inc_v", dumped_nodes)
+ self.assertNotIn("delta", dumped_nodes)
+
def testDumpingDebugHookWithoutWatchFnWorks(self):
dumping_hook = hooks.DumpingDebugHook(self.session_root, log_usage=False)
mon_sess = monitored_session._HookedSession(self.sess, [dumping_hook])
@@ -231,6 +263,49 @@ class DumpingDebugWrapperSessionTest(test_util.TensorFlowTestCase):
del fetches, feed_dict
watch_fn_state["run_counter"] += 1
if watch_fn_state["run_counter"] % 2 == 1:
+ # If odd-index run (1-based), watch every ref-type tensor.
+ return framework.WatchOptions(
+ debug_ops="DebugIdentity",
+ tensor_dtype_regex_whitelist=".*_ref")
+ else:
+ # If even-index run, watch nothing.
+ return framework.WatchOptions(
+ debug_ops="DebugIdentity",
+ node_name_regex_whitelist=r"^$",
+ op_type_regex_whitelist=r"^$")
+
+ dumping_hook = hooks.DumpingDebugHook(
+ self.session_root, watch_fn=counting_watch_fn, log_usage=False)
+ mon_sess = monitored_session._HookedSession(self.sess, [dumping_hook])
+ for _ in range(4):
+ mon_sess.run(self.inc_v)
+
+ dump_dirs = glob.glob(os.path.join(self.session_root, "run_*"))
+ dump_dirs = sorted(
+ dump_dirs, key=lambda x: int(os.path.basename(x).split("_")[1]))
+ self.assertEqual(4, len(dump_dirs))
+
+ for i, dump_dir in enumerate(dump_dirs):
+ self._assert_correct_run_subdir_naming(os.path.basename(dump_dir))
+ dump = debug_data.DebugDumpDir(dump_dir)
+ if i % 2 == 0:
+ self.assertAllClose([10.0 + 1.0 * i],
+ dump.get_tensors("v", 0, "DebugIdentity"))
+ self.assertNotIn("delta",
+ [datum.node_name for datum in dump.dumped_tensor_data])
+ else:
+ self.assertEqual(0, dump.size)
+
+ self.assertEqual(repr(self.inc_v), dump.run_fetches_info)
+ self.assertEqual(repr(None), dump.run_feed_keys_info)
+
+ def testDumpingDebugHookWithStatefulLegacyWatchFnWorks(self):
+ watch_fn_state = {"run_counter": 0}
+
+ def counting_watch_fn(fetches, feed_dict):
+ del fetches, feed_dict
+ watch_fn_state["run_counter"] += 1
+ if watch_fn_state["run_counter"] % 2 == 1:
# If odd-index run (1-based), watch everything.
return "DebugIdentity", r".*", r".*"
else:
diff --git a/tensorflow/python/debug/wrappers/framework.py b/tensorflow/python/debug/wrappers/framework.py
index 693015e718..bf1711e26a 100644
--- a/tensorflow/python/debug/wrappers/framework.py
+++ b/tensorflow/python/debug/wrappers/framework.py
@@ -238,7 +238,9 @@ class OnRunStartResponse(object):
debug_urls,
debug_ops="DebugIdentity",
node_name_regex_whitelist=None,
- op_type_regex_whitelist=None):
+ op_type_regex_whitelist=None,
+ tensor_dtype_regex_whitelist=None,
+ tolerate_debug_op_creation_failures=False):
"""Constructor of `OnRunStartResponse`.
Args:
@@ -251,6 +253,10 @@ class OnRunStartResponse(object):
node_name_regex_whitelist: Regular-expression whitelist for node
name.
op_type_regex_whitelist: Regular-expression whitelist for op type.
+ tensor_dtype_regex_whitelist: Regular-expression whitelist for tensor
+ dtype.
+ tolerate_debug_op_creation_failures: Whether debug op creation failures
+ are to be tolerated.
"""
_check_type(action, str)
@@ -263,6 +269,9 @@ class OnRunStartResponse(object):
self.node_name_regex_whitelist = node_name_regex_whitelist
self.op_type_regex_whitelist = op_type_regex_whitelist
+ self.tensor_dtype_regex_whitelist = tensor_dtype_regex_whitelist
+ self.tolerate_debug_op_creation_failures = (
+ tolerate_debug_op_creation_failures)
class OnRunEndRequest(object):
@@ -412,7 +421,11 @@ class BaseDebugWrapperSession(session.SessionInterface):
run_start_resp.debug_urls,
debug_ops=run_start_resp.debug_ops,
node_name_regex_whitelist=run_start_resp.node_name_regex_whitelist,
- op_type_regex_whitelist=run_start_resp.op_type_regex_whitelist)
+ op_type_regex_whitelist=run_start_resp.op_type_regex_whitelist,
+ tensor_dtype_regex_whitelist=(
+ run_start_resp.tensor_dtype_regex_whitelist),
+ tolerate_debug_op_creation_failures=(
+ run_start_resp.tolerate_debug_op_creation_failures))
# Invoke the run() method of the wrapped Session. Catch any TensorFlow
# runtime errors.
@@ -474,7 +487,9 @@ class BaseDebugWrapperSession(session.SessionInterface):
debug_urls,
debug_ops="DebugIdentity",
node_name_regex_whitelist=None,
- op_type_regex_whitelist=None):
+ op_type_regex_whitelist=None,
+ tensor_dtype_regex_whitelist=None,
+ tolerate_debug_op_creation_failures=False):
"""Modify a RunOptions object for debug tensor watching.
Specifies request for outputting partition graphs. Adds
@@ -488,6 +503,10 @@ class BaseDebugWrapperSession(session.SessionInterface):
node_name_regex_whitelist: Regular-expression whitelist for node
name.
op_type_regex_whitelist: Regular-expression whitelist for op type.
+ tensor_dtype_regex_whitelist: Regular-expression whitelist for tensor
+ dtype.
+ tolerate_debug_op_creation_failures: Whether debug op creation failures
+ are to be tolerated.
"""
run_options.output_partition_graphs = True
@@ -497,7 +516,9 @@ class BaseDebugWrapperSession(session.SessionInterface):
debug_urls=debug_urls,
debug_ops=debug_ops,
node_name_regex_whitelist=node_name_regex_whitelist,
- op_type_regex_whitelist=op_type_regex_whitelist)
+ op_type_regex_whitelist=op_type_regex_whitelist,
+ tensor_dtype_regex_whitelist=tensor_dtype_regex_whitelist,
+ tolerate_debug_op_creation_failures=tolerate_debug_op_creation_failures)
@abc.abstractmethod
def on_session_init(self, request):
@@ -582,6 +603,56 @@ class BaseDebugWrapperSession(session.SessionInterface):
"""
+class WatchOptions(object):
+ """Type for return values of watch_fn."""
+
+ def __init__(self,
+ debug_ops=None,
+ node_name_regex_whitelist=None,
+ op_type_regex_whitelist=None,
+ tensor_dtype_regex_whitelist=None,
+ tolerate_debug_op_creation_failures=False):
+ """Constructor of WatchOptions: Debug watch options.
+
+ Used as return values of `watch_fn`s.
+
+ Args:
+ debug_ops: (`str` or `list of str`) Debug ops to be used.
+ node_name_regex_whitelist: Regular-expression whitelist for node_name,
+ e.g., `"(weight_[0-9]+|bias_.*)"`
+ op_type_regex_whitelist: Regular-expression whitelist for the op type of
+ nodes, e.g., `"(Variable|Add)"`.
+ If both `node_name_regex_whitelist` and `op_type_regex_whitelist`
+ are set, the two filtering operations will occur in a logical `AND`
+ relation. In other words, a node will be included if and only if it
+ hits both whitelists.
+ tensor_dtype_regex_whitelist: Regular-experssion whitelist for Tensor
+ data type, e.g., `"^int.*"`.
+ This whitelist operates in logical `AND` relations to the two whitelists
+ above.
+ tolerate_debug_op_creation_failures: (`bool`) whether debug op creation
+ failures (e.g., due to dtype incompatibility) are to be tolerated by not
+ throwing exceptions.
+ """
+ if debug_ops:
+ self.debug_ops = debug_ops
+ else:
+ self.debug_ops = ["DebugIdentity"]
+ self.node_name_regex_whitelist = node_name_regex_whitelist
+ self.op_type_regex_whitelist = op_type_regex_whitelist
+ self.tensor_dtype_regex_whitelist = tensor_dtype_regex_whitelist
+ self.tolerate_debug_op_creation_failures = (
+ tolerate_debug_op_creation_failures)
+
+ def __repr__(self):
+ return ("WatchOptions(debug_ops=%r, node_name_regex_whitelist=%r, "
+ "op_type_regex_whitelist=%r, tensor_dtype_regex_whitelist=%r, "
+ "tolerate_debug_op_creation_failures=%r)" % (
+ self.debug_ops, self.node_name_regex_whitelist,
+ self.op_type_regex_whitelist, self.tensor_dtype_regex_whitelist,
+ self.tolerate_debug_op_creation_failures))
+
+
class NonInteractiveDebugWrapperSession(BaseDebugWrapperSession):
"""Base class for non-interactive (i.e., non-CLI) debug wrapper sessions."""
@@ -645,16 +716,18 @@ class NonInteractiveDebugWrapperSession(BaseDebugWrapperSession):
def on_run_start(self, request):
"""See doc of BaseDebugWrapperSession.on_run_start."""
- (debug_urls, debug_ops, node_name_regex_whitelist,
- op_type_regex_whitelist) = self._prepare_run_watch_config(
- request.fetches, request.feed_dict)
+ debug_urls, watch_opts = self._prepare_run_watch_config(
+ request.fetches, request.feed_dict)
return OnRunStartResponse(
OnRunStartAction.DEBUG_RUN,
debug_urls,
- debug_ops=debug_ops,
- node_name_regex_whitelist=node_name_regex_whitelist,
- op_type_regex_whitelist=op_type_regex_whitelist)
+ debug_ops=watch_opts.debug_ops,
+ node_name_regex_whitelist=watch_opts.node_name_regex_whitelist,
+ op_type_regex_whitelist=watch_opts.op_type_regex_whitelist,
+ tensor_dtype_regex_whitelist=watch_opts.tensor_dtype_regex_whitelist,
+ tolerate_debug_op_creation_failures=(
+ watch_opts.tolerate_debug_op_creation_failures))
def _prepare_run_watch_config(self, fetches, feed_dict):
"""Get the debug_urls, and node/op whitelists for the current run() call.
@@ -666,24 +739,20 @@ class NonInteractiveDebugWrapperSession(BaseDebugWrapperSession):
Returns:
debug_urls: (str or list of str) Debug URLs for the current run() call.
Currently, the list consists of only one URL that is a file:// URL.
- debug_ops: (str or list of str) Debug op(s) to be used by the
- debugger.
- node_name_regex_whitelist: (str or regex) Regular-expression whitelist for
- node name. Same as the same-name argument to debug_utils.watch_graph.
- op_type_regex_whitelist: (str or regex) Regular-expression whitelist for
- op type. Same as the same-name argument to debug_utils.watch_graph.
+ watch_options: (WatchOptions) The return value of a watch_fn, containing
+ options including debug_ops, and whitelists.
"""
debug_urls = self._prepare_run_debug_urls(fetches, feed_dict)
- debug_ops = "DebugIdentity"
- node_name_regex_whitelist = None
- op_type_regex_whitelist = None
- if self._watch_fn is not None:
- debug_ops, node_name_regex_whitelist, op_type_regex_whitelist = (
- self._watch_fn(fetches, feed_dict))
-
- return (debug_urls, debug_ops, node_name_regex_whitelist,
- op_type_regex_whitelist)
+ if self._watch_fn is None:
+ watch_options = WatchOptions()
+ else:
+ watch_options = self._watch_fn(fetches, feed_dict)
+ if isinstance(watch_options, tuple):
+ # For legacy return type (tuples).
+ watch_options = WatchOptions(*watch_options)
+
+ return debug_urls, watch_options
def on_run_end(self, request):
"""See doc of BaseDebugWrapperSession.on_run_end."""
diff --git a/tensorflow/python/debug/wrappers/hooks.py b/tensorflow/python/debug/wrappers/hooks.py
index 15bde70ee7..bdb30ffc5d 100644
--- a/tensorflow/python/debug/wrappers/hooks.py
+++ b/tensorflow/python/debug/wrappers/hooks.py
@@ -176,17 +176,19 @@ class DumpingDebugHook(session_run_hook.SessionRunHook,
self._run_call_count += 1
- (debug_urls, debug_ops, node_name_regex_whitelist,
- op_type_regex_whitelist) = self._prepare_run_watch_config(
- run_context.original_args.fetches, run_context.original_args.feed_dict)
+ debug_urls, watch_options = self._prepare_run_watch_config(
+ run_context.original_args.fetches, run_context.original_args.feed_dict)
run_options = config_pb2.RunOptions()
debug_utils.watch_graph(
run_options,
run_context.session.graph,
debug_urls=debug_urls,
- debug_ops=debug_ops,
- node_name_regex_whitelist=node_name_regex_whitelist,
- op_type_regex_whitelist=op_type_regex_whitelist)
+ debug_ops=watch_options.debug_ops,
+ node_name_regex_whitelist=watch_options.node_name_regex_whitelist,
+ op_type_regex_whitelist=watch_options.op_type_regex_whitelist,
+ tensor_dtype_regex_whitelist=watch_options.tensor_dtype_regex_whitelist,
+ tolerate_debug_op_creation_failures=(
+ watch_options.tolerate_debug_op_creation_failures))
run_args = session_run_hook.SessionRunArgs(
None, feed_dict=None, options=run_options)