aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Shanqing Cai <cais@google.com>2016-11-11 13:08:26 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-11-11 13:25:03 -0800
commita8967c15a45be5517dec8c2c343f84e36b001b7b (patch)
tree90b54bde1c07a48f441bbf7b8fa4f922cf34cb1d
parent20df37f40296662519b89fd6658e43fce7c000b7 (diff)
tfdbg: example for debugging tf-learn Estimator.fit() & minor changes
$ blaze build -c opt third_party/tensorflow/python/debug:debug_tflearn_iris && blaze-bin/third_party/tensorflow/python/debug/debug_tflearn_iris --debug --train_steps=1 Minor changes made in this CL: * Fix a bug in local_cli_wrapper related to computing fetch and feed summary strings. * Minor tweaks in debug/examples/README.md * Curses CLI: Add scroll direction information for UI clarity Change: 138910196
-rw-r--r--tensorflow/python/debug/BUILD18
-rw-r--r--tensorflow/python/debug/__init__.py1
-rw-r--r--tensorflow/python/debug/cli/curses_ui.py13
-rw-r--r--tensorflow/python/debug/cli/curses_ui_test.py45
-rw-r--r--tensorflow/python/debug/examples/README.md46
-rw-r--r--tensorflow/python/debug/examples/debug_tflearn_iris.py115
-rw-r--r--tensorflow/python/debug/wrappers/hooks.py91
-rw-r--r--tensorflow/python/debug/wrappers/local_cli_wrapper.py12
8 files changed, 309 insertions, 32 deletions
diff --git a/tensorflow/python/debug/BUILD b/tensorflow/python/debug/BUILD
index 0410b30db2..9d44840c88 100644
--- a/tensorflow/python/debug/BUILD
+++ b/tensorflow/python/debug/BUILD
@@ -22,6 +22,7 @@ py_library(
deps = [
":debug_data",
":debug_utils",
+ ":hooks",
":local_cli_wrapper",
],
)
@@ -116,6 +117,16 @@ py_library(
],
)
+py_library(
+ name = "hooks",
+ srcs = ["wrappers/hooks.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":local_cli_wrapper",
+ "//tensorflow/python:session",
+ ],
+)
+
py_binary(
name = "debug_fibonacci",
srcs = ["examples/debug_fibonacci.py"],
@@ -140,6 +151,13 @@ py_binary(
],
)
+py_binary(
+ name = "debug_tflearn_iris",
+ srcs = ["examples/debug_tflearn_iris.py"],
+ srcs_version = "PY2AND3",
+ deps = ["//tensorflow:tensorflow_py"],
+)
+
py_test(
name = "debug_data_test",
size = "small",
diff --git a/tensorflow/python/debug/__init__.py b/tensorflow/python/debug/__init__.py
index fd7863e6d3..812c31ed16 100644
--- a/tensorflow/python/debug/__init__.py
+++ b/tensorflow/python/debug/__init__.py
@@ -28,4 +28,5 @@ from tensorflow.python.debug.debug_utils import add_debug_tensor_watch
from tensorflow.python.debug.debug_utils import watch_graph
from tensorflow.python.debug.debug_utils import watch_graph_with_blacklists
+from tensorflow.python.debug.wrappers.hooks import LocalCLIDebugHook
from tensorflow.python.debug.wrappers.local_cli_wrapper import LocalCLIDebugWrapperSession
diff --git a/tensorflow/python/debug/cli/curses_ui.py b/tensorflow/python/debug/cli/curses_ui.py
index 8e3d069aa0..4dacc6c3c5 100644
--- a/tensorflow/python/debug/cli/curses_ui.py
+++ b/tensorflow/python/debug/cli/curses_ui.py
@@ -937,10 +937,19 @@ class CursesUI(object):
if self._output_pad_height > self._output_pad_screen_height + 1:
# Display information about the scrolling of tall screen output.
- self._scroll_info = "--- Scroll: %.2f%% " % (100.0 * (min(
+ scroll_percentage = 100.0 * (min(
1.0,
float(self._output_pad_row) /
- (self._output_pad_height - self._output_pad_screen_height - 1))))
+ (self._output_pad_height - self._output_pad_screen_height - 1)))
+ if self._output_pad_row == 0:
+ scroll_directions = " (PgDn)"
+ elif self._output_pad_row >= (
+ self._output_pad_height - self._output_pad_screen_height - 1):
+ scroll_directions = " (PgUp)"
+ else:
+ scroll_directions = " (PgDn/PgUp)"
+ self._scroll_info = "--- Scroll%s: %.2f%% " % (scroll_directions,
+ scroll_percentage)
self._output_array_pointer_indices = self._show_array_indices()
diff --git a/tensorflow/python/debug/cli/curses_ui_test.py b/tensorflow/python/debug/cli/curses_ui_test.py
index 94a80cfac6..8c4b150827 100644
--- a/tensorflow/python/debug/cli/curses_ui_test.py
+++ b/tensorflow/python/debug/cli/curses_ui_test.py
@@ -320,19 +320,19 @@ class CursesTest(test_util.TensorFlowTestCase):
self.assertEqual(["bar"] * 60, ui.wrapped_outputs[0].lines[:60])
# Initial scroll: At the top.
- self.assertIn("Scroll: 0.00%", ui.scroll_messages[0])
+ self.assertIn("Scroll (PgDn): 0.00%", ui.scroll_messages[0])
# After 1st scrolling (PageDown).
# The screen output shouldn't have changed. Only the viewport should.
self.assertEqual(["bar"] * 60, ui.unwrapped_outputs[0].lines)
self.assertEqual(["bar"] * 60, ui.wrapped_outputs[0].lines[:60])
- self.assertIn("Scroll: 1.69%", ui.scroll_messages[1])
+ self.assertIn("Scroll (PgDn/PgUp): 1.69%", ui.scroll_messages[1])
# After 2nd scrolling (PageDown).
- self.assertIn("Scroll: 3.39%", ui.scroll_messages[2])
+ self.assertIn("Scroll (PgDn/PgUp): 3.39%", ui.scroll_messages[2])
# After 3rd scrolling (PageUp).
- self.assertIn("Scroll: 1.69%", ui.scroll_messages[3])
+ self.assertIn("Scroll (PgDn/PgUp): 1.69%", ui.scroll_messages[3])
def testCutOffTooManyOutputLines(self):
ui = MockCursesUI(
@@ -375,16 +375,16 @@ class CursesTest(test_util.TensorFlowTestCase):
self.assertEqual(["bar"] * 60, ui.wrapped_outputs[0].lines[:60])
# Initial scroll: At the top.
- self.assertIn("Scroll: 0.00%", ui.scroll_messages[0])
+ self.assertIn("Scroll (PgDn): 0.00%", ui.scroll_messages[0])
# After 1st scrolling (End).
- self.assertIn("Scroll: 100.00%", ui.scroll_messages[1])
+ self.assertIn("Scroll (PgUp): 100.00%", ui.scroll_messages[1])
# After 2nd scrolling (End).
- self.assertIn("Scroll: 100.00%", ui.scroll_messages[2])
+ self.assertIn("Scroll (PgUp): 100.00%", ui.scroll_messages[2])
# After 3rd scrolling (Hhome).
- self.assertIn("Scroll: 0.00%", ui.scroll_messages[3])
+ self.assertIn("Scroll (PgDn): 0.00%", ui.scroll_messages[3])
def testRunUIWithInitCmd(self):
"""Run UI with an initial command specified."""
@@ -398,7 +398,7 @@ class CursesTest(test_util.TensorFlowTestCase):
self.assertEqual(["bar"] * 60, ui.unwrapped_outputs[0].lines)
self.assertEqual(["bar"] * 60, ui.wrapped_outputs[0].lines[:60])
- self.assertIn("Scroll: 0.00%", ui.scroll_messages[0])
+ self.assertIn("Scroll (PgDn): 0.00%", ui.scroll_messages[0])
def testCompileHelpWithoutHelpIntro(self):
ui = MockCursesUI(
@@ -571,7 +571,7 @@ class CursesTest(test_util.TensorFlowTestCase):
# The 1st scroll info should contain scrolling, because the screen size
# is less than the number of lines in the output.
- self.assertIn("Scroll: 0.00%", ui.scroll_messages[0])
+ self.assertIn("Scroll (PgDn): 0.00%", ui.scroll_messages[0])
def testTabCompletionWithCommonPrefix(self):
# Type "b" and trigger tab completion.
@@ -973,77 +973,80 @@ class CursesTest(test_util.TensorFlowTestCase):
0: None,
-1: [1, 0]
}, ui.output_array_pointer_indices[0])
- self.assertIn(" Scroll: 0.00% -[1,0] ", ui.scroll_messages[0])
+ self.assertIn(" Scroll (PgDn): 0.00% -[1,0] ", ui.scroll_messages[0])
# Scrolled down one line.
self.assertEqual({
0: None,
-1: [2, 0]
}, ui.output_array_pointer_indices[1])
- self.assertIn(" Scroll: 16.67% -[2,0] ", ui.scroll_messages[1])
+ self.assertIn(" Scroll (PgDn/PgUp): 16.67% -[2,0] ", ui.scroll_messages[1])
# Scrolled down one line.
self.assertEqual({
0: [0, 0],
-1: [3, 0]
}, ui.output_array_pointer_indices[2])
- self.assertIn(" Scroll: 33.33% [0,0]-[3,0] ", ui.scroll_messages[2])
+ self.assertIn(" Scroll (PgDn/PgUp): 33.33% [0,0]-[3,0] ",
+ ui.scroll_messages[2])
# Scrolled down one line.
self.assertEqual({
0: [1, 0],
-1: [4, 0]
}, ui.output_array_pointer_indices[3])
- self.assertIn(" Scroll: 50.00% [1,0]-[4,0] ", ui.scroll_messages[3])
+ self.assertIn(" Scroll (PgDn/PgUp): 50.00% [1,0]-[4,0] ",
+ ui.scroll_messages[3])
# Scroll to the bottom.
self.assertEqual({
0: [4, 0],
-1: None
}, ui.output_array_pointer_indices[4])
- self.assertIn(" Scroll: 100.00% [4,0]- ", ui.scroll_messages[4])
+ self.assertIn(" Scroll (PgUp): 100.00% [4,0]- ", ui.scroll_messages[4])
# Attempt to scroll beyond the bottom should lead to no change.
self.assertEqual({
0: [4, 0],
-1: None
}, ui.output_array_pointer_indices[5])
- self.assertIn(" Scroll: 100.00% [4,0]- ", ui.scroll_messages[5])
+ self.assertIn(" Scroll (PgUp): 100.00% [4,0]- ", ui.scroll_messages[5])
# Scrolled up one line.
self.assertEqual({
0: [3, 0],
-1: None
}, ui.output_array_pointer_indices[6])
- self.assertIn(" Scroll: 83.33% [3,0]- ", ui.scroll_messages[6])
+ self.assertIn(" Scroll (PgDn/PgUp): 83.33% [3,0]- ", ui.scroll_messages[6])
# Scrolled up one line.
self.assertEqual({
0: [2, 0],
-1: None
}, ui.output_array_pointer_indices[7])
- self.assertIn(" Scroll: 66.67% [2,0]- ", ui.scroll_messages[7])
+ self.assertIn(" Scroll (PgDn/PgUp): 66.67% [2,0]- ", ui.scroll_messages[7])
# Scrolled up one line.
self.assertEqual({
0: [1, 0],
-1: [4, 0]
}, ui.output_array_pointer_indices[8])
- self.assertIn(" Scroll: 50.00% [1,0]-[4,0] ", ui.scroll_messages[8])
+ self.assertIn(" Scroll (PgDn/PgUp): 50.00% [1,0]-[4,0] ",
+ ui.scroll_messages[8])
# Scroll to the top.
self.assertEqual({
0: None,
-1: [1, 0]
}, ui.output_array_pointer_indices[9])
- self.assertIn(" Scroll: 0.00% -[1,0] ", ui.scroll_messages[9])
+ self.assertIn(" Scroll (PgDn): 0.00% -[1,0] ", ui.scroll_messages[9])
# Attempt to scroll pass the top limit should lead to no change.
self.assertEqual({
0: None,
-1: [1, 0]
}, ui.output_array_pointer_indices[10])
- self.assertIn(" Scroll: 0.00% -[1,0] ", ui.scroll_messages[10])
+ self.assertIn(" Scroll (PgDn): 0.00% -[1,0] ", ui.scroll_messages[10])
def testScrollTensorByValidIndices(self):
"""Test scrolling to specified (valid) indices in a tensor."""
diff --git a/tensorflow/python/debug/examples/README.md b/tensorflow/python/debug/examples/README.md
index 8719dea3f7..b58c7516e9 100644
--- a/tensorflow/python/debug/examples/README.md
+++ b/tensorflow/python/debug/examples/README.md
@@ -2,11 +2,13 @@
**(Under development, subject to change)**
-This tutorial showcases how to use the TensorFlow Debugger (**tfdbg**) to debug
-a frequently encountered problem in TensorFlow model development: bad numerical
-values (`nan`s and `inf`s) causing training to fail.
+This tutorial showcases the features of TensorFlow Debugger (**tfdbg**)
+command-line interface.
+It contains an example of how to debug a frequently encountered problem in
+TensorFlow model development: bad numerical values (`nan`s and `inf`s) causing
+training to fail.
-To observe the issue, run the following code without the debugger:
+To **observe** such an issue, run the following code without the debugger:
```none
bazel build -c opt tensorflow/python/debug:debug_mnist && \
@@ -345,11 +347,43 @@ stuck. Success!
[tfprof](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/tfprof)
and other profiling tools for TensorFlow.
-**Q**: _How do I link tfdbg against my Session in Bazel?_
+**Q**: _How do I link tfdbg against my `Session` in Bazel?_
**A**: In your BUILD rule, declare the dependency: `"//tensorflow:tensorflow_py"`.
In your Python file, add:
- `from tensorflow.python import debug as tf_debug`
+
+```python
+from tensorflow.python import debug as tf_debug
+
+# Then wrap your TensorFlow Session with the local-CLI wrapper.
+sess = tf_debug.LocalCLIDebugWrapperSession(sess)
+```
+
+**Q**: _Can I use `tfdbg` if I am using tf-learn Estimators, instead of
+managing my own `Session` objects?_
+
+**A**: Currently, `tfdbg` can only debug the `fit()` method of tf-learn
+Estimators. Support for debugging `evaluate()` will come soon. To debug
+`Estimator.fit()`, create a monitor and supply it as an argument. For example:
+
+```python
+from tensorflow.python import debug as tf_debug
+
+# Create a local CLI debug hook and use it as a monitor when calling fit().
+classifier.fit(x=training_set.data,
+ y=training_set.target,
+ steps=1000,
+ monitors=[tf_debug.LocalCLIDebugHook()])
+```
+
+For a detailed [example](https://www.tensorflow.org/code/tensorflow/python/debug/examples/debug_tflearn_iris.py) based on
+[tf-learn's iris tutorial](../../../g3doc/tutorials/tflearn/index.md),
+run:
+
+```none
+bazel build -c opt tensorflow/python/debug:debug_tflearn_iris && \
+ bazel-bin/tensorflow/python/debug/debug_tflearn_iris
+```
**Q**: _Does tfdbg help debugging runtime errors such as shape mismatches?_
diff --git a/tensorflow/python/debug/examples/debug_tflearn_iris.py b/tensorflow/python/debug/examples/debug_tflearn_iris.py
new file mode 100644
index 0000000000..4c104b3e0a
--- /dev/null
+++ b/tensorflow/python/debug/examples/debug_tflearn_iris.py
@@ -0,0 +1,115 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Debug the tf-learn iris example, based on the tf-learn tutorial."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import tempfile
+
+import numpy as np
+from six.moves import urllib
+import tensorflow as tf
+
+from tensorflow.python import debug as tf_debug
+
+flags = tf.app.flags
+FLAGS = flags.FLAGS
+flags.DEFINE_string("data_dir", "/tmp/iris_data",
+ "Directory to save the training and test data in.")
+flags.DEFINE_string("model_dir", "", "Directory to save the trained model in.")
+flags.DEFINE_integer("train_steps", 10, "Number of steps to run trainer.")
+flags.DEFINE_boolean("debug", False,
+ "Use debugger to track down bad values during training")
+
+# URLs to download data sets from, if necessary.
+IRIS_TRAINING_DATA_URL = "https://raw.githubusercontent.com/tensorflow/tensorflow/master/tensorflow/examples/tutorials/monitors/iris_training.csv"
+IRIS_TEST_DATA_URL = "https://raw.githubusercontent.com/tensorflow/tensorflow/master/tensorflow/examples/tutorials/monitors/iris_test.csv"
+
+
+def maybe_download_data():
+ """Download data sets if necessary.
+
+ Returns:
+ Paths to the training and test data files.
+ """
+
+ if not os.path.isdir(FLAGS.data_dir):
+ os.makedirs(FLAGS.data_dir)
+
+ training_data_path = os.path.join(FLAGS.data_dir,
+ os.path.basename(IRIS_TRAINING_DATA_URL))
+ if not os.path.isfile(training_data_path):
+ train_file = open(training_data_path, "wt")
+ urllib.request.urlretrieve(IRIS_TRAINING_DATA_URL, train_file.name)
+ train_file.close()
+
+ print("Training data are downloaded to %s" % train_file.name)
+
+ test_data_path = os.path.join(FLAGS.data_dir,
+ os.path.basename(IRIS_TEST_DATA_URL))
+ if not os.path.isfile(test_data_path):
+ test_file = open(test_data_path, "wt")
+ urllib.request.urlretrieve(IRIS_TEST_DATA_URL, test_file.name)
+ test_file.close()
+
+ print("Test data are downloaded to %s" % test_file.name)
+
+ return training_data_path, test_data_path
+
+
+def main(_):
+ training_data_path, test_data_path = maybe_download_data()
+
+ # Load datasets.
+ training_set = tf.contrib.learn.datasets.base.load_csv_with_header(
+ filename=training_data_path,
+ target_dtype=np.int,
+ features_dtype=np.float32)
+ test_set = tf.contrib.learn.datasets.base.load_csv_with_header(
+ filename=test_data_path, target_dtype=np.int, features_dtype=np.float32)
+
+ # Specify that all features have real-value data
+ feature_columns = [tf.contrib.layers.real_valued_column("", dimension=4)]
+
+ # Build 3 layer DNN with 10, 20, 10 units respectively.
+ model_dir = FLAGS.model_dir or tempfile.mkdtemp(prefix="debug_tflearn_iris_")
+
+ classifier = tf.contrib.learn.DNNClassifier(
+ feature_columns=feature_columns,
+ hidden_units=[10, 20, 10],
+ n_classes=3,
+ model_dir=model_dir)
+
+ monitors = [tf_debug.LocalCLIDebugHook()] if FLAGS.debug else None
+
+ # Fit model.
+ classifier.fit(x=training_set.data,
+ y=training_set.target,
+ steps=FLAGS.train_steps,
+ monitors=monitors)
+
+ # Evaluate accuracy.
+ accuracy_score = classifier.evaluate(
+ x=test_set.data, y=test_set.target)["accuracy"]
+ # TODO(cais): Add debug monitor for evaluate()
+
+ print("After training %d steps, Accuracy = %f" %
+ (FLAGS.train_steps, accuracy_score))
+
+
+if __name__ == "__main__":
+ tf.app.run()
diff --git a/tensorflow/python/debug/wrappers/hooks.py b/tensorflow/python/debug/wrappers/hooks.py
new file mode 100644
index 0000000000..14123d95f8
--- /dev/null
+++ b/tensorflow/python/debug/wrappers/hooks.py
@@ -0,0 +1,91 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""tfdbg CLI as SessionRunHook."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.debug import debug_utils
+from tensorflow.python.debug.wrappers import framework
+from tensorflow.python.debug.wrappers import local_cli_wrapper
+from tensorflow.python.training import session_run_hook
+
+
+class LocalCLIDebugHook(session_run_hook.SessionRunHook,
+ local_cli_wrapper.LocalCLIDebugWrapperSession):
+ """Command-line-interface debugger hook.
+
+ Can be used as a monitor/hook for tf.train.MonitoredSession.
+ """
+
+ def __init__(self):
+ """Create a local debugger command-line interface (CLI) hook."""
+
+ self._wrapper_initialized = False
+
+ def begin(self):
+ pass
+
+ def before_run(self, run_context):
+ if not self._wrapper_initialized:
+ local_cli_wrapper.LocalCLIDebugWrapperSession.__init__(
+ self, run_context.session)
+ self._wrapper_initialized = True
+
+ # Increment run call counter.
+ self._run_call_count += 1
+
+ # Adapt run_context to an instance of OnRunStartRequest for invoking
+ # superclass on_run_start().
+ on_run_start_request = framework.OnRunStartRequest(
+ run_context.original_args.fetches, run_context.original_args.feed_dict,
+ None, None, self._run_call_count)
+
+ on_run_start_response = self.on_run_start(on_run_start_request)
+ self._performed_action = on_run_start_response.action
+
+ run_args = session_run_hook.SessionRunArgs(
+ None, feed_dict=None, options=config_pb2.RunOptions())
+ if self._performed_action == framework.OnRunStartAction.DEBUG_RUN:
+ self._decorate_options_for_debug(run_args.options,
+ run_context.session.graph)
+ elif self._performed_action == framework.OnRunStartAction.INVOKE_STEPPER:
+ raise NotImplementedError(
+ "OnRunStartAction INVOKE_STEPPER has not been implemented.")
+
+ return run_args
+
+ def after_run(self, run_context, run_values):
+ # Adapt run_context and run_values to OnRunEndRequest and invoke superclass
+ # on_run_end()
+ if self._performed_action == framework.OnRunStartAction.DEBUG_RUN:
+ on_run_end_request = framework.OnRunEndRequest(self._performed_action,
+ run_values.run_metadata)
+
+ self.on_run_end(on_run_end_request)
+
+ def _decorate_options_for_debug(self, options, graph):
+ """Modify RunOptions.debug_tensor_watch_opts for debugging.
+
+ Args:
+ options: (config_pb2.RunOptions) The RunOptions instance to be modified.
+ graph: A TensorFlow Graph object.
+ """
+
+ debug_utils.watch_graph(
+ options, graph, debug_urls=self._get_run_debug_urls())
+ options.output_partition_graphs = True
diff --git a/tensorflow/python/debug/wrappers/local_cli_wrapper.py b/tensorflow/python/debug/wrappers/local_cli_wrapper.py
index a5974afde8..74e3236e7d 100644
--- a/tensorflow/python/debug/wrappers/local_cli_wrapper.py
+++ b/tensorflow/python/debug/wrappers/local_cli_wrapper.py
@@ -23,6 +23,8 @@ import shutil
import sys
import tempfile
+import six
+
# Google-internal import(s).
from tensorflow.python.debug import debug_data
from tensorflow.python.debug.cli import analyzer_cli
@@ -30,6 +32,7 @@ from tensorflow.python.debug.cli import curses_ui
from tensorflow.python.debug.cli import debugger_cli_common
from tensorflow.python.debug.wrappers import framework
from tensorflow.python.framework import ops
+from tensorflow.python.ops import variables
_DUMP_ROOT_PREFIX = "tfdbg_"
@@ -180,7 +183,10 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession):
else:
feed_dict_lines = []
for feed_key in request.feed_dict:
- feed_dict_lines.append(feed_key.name)
+ if isinstance(feed_key, six.string_types):
+ feed_dict_lines.append(feed_key)
+ else:
+ feed_dict_lines.append(feed_key.name)
# TODO(cais): Refactor into its own function.
help_intro = [
@@ -457,11 +463,11 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession):
self._run_call_count = run_call_count
self._run_description = "run #%d: " % self._run_call_count
- if isinstance(fetches, ops.Tensor) or isinstance(fetches, ops.Operation):
+ if isinstance(fetches, (ops.Tensor, ops.Operation, variables.Variable)):
self._run_description += "fetch: %s; " % fetches.name
else:
# Could be list, tuple, dict or namedtuple.
- self._run_description += "%d fetch(es)" % len(fetches)
+ self._run_description += "%d fetch(es); " % len(fetches)
if not feed_dict:
self._run_description += "0 feeds"