aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/debug
diff options
context:
space:
mode:
authorGravatar Shanqing Cai <cais@google.com>2018-03-26 12:16:34 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-26 12:20:17 -0700
commit04b1e736897505ccf5b483379289d02a274ea586 (patch)
tree9f1b3742d4b23506d426c627fd5ed82a58d171c8 /tensorflow/python/debug
parentaf0fe569f48f3d5e8405eab76e14abde3c4e3d36 (diff)
tfdbg CLI: Allow node exclusion with tensor filters
Fixes: #16619 See the referred GitHub issue for details, but users want to be able to skip certain nodes when searching for inf/nans, because some nodes generate inf/nans even in nominal conditions. This CL adds a new optional flag `--filter_exclude_node_names` (or `-fenn` for short), which allows users to do exactly that, by using a regex for node names. RELNOTES: tfdbg CLI: Allow exclusion of nodes by regular expressions during tensor filter-enabled Session runs: see the new flags `--filter_exclude_node_names` (or `-fenn` for short). PiperOrigin-RevId: 190504225
Diffstat (limited to 'tensorflow/python/debug')
-rw-r--r--tensorflow/python/debug/cli/analyzer_cli.py22
-rw-r--r--tensorflow/python/debug/cli/analyzer_cli_test.py26
-rw-r--r--tensorflow/python/debug/lib/debug_data.py14
-rw-r--r--tensorflow/python/debug/lib/session_debug_testlib.py49
-rw-r--r--tensorflow/python/debug/wrappers/local_cli_wrapper.py39
-rw-r--r--tensorflow/python/debug/wrappers/local_cli_wrapper_test.py36
6 files changed, 180 insertions, 6 deletions
diff --git a/tensorflow/python/debug/cli/analyzer_cli.py b/tensorflow/python/debug/cli/analyzer_cli.py
index 156afdfd4c..9a47cd12b4 100644
--- a/tensorflow/python/debug/cli/analyzer_cli.py
+++ b/tensorflow/python/debug/cli/analyzer_cli.py
@@ -186,6 +186,15 @@ class DebugAnalyzer(object):
default="",
help="List only Tensors passing the filter of the specified name")
ap.add_argument(
+ "-fenn",
+ "--filter_exclude_node_names",
+ dest="filter_exclude_node_names",
+ type=str,
+ default="",
+ help="When applying the tensor filter, exclude node with names "
+ "matching the regular expression. Applicable only if --tensor_filter "
+ "or -f is used.")
+ ap.add_argument(
"-n",
"--node_name_filter",
dest="node_name_filter",
@@ -484,6 +493,10 @@ class DebugAnalyzer(object):
Returns:
Output text lines as a RichTextLines object.
+
+ Raises:
+ ValueError: If `--filter_exclude_node_names` is used without `-f` or
+ `--tensor_filter` being used.
"""
# TODO(cais): Add annotations of substrings for dumped tensor names, to
@@ -520,8 +533,15 @@ class DebugAnalyzer(object):
_add_main_menu(output, node_name=None, enable_list_tensors=False)
return output
- data_to_show = self._debug_dump.find(filter_callable)
+ data_to_show = self._debug_dump.find(
+ filter_callable,
+ exclude_node_names=parsed.filter_exclude_node_names)
else:
+ if parsed.filter_exclude_node_names:
+ raise ValueError(
+ "The flag --filter_exclude_node_names is valid only when "
+ "the flag -f or --tensor_filter is used.")
+
data_to_show = self._debug_dump.dumped_tensor_data
# TODO(cais): Implement filter by lambda on tensor value.
diff --git a/tensorflow/python/debug/cli/analyzer_cli_test.py b/tensorflow/python/debug/cli/analyzer_cli_test.py
index 6b110fda9e..55231954d1 100644
--- a/tensorflow/python/debug/cli/analyzer_cli_test.py
+++ b/tensorflow/python/debug/cli/analyzer_cli_test.py
@@ -820,6 +820,32 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase):
op_type_regex="(Add|MatMul)")
check_main_menu(self, out, list_tensors_enabled=False)
+ def testListTensorWithFilterAndNodeNameExclusionWorks(self):
+ # First, create and register the filter.
+ def is_2x1_vector(datum, tensor):
+ del datum # Unused.
+ return list(tensor.shape) == [2, 1]
+ self._analyzer.add_tensor_filter("is_2x1_vector", is_2x1_vector)
+
+ # Use shorthand alias for the command prefix.
+ out = self._registry.dispatch_command(
+ "lt", ["-f", "is_2x1_vector", "--filter_exclude_node_names", ".*v.*"])
+
+ # If the --filter_exclude_node_names were not used, then the matching
+ # tensors would be:
+ # - simple_mul_add/v:0
+ # - simple_mul_add/v/read:0
+ # - simple_mul_add/matmul:0
+ # - simple_mul_add/add:0
+ #
+ # With the --filter_exclude_node_names option, only the last two should
+ # show up in the result.
+ assert_listed_tensors(
+ self,
+ out, ["simple_mul_add/matmul:0", "simple_mul_add/add:0"],
+ ["MatMul", "Add"], tensor_filter_name="is_2x1_vector")
+ check_main_menu(self, out, list_tensors_enabled=False)
+
def testListTensorsFilterNanOrInf(self):
"""Test register and invoke a tensor filter."""
diff --git a/tensorflow/python/debug/lib/debug_data.py b/tensorflow/python/debug/lib/debug_data.py
index 8d355aa27f..8a65ad087b 100644
--- a/tensorflow/python/debug/lib/debug_data.py
+++ b/tensorflow/python/debug/lib/debug_data.py
@@ -23,6 +23,7 @@ import glob
import json
import os
import platform
+import re
import numpy as np
import six
@@ -1411,7 +1412,11 @@ class DebugDumpDir(object):
return self._watch_key_to_datum[device_name].get(debug_watch_key, [])
- def find(self, predicate, first_n=0, device_name=None):
+ def find(self,
+ predicate,
+ first_n=0,
+ device_name=None,
+ exclude_node_names=None):
"""Find dumped tensor data by a certain predicate.
Args:
@@ -1430,17 +1435,24 @@ class DebugDumpDir(object):
time order) for which the predicate returns True. To return all the
`DebugTensotDatum` instances, let first_n be <= 0.
device_name: optional device name.
+ exclude_node_names: Optional regular expression to exclude nodes with
+ names matching the regular expression.
Returns:
A list of all `DebugTensorDatum` objects in this `DebugDumpDir` object
for which predicate returns True, sorted in ascending order of the
timestamp.
"""
+ if exclude_node_names:
+ exclude_node_names = re.compile(exclude_node_names)
matched_data = []
for device in (self._dump_tensor_data if device_name is None
else (self._dump_tensor_data[device_name],)):
for datum in self._dump_tensor_data[device]:
+ if exclude_node_names and exclude_node_names.match(datum.node_name):
+ continue
+
if predicate(datum, datum.get_tensor()):
matched_data.append(datum)
diff --git a/tensorflow/python/debug/lib/session_debug_testlib.py b/tensorflow/python/debug/lib/session_debug_testlib.py
index f4fac14019..070d9c4cd7 100644
--- a/tensorflow/python/debug/lib/session_debug_testlib.py
+++ b/tensorflow/python/debug/lib/session_debug_testlib.py
@@ -669,6 +669,55 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
self.assertEqual(1, len(first_bad_datum))
self.assertEqual(x_name, first_bad_datum[0].node_name)
+ def testFindInfOrNanWithOpNameExclusion(self):
+ with session.Session() as sess:
+ u_name = "testFindInfOrNanWithOpNameExclusion/u"
+ v_name = "testFindInfOrNanWithOpNameExclusion/v"
+ w_name = "testFindInfOrNanWithOpNameExclusion/w"
+ x_name = "testFindInfOrNanWithOpNameExclusion/x"
+ y_name = "testFindInfOrNanWithOpNameExclusion/y"
+ z_name = "testFindInfOrNanWithOpNameExclusion/z"
+
+ u_init = constant_op.constant([2.0, 4.0])
+ u = variables.Variable(u_init, name=u_name)
+ v_init = constant_op.constant([2.0, 1.0])
+ v = variables.Variable(v_init, name=v_name)
+
+ # Expected output: [0.0, 3.0]
+ w = math_ops.subtract(u, v, name=w_name)
+
+ # Expected output: [inf, 1.3333]
+ x = math_ops.div(u, w, name=x_name)
+
+ # Expected output: [nan, 4.0]
+ y = math_ops.multiply(w, x, name=y_name)
+
+ z = math_ops.multiply(y, y, name=z_name)
+
+ u.initializer.run()
+ v.initializer.run()
+
+ _, dump = self._debug_run_and_get_dump(
+ sess, z,
+ expected_partition_graph_count=self._expected_partition_graph_count)
+
+ # Find all "offending tensors".
+ bad_data = dump.find(debug_data.has_inf_or_nan,
+ exclude_node_names=".*/x$")
+
+ # Verify that the nodes with bad values are caught through running find
+ # on the debug dump.
+ self.assertEqual(2, len(bad_data))
+ # Assert that the node `x` should have been excluded.
+ self.assertEqual(y_name, bad_data[0].node_name)
+ self.assertEqual(z_name, bad_data[1].node_name)
+
+ first_bad_datum = dump.find(
+ debug_data.has_inf_or_nan, first_n=1, exclude_node_names=".*/x$")
+
+ self.assertEqual(1, len(first_bad_datum))
+ self.assertEqual(y_name, first_bad_datum[0].node_name)
+
def _session_run_for_graph_structure_lookup(self):
with session.Session(config=no_rewrite_session_config()) as sess:
u_name = "testDumpGraphStructureLookup/u"
diff --git a/tensorflow/python/debug/wrappers/local_cli_wrapper.py b/tensorflow/python/debug/wrappers/local_cli_wrapper.py
index 1465cb7295..c8625655e5 100644
--- a/tensorflow/python/debug/wrappers/local_cli_wrapper.py
+++ b/tensorflow/python/debug/wrappers/local_cli_wrapper.py
@@ -115,6 +115,7 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession):
# unavailable (i.e., is None), the run-start CLI will be launched to ask
# the user. This is the case, e.g., right before the first run starts.
self._active_tensor_filter = None
+ self._active_filter_exclude_node_names = None
self._active_tensor_filter_run_start_response = None
self._run_through_times = 1
self._skip_debug = False
@@ -149,6 +150,15 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession):
default="",
help="Run until a tensor in the graph passes the specified filter.")
ap.add_argument(
+ "-fenn",
+ "--filter_exclude_node_names",
+ dest="filter_exclude_node_names",
+ type=str,
+ default="",
+ help="When applying the tensor filter, exclude node with names "
+ "matching the regular expression. Applicable only if --tensor_filter "
+ "or -f is used.")
+ ap.add_argument(
"--node_name_filter",
dest="node_name_filter",
type=str,
@@ -324,9 +334,11 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession):
debug_dump.set_python_graph(self._sess.graph)
passed_filter = None
+ passed_filter_exclude_node_names = None
if self._active_tensor_filter:
if not debug_dump.find(
- self._tensor_filters[self._active_tensor_filter], first_n=1):
+ self._tensor_filters[self._active_tensor_filter], first_n=1,
+ exclude_node_names=self._active_filter_exclude_node_names):
# No dumped tensor passes the filter in this run. Clean up the dump
# directory and move on.
self._remove_dump_root()
@@ -334,10 +346,14 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession):
else:
# Some dumped tensor(s) from this run passed the filter.
passed_filter = self._active_tensor_filter
+ passed_filter_exclude_node_names = (
+ self._active_filter_exclude_node_names)
self._active_tensor_filter = None
+ self._active_filter_exclude_node_names = None
self._prep_debug_cli_for_run_end(
- debug_dump, request.tf_error, passed_filter)
+ debug_dump, request.tf_error, passed_filter,
+ passed_filter_exclude_node_names)
self._run_start_response = self._launch_cli()
@@ -358,7 +374,11 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession):
if os.path.isdir(self._dump_root):
shutil.rmtree(self._dump_root)
- def _prep_debug_cli_for_run_end(self, debug_dump, tf_error, passed_filter):
+ def _prep_debug_cli_for_run_end(self,
+ debug_dump,
+ tf_error,
+ passed_filter,
+ passed_filter_exclude_node_names):
"""Prepare (but not launch) CLI for run-end, with debug dump from the run.
Args:
@@ -368,6 +388,9 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession):
(if any).
passed_filter: (None or str) Name of the tensor filter that just passed
and caused the preparation of this run-end CLI (if any).
+ passed_filter_exclude_node_names: (None or str) Regular expression used
+ with the tensor filter to exclude ops with names matching the regular
+ expresssion.
"""
if tf_error:
@@ -383,6 +406,9 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession):
if passed_filter is not None:
# Some dumped tensor(s) from this run passed the filter.
self._init_command = "lt -f %s" % passed_filter
+ if passed_filter_exclude_node_names:
+ self._init_command += (" --filter_exclude_node_names %s" %
+ passed_filter_exclude_node_names)
self._title_color = "red_on_white"
self._run_cli = analyzer_cli.create_analyzer_ui(
@@ -496,6 +522,11 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession):
parsed.op_type_filter = parsed.op_type_filter or None
parsed.tensor_dtype_filter = parsed.tensor_dtype_filter or None
+ if parsed.filter_exclude_node_names and not parsed.till_filter_pass:
+ raise ValueError(
+ "The --filter_exclude_node_names (or -feon) flag is valid only if "
+ "the --till_filter_pass (or -f) flag is used.")
+
if parsed.profile:
raise debugger_cli_common.CommandLineExit(
exit_token=framework.OnRunStartResponse(
@@ -525,6 +556,8 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession):
if parsed.till_filter_pass in self._tensor_filters:
action = framework.OnRunStartAction.DEBUG_RUN
self._active_tensor_filter = parsed.till_filter_pass
+ self._active_filter_exclude_node_names = (
+ parsed.filter_exclude_node_names)
self._active_tensor_filter_run_start_response = run_start_response
else:
# Handle invalid filter name.
diff --git a/tensorflow/python/debug/wrappers/local_cli_wrapper_test.py b/tensorflow/python/debug/wrappers/local_cli_wrapper_test.py
index 490812c96d..b06fa26a93 100644
--- a/tensorflow/python/debug/wrappers/local_cli_wrapper_test.py
+++ b/tensorflow/python/debug/wrappers/local_cli_wrapper_test.py
@@ -87,7 +87,11 @@ class LocalCLIDebuggerWrapperSessionForTest(
def _prep_cli_for_run_start(self):
pass
- def _prep_debug_cli_for_run_end(self, debug_dump, tf_error, passed_filter):
+ def _prep_debug_cli_for_run_end(self,
+ debug_dump,
+ tf_error,
+ passed_filter,
+ passed_filter_exclude_op_names):
self.observers["debug_dumps"].append(debug_dump)
self.observers["tf_errors"].append(tf_error)
@@ -451,6 +455,36 @@ class LocalCLIDebugWrapperSessionTest(test_util.TensorFlowTestCase):
self.assertEqual(2, len(wrapped_sess.observers["debug_dumps"]))
self.assertEqual([None, None], wrapped_sess.observers["tf_errors"])
+ def testRunTillFilterPassesWithExcludeOpNames(self):
+ wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
+ [["run", "-f", "greater_than_twelve",
+ "--filter_exclude_node_names", "inc_v.*"],
+ ["run"], ["run"]],
+ self.sess,
+ dump_root=self._tmp_dir)
+
+ def greater_than_twelve(datum, tensor):
+ del datum # Unused.
+ return tensor > 12.0
+
+ # Verify that adding the same tensor filter more than once is tolerated
+ # (i.e., as if it were added only once).
+ wrapped_sess.add_tensor_filter("greater_than_twelve", greater_than_twelve)
+
+ # run five times.
+ wrapped_sess.run(self.inc_v)
+ wrapped_sess.run(self.inc_v)
+ wrapped_sess.run(self.inc_v)
+ wrapped_sess.run(self.inc_v)
+
+ self.assertAllClose(14.0, self.sess.run(self.v))
+
+ self.assertEqual([1], wrapped_sess.observers["run_start_cli_run_numbers"])
+
+ # Due to the --filter_exclude_op_names flag, the run-end CLI should show up
+ # not after run 3, but after run 4.
+ self.assertEqual([4], wrapped_sess.observers["run_end_cli_run_numbers"])
+
def testRunTillFilterPassesWorksInConjunctionWithOtherNodeNameFilter(self):
"""Test that --.*_filter flags work in conjunction with -f.