diff options
author | 2017-09-26 18:00:16 -0700 | |
---|---|---|
committer | 2017-09-26 18:04:21 -0700 | |
commit | 35c44ab67d6e5d9b24f3f154c92e7aa3edfee957 (patch) | |
tree | 89c73cd4105df214c2469f270336234cf160f7a4 /tensorflow/python/debug | |
parent | 2733d24da31318208f85df20e5a54372c0a1af9f (diff) |
tfdbg: fix a bug re. string representation of SparseTensor feeds
Fixes: #12059
PiperOrigin-RevId: 170138936
Diffstat (limited to 'tensorflow/python/debug')
-rw-r--r-- | tensorflow/python/debug/BUILD | 2 | ||||
-rw-r--r-- | tensorflow/python/debug/cli/cli_shared.py | 27 | ||||
-rw-r--r-- | tensorflow/python/debug/wrappers/local_cli_wrapper.py | 13 | ||||
-rw-r--r-- | tensorflow/python/debug/wrappers/local_cli_wrapper_test.py | 4 |
4 files changed, 23 insertions, 23 deletions
diff --git a/tensorflow/python/debug/BUILD b/tensorflow/python/debug/BUILD index 05906a405a..ee53469cc7 100644 --- a/tensorflow/python/debug/BUILD +++ b/tensorflow/python/debug/BUILD @@ -330,7 +330,6 @@ py_library( ":stepper_cli", ":tensor_format", ":ui_factory", - "@six_archive//:six", ], ) @@ -941,6 +940,7 @@ py_test( ":cli_shared", ":debugger_cli_common", ":local_cli_wrapper", + ":ui_factory", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:client", diff --git a/tensorflow/python/debug/cli/cli_shared.py b/tensorflow/python/debug/cli/cli_shared.py index 5d0e1d19d8..c3c9a332a7 100644 --- a/tensorflow/python/debug/cli/cli_shared.py +++ b/tensorflow/python/debug/cli/cli_shared.py @@ -214,18 +214,22 @@ def error(msg): RL("ERROR: " + msg, COLOR_RED)]) -def _get_fetch_name(fetch): - """Obtain the name or string representation of a fetch. +def get_graph_element_name(elem): + """Obtain the name or string representation of a graph element. + + If the graph element has the attribute "name", return name. Otherwise, return + a __str__ representation of the graph element. Certain graph elements, such as + `SparseTensor`s, do not have the attribute "name". Args: - fetch: The fetch in question. + elem: The graph element in question. Returns: If the attribute 'name' is available, return the name. Otherwise, return str(fetch). """ - return fetch.name if hasattr(fetch, "name") else str(fetch) + return elem.name if hasattr(elem, "name") else str(elem) def _get_fetch_names(fetches): @@ -250,7 +254,7 @@ def _get_fetch_names(fetches): else: # This ought to be a Tensor, an Operation or a Variable, for which the name # attribute should be available. (Bottom-out condition of the recursion.) - lines.append(_get_fetch_name(fetches)) + lines.append(get_graph_element_name(fetches)) return lines @@ -330,16 +334,13 @@ def get_run_start_intro(run_call_count, else: feed_dict_lines = [] for feed_key in feed_dict: - if isinstance(feed_key, six.string_types): - feed_key_name = feed_key - elif hasattr(feed_key, "name"): - feed_key_name = feed_key.name - else: - feed_key_name = str(feed_key) + feed_key_name = get_graph_element_name(feed_key) feed_dict_line = debugger_cli_common.RichLine(" ") feed_dict_line += debugger_cli_common.RichLine( feed_key_name, - debugger_cli_common.MenuItem(None, "pf %s" % feed_key_name)) + debugger_cli_common.MenuItem(None, "pf '%s'" % feed_key_name)) + # Surround the name string with quotes, because feed_key_name may contain + # spaces in some cases, e.g., SparseTensors. feed_dict_lines.append(feed_dict_line) feed_dict_lines = debugger_cli_common.rich_text_lines_from_rich_line_list( feed_dict_lines) @@ -445,7 +446,7 @@ def get_run_short_description(run_call_count, description = "run #%d: " % run_call_count if isinstance(fetches, (ops.Tensor, ops.Operation, variables.Variable)): - description += "1 fetch (%s); " % _get_fetch_name(fetches) + description += "1 fetch (%s); " % get_graph_element_name(fetches) else: # Could be (nested) list, tuple, dict or namedtuple. num_fetches = len(_get_fetch_names(fetches)) diff --git a/tensorflow/python/debug/wrappers/local_cli_wrapper.py b/tensorflow/python/debug/wrappers/local_cli_wrapper.py index 7334a937f6..e06267ff5a 100644 --- a/tensorflow/python/debug/wrappers/local_cli_wrapper.py +++ b/tensorflow/python/debug/wrappers/local_cli_wrapper.py @@ -23,8 +23,6 @@ import shutil import sys import tempfile -import six - # Google-internal import(s). from tensorflow.python.debug.cli import analyzer_cli from tensorflow.python.debug.cli import cli_shared @@ -465,12 +463,9 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession): feed_key = None feed_value = None for key in self._feed_dict: - if isinstance(key, six.string_types): - if key == tensor_name: - feed_key = key - elif key.name == tensor_name: - feed_key = key.name - if feed_key is not None: + key_name = cli_shared.get_graph_element_name(key) + if key_name == tensor_name: + feed_key = key_name feed_value = self._feed_dict[key] break @@ -565,7 +560,7 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession): list(self._tensor_filters.keys())) if self._feed_dict: # Register tab completion for feed_dict keys. - feed_keys = [(key if isinstance(key, six.string_types) else key.name) + feed_keys = [cli_shared.get_graph_element_name(key) for key in self._feed_dict.keys()] curses_cli.register_tab_comp_context(["print_feed", "pf"], feed_keys) diff --git a/tensorflow/python/debug/wrappers/local_cli_wrapper_test.py b/tensorflow/python/debug/wrappers/local_cli_wrapper_test.py index 8a2fe7283c..770a496aa9 100644 --- a/tensorflow/python/debug/wrappers/local_cli_wrapper_test.py +++ b/tensorflow/python/debug/wrappers/local_cli_wrapper_test.py @@ -25,6 +25,7 @@ from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session from tensorflow.python.debug.cli import cli_shared from tensorflow.python.debug.cli import debugger_cli_common +from tensorflow.python.debug.cli import ui_factory from tensorflow.python.debug.wrappers import local_cli_wrapper from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -100,6 +101,9 @@ class LocalCLIDebuggerWrapperSessionForTest( else: self.observers["run_end_cli_run_numbers"].append(self._run_call_count) + readline_cli = ui_factory.get_ui("readline") + self._register_this_run_info(readline_cli) + while True: command = self._command_sequence[self._command_pointer] self._command_pointer += 1 |