aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/debug
diff options
context:
space:
mode:
authorGravatar Shanqing Cai <cais@google.com>2017-09-26 18:00:16 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-26 18:04:21 -0700
commit35c44ab67d6e5d9b24f3f154c92e7aa3edfee957 (patch)
tree89c73cd4105df214c2469f270336234cf160f7a4 /tensorflow/python/debug
parent2733d24da31318208f85df20e5a54372c0a1af9f (diff)
tfdbg: fix a bug re. string representation of SparseTensor feeds
Fixes: #12059 PiperOrigin-RevId: 170138936
Diffstat (limited to 'tensorflow/python/debug')
-rw-r--r--tensorflow/python/debug/BUILD2
-rw-r--r--tensorflow/python/debug/cli/cli_shared.py27
-rw-r--r--tensorflow/python/debug/wrappers/local_cli_wrapper.py13
-rw-r--r--tensorflow/python/debug/wrappers/local_cli_wrapper_test.py4
4 files changed, 23 insertions, 23 deletions
diff --git a/tensorflow/python/debug/BUILD b/tensorflow/python/debug/BUILD
index 05906a405a..ee53469cc7 100644
--- a/tensorflow/python/debug/BUILD
+++ b/tensorflow/python/debug/BUILD
@@ -330,7 +330,6 @@ py_library(
":stepper_cli",
":tensor_format",
":ui_factory",
- "@six_archive//:six",
],
)
@@ -941,6 +940,7 @@ py_test(
":cli_shared",
":debugger_cli_common",
":local_cli_wrapper",
+ ":ui_factory",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:client",
diff --git a/tensorflow/python/debug/cli/cli_shared.py b/tensorflow/python/debug/cli/cli_shared.py
index 5d0e1d19d8..c3c9a332a7 100644
--- a/tensorflow/python/debug/cli/cli_shared.py
+++ b/tensorflow/python/debug/cli/cli_shared.py
@@ -214,18 +214,22 @@ def error(msg):
RL("ERROR: " + msg, COLOR_RED)])
-def _get_fetch_name(fetch):
- """Obtain the name or string representation of a fetch.
+def get_graph_element_name(elem):
+ """Obtain the name or string representation of a graph element.
+
+ If the graph element has the attribute "name", return name. Otherwise, return
+ a __str__ representation of the graph element. Certain graph elements, such as
+ `SparseTensor`s, do not have the attribute "name".
Args:
- fetch: The fetch in question.
+ elem: The graph element in question.
Returns:
If the attribute 'name' is available, return the name. Otherwise, return
str(fetch).
"""
- return fetch.name if hasattr(fetch, "name") else str(fetch)
+ return elem.name if hasattr(elem, "name") else str(elem)
def _get_fetch_names(fetches):
@@ -250,7 +254,7 @@ def _get_fetch_names(fetches):
else:
# This ought to be a Tensor, an Operation or a Variable, for which the name
# attribute should be available. (Bottom-out condition of the recursion.)
- lines.append(_get_fetch_name(fetches))
+ lines.append(get_graph_element_name(fetches))
return lines
@@ -330,16 +334,13 @@ def get_run_start_intro(run_call_count,
else:
feed_dict_lines = []
for feed_key in feed_dict:
- if isinstance(feed_key, six.string_types):
- feed_key_name = feed_key
- elif hasattr(feed_key, "name"):
- feed_key_name = feed_key.name
- else:
- feed_key_name = str(feed_key)
+ feed_key_name = get_graph_element_name(feed_key)
feed_dict_line = debugger_cli_common.RichLine(" ")
feed_dict_line += debugger_cli_common.RichLine(
feed_key_name,
- debugger_cli_common.MenuItem(None, "pf %s" % feed_key_name))
+ debugger_cli_common.MenuItem(None, "pf '%s'" % feed_key_name))
+ # Surround the name string with quotes, because feed_key_name may contain
+ # spaces in some cases, e.g., SparseTensors.
feed_dict_lines.append(feed_dict_line)
feed_dict_lines = debugger_cli_common.rich_text_lines_from_rich_line_list(
feed_dict_lines)
@@ -445,7 +446,7 @@ def get_run_short_description(run_call_count,
description = "run #%d: " % run_call_count
if isinstance(fetches, (ops.Tensor, ops.Operation, variables.Variable)):
- description += "1 fetch (%s); " % _get_fetch_name(fetches)
+ description += "1 fetch (%s); " % get_graph_element_name(fetches)
else:
# Could be (nested) list, tuple, dict or namedtuple.
num_fetches = len(_get_fetch_names(fetches))
diff --git a/tensorflow/python/debug/wrappers/local_cli_wrapper.py b/tensorflow/python/debug/wrappers/local_cli_wrapper.py
index 7334a937f6..e06267ff5a 100644
--- a/tensorflow/python/debug/wrappers/local_cli_wrapper.py
+++ b/tensorflow/python/debug/wrappers/local_cli_wrapper.py
@@ -23,8 +23,6 @@ import shutil
import sys
import tempfile
-import six
-
# Google-internal import(s).
from tensorflow.python.debug.cli import analyzer_cli
from tensorflow.python.debug.cli import cli_shared
@@ -465,12 +463,9 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession):
feed_key = None
feed_value = None
for key in self._feed_dict:
- if isinstance(key, six.string_types):
- if key == tensor_name:
- feed_key = key
- elif key.name == tensor_name:
- feed_key = key.name
- if feed_key is not None:
+ key_name = cli_shared.get_graph_element_name(key)
+ if key_name == tensor_name:
+ feed_key = key_name
feed_value = self._feed_dict[key]
break
@@ -565,7 +560,7 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession):
list(self._tensor_filters.keys()))
if self._feed_dict:
# Register tab completion for feed_dict keys.
- feed_keys = [(key if isinstance(key, six.string_types) else key.name)
+ feed_keys = [cli_shared.get_graph_element_name(key)
for key in self._feed_dict.keys()]
curses_cli.register_tab_comp_context(["print_feed", "pf"], feed_keys)
diff --git a/tensorflow/python/debug/wrappers/local_cli_wrapper_test.py b/tensorflow/python/debug/wrappers/local_cli_wrapper_test.py
index 8a2fe7283c..770a496aa9 100644
--- a/tensorflow/python/debug/wrappers/local_cli_wrapper_test.py
+++ b/tensorflow/python/debug/wrappers/local_cli_wrapper_test.py
@@ -25,6 +25,7 @@ from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
from tensorflow.python.debug.cli import cli_shared
from tensorflow.python.debug.cli import debugger_cli_common
+from tensorflow.python.debug.cli import ui_factory
from tensorflow.python.debug.wrappers import local_cli_wrapper
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -100,6 +101,9 @@ class LocalCLIDebuggerWrapperSessionForTest(
else:
self.observers["run_end_cli_run_numbers"].append(self._run_call_count)
+ readline_cli = ui_factory.get_ui("readline")
+ self._register_this_run_info(readline_cli)
+
while True:
command = self._command_sequence[self._command_pointer]
self._command_pointer += 1