aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Shanqing Cai <cais@google.com>2016-10-28 08:16:31 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-10-28 09:36:51 -0700
commita304f0a1e430be84fcbe4ffc6bd80e26fcf08393 (patch)
tree80f5be028d51a127d2c2f2c328c628420a3c3b07
parentb281537cd961ff08bf1ce451fd9ad0df642edbe4 (diff)
tfdbg: improvements and fixes to tensor display in CLI
1) Enable scrolling to next regex match with command "/" following "/regex". 2) Enable scrolling to tensor indices with command such as "@[1, 2]" and "@100,30,0". 3) Display tensor indices at the top and bottom of the screen, and in scroll status info bar. 4) Handle invalid regex search commands, e.g., "/[", without crashing. Doc updated accordingly. Change: 137518091
-rw-r--r--tensorflow/python/debug/BUILD2
-rw-r--r--tensorflow/python/debug/cli/command_parser.py28
-rw-r--r--tensorflow/python/debug/cli/command_parser_test.py29
-rw-r--r--tensorflow/python/debug/cli/curses_ui.py356
-rw-r--r--tensorflow/python/debug/cli/curses_ui_test.py394
-rw-r--r--tensorflow/python/debug/cli/debugger_cli_common.py18
-rw-r--r--tensorflow/python/debug/cli/debugger_cli_common_test.py25
-rw-r--r--tensorflow/python/debug/examples/README.md15
8 files changed, 764 insertions, 103 deletions
diff --git a/tensorflow/python/debug/BUILD b/tensorflow/python/debug/BUILD
index 4015d6850b..580bd0e79b 100644
--- a/tensorflow/python/debug/BUILD
+++ b/tensorflow/python/debug/BUILD
@@ -89,6 +89,7 @@ py_library(
deps = [
":command_parser",
":debugger_cli_common",
+ ":tensor_format",
],
)
@@ -180,6 +181,7 @@ py_test(
deps = [
":curses_ui",
":debugger_cli_common",
+ ":tensor_format",
"//tensorflow/python:framework",
"//tensorflow/python:framework_test_lib",
],
diff --git a/tensorflow/python/debug/cli/command_parser.py b/tensorflow/python/debug/cli/command_parser.py
index 4a70468e27..4f940d480c 100644
--- a/tensorflow/python/debug/cli/command_parser.py
+++ b/tensorflow/python/debug/cli/command_parser.py
@@ -77,11 +77,11 @@ def parse_tensor_name_with_slicing(in_str):
Args:
in_str: (str) Input name of the tensor, potentially followed by a slicing
string. E.g.: Without slicing string: "hidden/weights/Variable:0", with
- slicing string: "hidden/weights/Varaible:0[1, :]"
+ slicing string: "hidden/weights/Variable:0[1, :]"
Returns:
(str) name of the tensor
- (str) sliciing string, if any. If no slicing string is present, return "".
+ (str) slicing string, if any. If no slicing string is present, return "".
"""
if in_str.count("[") == 1 and in_str.endswith("]"):
@@ -108,3 +108,27 @@ def validate_slicing_string(slicing_string):
"""
return bool(re.search(r"^\[(\d|,|\s|:)+\]$", slicing_string))
+
+
+def parse_indices(indices_string):
+ """Parse a string representing indices.
+
+ For example, if the input is "[1, 2, 3]", the return value will be a list of
+ indices: [1, 2, 3]
+
+ Args:
+ indices_string: (str) a string representing indices. Can optionally be
+ surrounded by a pair of brackets.
+
+ Returns:
+ (list of int): Parsed indices.
+ """
+
+ # Strip whitespace.
+ indices_string = re.sub(r"\s+", "", indices_string)
+
+ # Strip any brackets at the two ends.
+ if indices_string.startswith("[") and indices_string.endswith("]"):
+ indices_string = indices_string[1:-1]
+
+ return [int(element) for element in indices_string.split(",")]
diff --git a/tensorflow/python/debug/cli/command_parser_test.py b/tensorflow/python/debug/cli/command_parser_test.py
index b819f25e69..acb1d7f90b 100644
--- a/tensorflow/python/debug/cli/command_parser_test.py
+++ b/tensorflow/python/debug/cli/command_parser_test.py
@@ -129,5 +129,34 @@ class ValidateSlicingStringTest(test_util.TensorFlowTestCase):
self.assertFalse(command_parser.validate_slicing_string("[5, bar]"))
+class ParseIndicesTest(test_util.TensorFlowTestCase):
+
+ def testParseValidIndicesStringsWithBrackets(self):
+ self.assertEqual([0], command_parser.parse_indices("[0]"))
+ self.assertEqual([0], command_parser.parse_indices(" [0] "))
+ self.assertEqual([-1, 2], command_parser.parse_indices("[-1, 2]"))
+ self.assertEqual([3, 4, -5],
+ command_parser.parse_indices("[3,4,-5]"))
+
+ def testParseValidIndicesStringsWithoutBrackets(self):
+ self.assertEqual([0], command_parser.parse_indices("0"))
+ self.assertEqual([0], command_parser.parse_indices(" 0 "))
+ self.assertEqual([-1, 2], command_parser.parse_indices("-1, 2"))
+ self.assertEqual([3, 4, -5], command_parser.parse_indices("3,4,-5"))
+
+ def testParseInvalidIndicesStringsWithoutBrackets(self):
+ with self.assertRaisesRegexp(
+ ValueError, r"invalid literal for int\(\) with base 10: 'a'"):
+ self.assertEqual([0], command_parser.parse_indices("0,a"))
+
+ with self.assertRaisesRegexp(
+ ValueError, r"invalid literal for int\(\) with base 10: '2\]'"):
+ self.assertEqual([0], command_parser.parse_indices("1, 2]"))
+
+ with self.assertRaisesRegexp(
+ ValueError, r"invalid literal for int\(\) with base 10: ''"):
+ self.assertEqual([0], command_parser.parse_indices("3, 4,"))
+
+
if __name__ == "__main__":
googletest.main()
diff --git a/tensorflow/python/debug/cli/curses_ui.py b/tensorflow/python/debug/cli/curses_ui.py
index bcdd675f9b..df497f2888 100644
--- a/tensorflow/python/debug/cli/curses_ui.py
+++ b/tensorflow/python/debug/cli/curses_ui.py
@@ -27,6 +27,7 @@ from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.python.debug.cli import command_parser
from tensorflow.python.debug.cli import debugger_cli_common
+from tensorflow.python.debug.cli import tensor_format
class CursesUI(object):
@@ -41,6 +42,8 @@ class CursesUI(object):
CLI_TERMINATOR_KEY = 7 # Terminator key for input text box.
CLI_TAB_KEY = ord("\t")
REGEX_SEARCH_PREFIX = "/"
+ TENSOR_INDICES_NAVIGATION_PREFIX = "@"
+ ERROR_MESSAGE_PREFIX = "ERROR: "
# Possible Enter keys. 343 is curses key code for the num-pad Enter key when
# num lock is off.
@@ -51,6 +54,27 @@ class CursesUI(object):
_SCROLL_DOWN = "down"
_SCROLL_HOME = "home"
_SCROLL_END = "end"
+ _SCROLL_TO_LINE_INDEX = "scroll_to_line_index"
+
+ _FOREGROUND_COLORS = {
+ "white": curses.COLOR_WHITE,
+ "red": curses.COLOR_RED,
+ "green": curses.COLOR_GREEN,
+ "yellow": curses.COLOR_YELLOW,
+ "blue": curses.COLOR_BLUE,
+ "magenta": curses.COLOR_MAGENTA,
+ "black": curses.COLOR_BLACK,
+ }
+ _BACKGROUND_COLORS = {
+ "white": curses.COLOR_WHITE,
+ "black": curses.COLOR_BLACK,
+ }
+
+ # Font attribute for search and highlighting.
+ _SEARCH_HIGHLIGHT_FONT_ATTR = "black_on_white"
+ _ARRAY_INDICES_COLOR_PAIR = "black_on_white"
+ _ERROR_TOAST_COLOR_PAIR = "red_on_white"
+ _STATUS_BAR_COLOR_PAIR = "black_on_white"
def __init__(self):
self._screen_init()
@@ -94,13 +118,11 @@ class CursesUI(object):
# State related to screen output.
self._output_pad = None
+ self._output_pad_row = 0
+ self._output_array_pointer_indices = None
self._curr_unwrapped_output = None
self._curr_wrapped_output = None
- # NamedTuple for rectangular locations on screen
- self.rectangle = collections.namedtuple("rectangle",
- "top left bottom right")
-
# Register signal handler for SIGINT.
signal.signal(signal.SIGINT, self._interrupt_handler)
@@ -111,6 +133,10 @@ class CursesUI(object):
and output region according to the terminal size.
"""
+ # NamedTuple for rectangular locations on screen
+ self.rectangle = collections.namedtuple("rectangle",
+ "top left bottom right")
+
# Height of command text box
self._command_textbox_height = 2
@@ -138,11 +164,22 @@ class CursesUI(object):
# Maximum number of lines the candidates display can have.
self._candidates_max_lines = int(self._output_num_rows / 2)
- # Font attribute for search and highlighting.
- self._search_highlight_font_attr = "bw_reversed"
-
self.max_output_lines = 10000
+ # Regex search state.
+ self._curr_search_regex = None
+ self._regex_match_lines = None
+
+ # Size of view port on screen, which is always smaller or equal to the
+ # screen size.
+ self._output_pad_screen_height = self._output_num_rows - 1
+ self._output_pad_screen_width = self._max_x - 1
+ self._output_pad_screen_location = self.rectangle(
+ top=self._output_top_row,
+ left=0,
+ bottom=self._output_top_row + self._output_num_rows,
+ right=self._output_pad_screen_width)
+
def _screen_init(self):
"""Screen initialization.
@@ -155,24 +192,21 @@ class CursesUI(object):
# Prepare color pairs.
curses.start_color()
- curses.init_pair(1, curses.COLOR_WHITE, curses.COLOR_BLACK)
- curses.init_pair(2, curses.COLOR_RED, curses.COLOR_BLACK)
- curses.init_pair(3, curses.COLOR_GREEN, curses.COLOR_BLACK)
- curses.init_pair(4, curses.COLOR_YELLOW, curses.COLOR_BLACK)
- curses.init_pair(5, curses.COLOR_BLUE, curses.COLOR_BLACK)
- curses.init_pair(6, curses.COLOR_MAGENTA, curses.COLOR_BLACK)
- curses.init_pair(7, curses.COLOR_BLACK, curses.COLOR_WHITE)
-
self._color_pairs = {}
- self._color_pairs["white"] = curses.color_pair(1)
- self._color_pairs["red"] = curses.color_pair(2)
- self._color_pairs["green"] = curses.color_pair(3)
- self._color_pairs["yellow"] = curses.color_pair(4)
- self._color_pairs["blue"] = curses.color_pair(5)
- self._color_pairs["magenta"] = curses.color_pair(6)
+ color_index = 0
- # Black-white reversed
- self._color_pairs["bw_reversed"] = curses.color_pair(7)
+ for fg_color in self._FOREGROUND_COLORS:
+ for bg_color in self._BACKGROUND_COLORS:
+
+ color_index += 1
+ curses.init_pair(color_index, self._FOREGROUND_COLORS[fg_color],
+ self._BACKGROUND_COLORS[bg_color])
+
+ color_name = fg_color
+ if bg_color != "black":
+ color_name += "_on_" + bg_color
+
+ self._color_pairs[color_name] = curses.color_pair(color_index)
# A_BOLD is not really a "color". But place it here for convenience.
self._color_pairs["bold"] = curses.A_BOLD
@@ -391,18 +425,41 @@ class CursesUI(object):
if command:
self._command_history_store.add_command(command)
- if (len(command) > len(self.REGEX_SEARCH_PREFIX) and
- command.startswith(self.REGEX_SEARCH_PREFIX) and
+ if (command.startswith(self.REGEX_SEARCH_PREFIX) and
self._curr_unwrapped_output):
- # Regex search and highlighting in screen output.
- regex = command[len(self.REGEX_SEARCH_PREFIX):]
+ if len(command) > len(self.REGEX_SEARCH_PREFIX):
+ # Command is like "/regex". Perform regex search.
+ regex = command[len(self.REGEX_SEARCH_PREFIX):]
+
+ self._curr_search_regex = regex
+ self._display_output(self._curr_unwrapped_output, highlight_regex=regex)
+ elif self._regex_match_lines:
+ # Command is "/". Continue scrolling down matching lines.
+ self._display_output(
+ self._curr_unwrapped_output,
+ is_refresh=True,
+ highlight_regex=self._curr_search_regex)
- # TODO(cais): Support scrolling to matches.
- # TODO(cais): Display warning message on screen if no match.
- self._display_output(self._curr_unwrapped_output, highlight_regex=regex)
self._command_pointer = 0
self._pending_command = ""
return
+ elif command.startswith(self.TENSOR_INDICES_NAVIGATION_PREFIX):
+ indices_str = command[1:].strip()
+ if indices_str:
+ try:
+ indices = command_parser.parse_indices(indices_str)
+ omitted, line_index = tensor_format.locate_tensor_element(
+ self._curr_wrapped_output, indices)
+
+ if not omitted:
+ self._scroll_output(
+ self._SCROLL_TO_LINE_INDEX, line_index=line_index)
+ except Exception as e: # pylint: disable=broad-except
+ self._error_toast(str(e))
+ else:
+ self._error_toast("Empty indices.")
+
+ return
prefix, args = self._parse_command(command)
@@ -420,7 +477,7 @@ class CursesUI(object):
exit_token = e.exit_token
else:
screen_output = debugger_cli_common.RichTextLines([
- "ERROR: Invalid command prefix \"%s\"" % prefix
+ self.ERROR_MESSAGE_PREFIX + "Invalid command prefix \"%s\"" % prefix
])
# Clear active command history. Until next up/down history navigation
@@ -613,25 +670,28 @@ class CursesUI(object):
return curses.newpad(rows, cols)
- def _display_output(self, output, is_refresh=False, highlight_regex=None):
- """Display text output in a scrollable text pad.
+ def _screen_display_output(self, output):
+ """Actually render text output on the screen.
+
+ Wraps the lines according to screen width. Pad lines below according to
+ screen height so that the user can scroll the output to a state where
+ the last non-empty line is on the top of the screen. Then renders the
+ lines on the screen.
Args:
- output: A RichTextLines object that is the screen output text.
- is_refresh: (bool) Is this a refreshing display with existing output.
- highlight_regex: (str) Optional string representing the regex used to
- search and highlight in the current screen output.
+ output: (RichTextLines) text lines to display on the screen. These lines
+ may have widths exceeding the screen width. This method will take care
+ of the wrapping.
"""
- if highlight_regex:
- output = debugger_cli_common.regex_find(
- output, highlight_regex, font_attr=self._search_highlight_font_attr)
- else:
- self._curr_unwrapped_output = output
-
+ # Wrap the output lines according to screen width.
self._curr_wrapped_output = debugger_cli_common.wrap_rich_text_lines(
output, self._max_x - 1)
+ # Append lines to curr_wrapped_output so that the user can scroll to a
+ # state where the last text line is on the top of the output area.
+ self._curr_wrapped_output.lines.extend([""] * (self._output_num_rows - 1))
+
# Limit number of lines displayed to avoid curses overflow problems.
if self._curr_wrapped_output.num_lines() > self.max_output_lines:
self._curr_wrapped_output = self._curr_wrapped_output.slice(
@@ -646,17 +706,64 @@ class CursesUI(object):
self._output_pad_width) = self._display_lines(self._curr_wrapped_output,
self._output_num_rows)
- # Size of view port on screen, which is always smaller or equal to the
- # screen size.
- self._output_pad_screen_height = self._output_num_rows - 1
- self._output_pad_screen_width = self._max_x - 1
- self._output_pad_screen_location = self.rectangle(
- top=self._output_top_row,
- left=0,
- bottom=self._output_top_row + self._output_num_rows,
- right=self._output_pad_screen_width)
+ def _display_output(self, output, is_refresh=False, highlight_regex=None):
+ """Display text output in a scrollable text pad.
+
+ This method does some preprocessing on the text lines, render them on the
+ screen and scroll to the appropriate line. These are done according to regex
+ highlighting requests (if any), scroll-to-next-match requests (if any),
+ and screen refrexh requests (if any).
+
+ TODO(cais): Separate these unrelated request to increase clarity and
+ maintainability.
+
+ Args:
+ output: A RichTextLines object that is the screen output text.
+ is_refresh: (bool) Is this a refreshing display with existing output.
+ highlight_regex: (str) Optional string representing the regex used to
+ search and highlight in the current screen output.
+ """
+
+ if highlight_regex:
+ try:
+ output = debugger_cli_common.regex_find(
+ output, highlight_regex, font_attr=self._SEARCH_HIGHLIGHT_FONT_ATTR)
+ except ValueError as e:
+ self._error_toast(str(e))
+ return
+
+ if not is_refresh:
+ # Perform new regex search on the current output.
+ self._regex_match_lines = output.annotations[
+ debugger_cli_common.REGEX_MATCH_LINES_KEY]
+ else:
+ # Continue scrolling down.
+ self._output_pad_row += 1
+ else:
+ self._curr_unwrapped_output = output
+
+ # Display output on the screen.
+ self._screen_display_output(output)
+
+ # Now that the text lines are displayed on the screen scroll to the
+ # appropriate line according to previous scrolling state and regex search
+ # and highlighting state.
- if is_refresh:
+ if highlight_regex:
+ next_match_line = -1
+ for match_line in self._regex_match_lines:
+ if match_line >= self._output_pad_row:
+ next_match_line = match_line
+ break
+
+ if next_match_line >= 0:
+ self._scroll_output(
+ self._SCROLL_TO_LINE_INDEX, line_index=next_match_line)
+ else:
+ # Regex search found no match >= current line number. Display message
+ # stating as such.
+ self._toast("Pattern not found", color=self._ERROR_TOAST_COLOR_PAIR)
+ elif is_refresh:
self._scroll_output(self._SCROLL_REFRESH)
else:
self._output_pad_row = 0
@@ -764,15 +871,19 @@ class CursesUI(object):
screen_location_left, screen_location_bottom,
screen_location_right)
- def _scroll_output(self, direction):
+ def _scroll_output(self, direction, line_index=None):
"""Scroll the output pad.
Args:
direction: _SCROLL_REFRESH, _SCROLL_UP, _SCROLL_DOWN, _SCROLL_HOME or
- _SCROLL_END
+ _SCROLL_END, _SCROLL_TO_LINE_INDEX
+ line_index: (int) Specifies the zero-based line index to scroll to.
+ Applicable only if direction is _SCROLL_TO_LINE_INDEX.
Raises:
ValueError: On invalid scroll direction.
+ TypeError: If line_index is not int and direction is
+ _SCROLL_TO_LINE_INDEX.
"""
if not self._output_pad:
@@ -797,6 +908,11 @@ class CursesUI(object):
# Scroll to bottom
self._output_pad_row = (
self._output_pad_height - self._output_pad_screen_height - 1)
+ elif direction == self._SCROLL_TO_LINE_INDEX:
+ if not isinstance(line_index, int):
+ raise TypeError("Invalid line_index type (%s) under mode %s" %
+ (type(line_index), self._SCROLL_TO_LINE_INDEX))
+ self._output_pad_row = line_index
else:
raise ValueError("Unsupported scroll mode: %s" % direction)
@@ -809,18 +925,103 @@ class CursesUI(object):
if self._output_pad_height > self._output_pad_screen_height + 1:
# Display information about the scrolling of tall screen output.
- self._scroll_info = "--- Scroll: %.2f%% " % (100.0 * (
+ self._scroll_info = "--- Scroll: %.2f%% " % (100.0 * (min(
+ 1.0,
float(self._output_pad_row) /
- (self._output_pad_height - self._output_pad_screen_height - 1)))
+ (self._output_pad_height - self._output_pad_screen_height - 1))))
+
+ self._output_array_pointer_indices = self._show_array_indices()
+
+ # Add array indices information to scroll message.
+ if self._output_array_pointer_indices:
+ if self._output_array_pointer_indices[0]:
+ self._scroll_info += self._format_indices(
+ self._output_array_pointer_indices[0])
+ self._scroll_info += "-"
+ if self._output_array_pointer_indices[-1]:
+ self._scroll_info += self._format_indices(
+ self._output_array_pointer_indices[-1])
+ self._scroll_info += " "
+
if len(self._scroll_info) < self._max_x:
self._scroll_info += "-" * (self._max_x - len(self._scroll_info))
self._screen_draw_text_line(
- self._output_scroll_row, self._scroll_info, color="green")
+ self._output_scroll_row,
+ self._scroll_info,
+ color=self._STATUS_BAR_COLOR_PAIR)
else:
# Screen output is not tall enough to cause scrolling.
self._scroll_info = "-" * self._max_x
self._screen_draw_text_line(
- self._output_scroll_row, self._scroll_info, color="green")
+ self._output_scroll_row,
+ self._scroll_info,
+ color=self._STATUS_BAR_COLOR_PAIR)
+
+ def _format_indices(self, indices):
+ # Remove the spaces to make it compact.
+ return repr(indices).replace(" ", "")
+
+ def _show_array_indices(self):
+ """Show array indices for the lines at the top and bottom of the output.
+
+ For the top line and bottom line of the output display area, show the
+ element indices of the array being displayed.
+
+ Returns:
+ If either the top of the bottom row has any matching array indices,
+ a dict from line index (0 being the top of the display area, -1
+ being the bottom of the display area) to array element indices. For
+ example:
+ {0: [0, 0], -1: [10, 0]}
+ Otherwise, None.
+ """
+
+ indices_top = self._show_array_index_at_line(0)
+
+ bottom_line_index = (self._output_pad_screen_location.bottom -
+ self._output_pad_screen_location.top - 1)
+ indices_bottom = self._show_array_index_at_line(bottom_line_index)
+
+ if indices_top or indices_bottom:
+ return {0: indices_top, -1: indices_bottom}
+ else:
+ return None
+
+ def _show_array_index_at_line(self, line_index):
+ """Show array indices for the specified line in the display area.
+
+ Uses the line number to array indices map in the annotations field of the
+ RichTextLines object being displayed.
+ If the displayed RichTextLines object does not contain such a mapping,
+ will do nothing.
+
+ Args:
+ line_index: (int) 0-based line index from the top of the display area.
+ For example,if line_index == 0, this method will display the array
+ indices for the line currently at the top of the display area.
+
+ Returns:
+ (list) The array indices at the specified line, if available. None, if
+ not available.
+ """
+
+ # Examine whether the index information is available for the specified line
+ # number.
+ pointer = self._output_pad_row + line_index
+ if pointer in self._curr_wrapped_output.annotations:
+ indices = self._curr_wrapped_output.annotations[pointer]["i0"]
+
+ array_indices_str = self._format_indices(indices)
+ array_indices_info = "@" + array_indices_str
+
+ self._toast(
+ array_indices_info,
+ color=self._ARRAY_INDICES_COLOR_PAIR,
+ line_index=self._output_pad_screen_location.top + line_index)
+
+ return indices
+ else:
+ return None
def _tab_complete(self, command_str):
"""Perform tab completion.
@@ -899,7 +1100,7 @@ class CursesUI(object):
})
candidates_output = debugger_cli_common.wrap_rich_text_lines(
- candidates_output, self._max_x - 1)
+ candidates_output, self._max_x - 2)
# Calculate how many lines the candidate text should occupy. Limit it to
# a maximum value.
@@ -914,6 +1115,41 @@ class CursesUI(object):
pad, 0, 0, self._candidates_top_row, 0,
self._candidates_top_row + candidates_num_rows - 1, self._max_x - 1)
+ def _toast(self, message, color=None, line_index=None):
+ """Display a one-line message on the screen.
+
+ By default, the toast is displayed in the line right above the scroll bar.
+ But the line location can be overridden with the line_index arg.
+
+ Args:
+ message: (str) the message to display.
+ color: (str) optional color attribute for the message.
+ line_index: (int) line index.
+ """
+
+ pad, _, _ = self._display_lines(
+ debugger_cli_common.RichTextLines(
+ message,
+ font_attr_segs={0: [(0, len(message), color or "white")]}),
+ 0)
+
+ right_end = min(len(message), self._max_x - 1)
+
+ if line_index is None:
+ line_index = self._output_scroll_row - 1
+ self._screen_scroll_output_pad(pad, 0, 0, line_index, 0, line_index,
+ right_end)
+
+ def _error_toast(self, message):
+ """Display a one-line error message on screen.
+
+ Args:
+ message: The error message, without the preceding "ERROR: " substring.
+ """
+
+ self._toast(
+ self.ERROR_MESSAGE_PREFIX + message, color=self._ERROR_TOAST_COLOR_PAIR)
+
def _interrupt_handler(self, signal_num, frame):
_ = signal_num # Unused.
_ = frame # Unused.
diff --git a/tensorflow/python/debug/cli/curses_ui_test.py b/tensorflow/python/debug/cli/curses_ui_test.py
index d1dba5a1fb..1905497870 100644
--- a/tensorflow/python/debug/cli/curses_ui_test.py
+++ b/tensorflow/python/debug/cli/curses_ui_test.py
@@ -20,8 +20,11 @@ from __future__ import print_function
import argparse
import curses
+import numpy as np
+
from tensorflow.python.debug.cli import curses_ui
from tensorflow.python.debug.cli import debugger_cli_common
+from tensorflow.python.debug.cli import tensor_format
from tensorflow.python.framework import test_util
from tensorflow.python.platform import googletest
@@ -54,13 +57,19 @@ class MockCursesUI(curses_ui.CursesUI):
self.unwrapped_outputs = []
self.wrapped_outputs = []
self.scroll_messages = []
+ self.output_array_pointer_indices = []
+
+ self.output_pad_rows = []
# Observers of command textbox.
self.existing_commands = []
- # Observers for tab-completion candidates.
+ # Observer for tab-completion candidates.
self.candidates_lists = []
+ # Observer for toast messages.
+ self.toasts = []
+
curses_ui.CursesUI.__init__(self)
# Below, override the _screen_ prefixed member methods that interact with the
@@ -146,7 +155,7 @@ class MockCursesUI(curses_ui.CursesUI):
return codes_to_string(self._command_sequence[self._command_counter]
[:self._command_key_counter])
- def _scroll_output(self, direction):
+ def _scroll_output(self, direction, line_index=None):
"""Override to observe screen output.
This method is invoked after every command that generates a new screen
@@ -155,19 +164,28 @@ class MockCursesUI(curses_ui.CursesUI):
Args:
direction: which direction to scroll.
+ line_index: (int or None) Optional line index to scroll to. See doc string
+ of the overridden method for more information.
"""
- curses_ui.CursesUI._scroll_output(self, direction)
+ curses_ui.CursesUI._scroll_output(self, direction, line_index=line_index)
self.unwrapped_outputs.append(self._curr_unwrapped_output)
self.wrapped_outputs.append(self._curr_wrapped_output)
self.scroll_messages.append(self._scroll_info)
+ self.output_array_pointer_indices.append(self._output_array_pointer_indices)
+ self.output_pad_rows.append(self._output_pad_row)
def _display_candidates(self, candidates):
curses_ui.CursesUI._display_candidates(self, candidates)
self.candidates_lists.append(candidates)
+ def _toast(self, message, color=None, line_index=None):
+ curses_ui.CursesUI._toast(self, message, color=color, line_index=line_index)
+
+ self.toasts.append(message)
+
class CursesTest(test_util.TensorFlowTestCase):
@@ -188,6 +206,24 @@ class CursesTest(test_util.TensorFlowTestCase):
return debugger_cli_common.RichTextLines(["bar"] * parsed.num_times)
+ def _print_ones(self, args, screen_info=None):
+ ap = argparse.ArgumentParser(
+ description="Print all-one matrix.", usage=argparse.SUPPRESS)
+ ap.add_argument(
+ "-s",
+ "--size",
+ dest="size",
+ type=int,
+ default=3,
+ help="Size of the matrix. For example, of the value is 3, "
+ "the matrix will have shape (3, 3)")
+
+ parsed = ap.parse_args(args)
+
+ m = np.ones([parsed.size, parsed.size])
+
+ return tensor_format.format_tensor(m, "m")
+
def testInitialization(self):
ui = MockCursesUI(40, 80)
@@ -229,8 +265,9 @@ class CursesTest(test_util.TensorFlowTestCase):
self.assertEqual(["ERROR: Invalid command prefix \"foo\""],
ui.unwrapped_outputs[0].lines)
+ # TODO(cais): Add explanation for the 35 extra lines.
self.assertEqual(["ERROR: Invalid command prefix \"foo\""],
- ui.wrapped_outputs[0].lines)
+ ui.wrapped_outputs[0].lines[:1])
# A single line of output should not have caused scrolling.
self.assertEqual("-" * 80, ui.scroll_messages[0])
@@ -273,7 +310,7 @@ class CursesTest(test_util.TensorFlowTestCase):
# Before scrolling.
self.assertEqual(["bar"] * 60, ui.unwrapped_outputs[0].lines)
- self.assertEqual(["bar"] * 60, ui.wrapped_outputs[0].lines)
+ self.assertEqual(["bar"] * 60, ui.wrapped_outputs[0].lines[:60])
# Initial scroll: At the top.
self.assertIn("Scroll: 0.00%", ui.scroll_messages[0])
@@ -281,14 +318,14 @@ class CursesTest(test_util.TensorFlowTestCase):
# After 1st scrolling (PageDown).
# The screen output shouldn't have changed. Only the viewport should.
self.assertEqual(["bar"] * 60, ui.unwrapped_outputs[0].lines)
- self.assertEqual(["bar"] * 60, ui.wrapped_outputs[0].lines)
- self.assertIn("Scroll: 4.17%", ui.scroll_messages[1])
+ self.assertEqual(["bar"] * 60, ui.wrapped_outputs[0].lines[:60])
+ self.assertIn("Scroll: 1.69%", ui.scroll_messages[1])
# After 2nd scrolling (PageDown).
- self.assertIn("Scroll: 8.33%", ui.scroll_messages[2])
+ self.assertIn("Scroll: 3.39%", ui.scroll_messages[2])
# After 3rd scrolling (PageUp).
- self.assertIn("Scroll: 4.17%", ui.scroll_messages[3])
+ self.assertIn("Scroll: 1.69%", ui.scroll_messages[3])
def testCutOffTooManyOutputLines(self):
ui = MockCursesUI(
@@ -304,7 +341,7 @@ class CursesTest(test_util.TensorFlowTestCase):
ui.run_ui()
self.assertEqual(["bar"] * 10 + ["Output cut off at 10 lines!"],
- ui.wrapped_outputs[0].lines)
+ ui.wrapped_outputs[0].lines[:11])
def testRunUIScrollTallOutputEndHome(self):
"""Scroll tall output with PageDown and PageUp."""
@@ -328,7 +365,7 @@ class CursesTest(test_util.TensorFlowTestCase):
# Before scrolling.
self.assertEqual(["bar"] * 60, ui.unwrapped_outputs[0].lines)
- self.assertEqual(["bar"] * 60, ui.wrapped_outputs[0].lines)
+ self.assertEqual(["bar"] * 60, ui.wrapped_outputs[0].lines[:60])
# Initial scroll: At the top.
self.assertIn("Scroll: 0.00%", ui.scroll_messages[0])
@@ -353,7 +390,7 @@ class CursesTest(test_util.TensorFlowTestCase):
self.assertEqual(1, len(ui.unwrapped_outputs))
self.assertEqual(["bar"] * 60, ui.unwrapped_outputs[0].lines)
- self.assertEqual(["bar"] * 60, ui.wrapped_outputs[0].lines)
+ self.assertEqual(["bar"] * 60, ui.wrapped_outputs[0].lines[:60])
self.assertIn("Scroll: 0.00%", ui.scroll_messages[0])
def testCompileHelpWithoutHelpIntro(self):
@@ -529,10 +566,6 @@ class CursesTest(test_util.TensorFlowTestCase):
# is less than the number of lines in the output.
self.assertIn("Scroll: 0.00%", ui.scroll_messages[0])
- # The 2nd scroll info should contain no scrolling, because the screen size
- # is now greater than the numberf lines in the output.
- self.assertEqual("-" * 85, ui.scroll_messages[1])
-
def testTabCompletionWithCommonPrefix(self):
# Type "b" and trigger tab completion.
ui = MockCursesUI(
@@ -556,7 +589,7 @@ class CursesTest(test_util.TensorFlowTestCase):
self.assertEqual(1, len(ui.wrapped_outputs))
self.assertEqual(1, len(ui.scroll_messages))
self.assertEqual(["bar"] * 60, ui.unwrapped_outputs[0].lines)
- self.assertEqual(["bar"] * 60, ui.wrapped_outputs[0].lines)
+ self.assertEqual(["bar"] * 60, ui.wrapped_outputs[0].lines[:60])
def testTabCompletionEmptyTriggerWithoutCommonPrefix(self):
ui = MockCursesUI(
@@ -603,7 +636,7 @@ class CursesTest(test_util.TensorFlowTestCase):
self.assertEqual(1, len(ui.wrapped_outputs))
self.assertEqual(1, len(ui.scroll_messages))
self.assertEqual(["bar"] * 60, ui.unwrapped_outputs[0].lines)
- self.assertEqual(["bar"] * 60, ui.wrapped_outputs[0].lines)
+ self.assertEqual(["bar"] * 60, ui.wrapped_outputs[0].lines[:60])
def testTabCompletionNoMatch(self):
ui = MockCursesUI(
@@ -625,7 +658,7 @@ class CursesTest(test_util.TensorFlowTestCase):
self.assertEqual(["ERROR: Invalid command prefix \"c\""],
ui.unwrapped_outputs[0].lines)
self.assertEqual(["ERROR: Invalid command prefix \"c\""],
- ui.wrapped_outputs[0].lines)
+ ui.wrapped_outputs[0].lines[:1])
def testTabCompletionOneWordContext(self):
ui = MockCursesUI(
@@ -648,7 +681,7 @@ class CursesTest(test_util.TensorFlowTestCase):
self.assertEqual(1, len(ui.wrapped_outputs))
self.assertEqual(1, len(ui.scroll_messages))
self.assertEqual(["bar"] * 30, ui.unwrapped_outputs[0].lines)
- self.assertEqual(["bar"] * 30, ui.wrapped_outputs[0].lines)
+ self.assertEqual(["bar"] * 30, ui.wrapped_outputs[0].lines[:30])
def testTabCompletionTwice(self):
ui = MockCursesUI(
@@ -674,7 +707,7 @@ class CursesTest(test_util.TensorFlowTestCase):
self.assertEqual(1, len(ui.wrapped_outputs))
self.assertEqual(1, len(ui.scroll_messages))
self.assertEqual(["bar"] * 123, ui.unwrapped_outputs[0].lines)
- self.assertEqual(["bar"] * 123, ui.wrapped_outputs[0].lines)
+ self.assertEqual(["bar"] * 123, ui.wrapped_outputs[0].lines[:123])
def testRegexSearch(self):
"""Test regex search."""
@@ -703,21 +736,128 @@ class CursesTest(test_util.TensorFlowTestCase):
self.assertEqual(3, len(ui.wrapped_outputs))
# The first output should have no highlighting.
- self.assertEqual(["bar"] * 3, ui.wrapped_outputs[0].lines)
+ self.assertEqual(["bar"] * 3, ui.wrapped_outputs[0].lines[:3])
self.assertEqual({}, ui.wrapped_outputs[0].font_attr_segs)
# The second output should have highlighting for "b" and "r".
- self.assertEqual(["bar"] * 3, ui.wrapped_outputs[1].lines)
+ self.assertEqual(["bar"] * 3, ui.wrapped_outputs[1].lines[:3])
for i in range(3):
- self.assertEqual([(0, 1, "bw_reversed"), (2, 3, "bw_reversed")],
+ self.assertEqual([(0, 1, "black_on_white"), (2, 3, "black_on_white")],
ui.wrapped_outputs[1].font_attr_segs[i])
# The third output should have highlighting for "a" only.
- self.assertEqual(["bar"] * 3, ui.wrapped_outputs[1].lines)
+ self.assertEqual(["bar"] * 3, ui.wrapped_outputs[1].lines[:3])
for i in range(3):
- self.assertEqual([(1, 2, "bw_reversed")],
+ self.assertEqual([(1, 2, "black_on_white")],
ui.wrapped_outputs[2].font_attr_segs[i])
+ def testRegexSearchContinuation(self):
+ """Test continuing scrolling down to next regex match."""
+
+ ui = MockCursesUI(
+ 40,
+ 80,
+ command_sequence=[
+ string_to_codes("babble -n 3\n"),
+ string_to_codes("/(b|r)\n"), # Regex search and highlight.
+ string_to_codes("/\n"), # Continue scrolling down: 1st time.
+ string_to_codes("/\n"), # Continue scrolling down: 2nd time.
+ string_to_codes("/\n"), # Continue scrolling down: 3rd time.
+ string_to_codes("/\n"), # Continue scrolling down: 4th time.
+ self._EXIT
+ ])
+
+ ui.register_command_handler(
+ "babble", self._babble, "babble some", prefix_aliases=["b"])
+ ui.run_ui()
+
+ # The 1st output is for the non-searched output. The other three are for
+ # the searched output. Even though continuation search "/" is performed
+ # four times, there should be only three searched outputs, because the
+ # last one has exceeded the end.
+ self.assertEqual(4, len(ui.unwrapped_outputs))
+
+ for i in range(4):
+ self.assertEqual(["bar"] * 3, ui.unwrapped_outputs[i].lines)
+ self.assertEqual({}, ui.unwrapped_outputs[i].font_attr_segs)
+
+ self.assertEqual(["bar"] * 3, ui.wrapped_outputs[0].lines[:3])
+ self.assertEqual({}, ui.wrapped_outputs[0].font_attr_segs)
+
+ for j in range(1, 4):
+ self.assertEqual(["bar"] * 3, ui.wrapped_outputs[j].lines[:3])
+ self.assertEqual({
+ 0: [(0, 1, "black_on_white"), (2, 3, "black_on_white")],
+ 1: [(0, 1, "black_on_white"), (2, 3, "black_on_white")],
+ 2: [(0, 1, "black_on_white"), (2, 3, "black_on_white")]
+ }, ui.wrapped_outputs[j].font_attr_segs)
+
+ self.assertEqual([0, 0, 1, 2], ui.output_pad_rows)
+
+ def testRegexSearchNoMatchContinuation(self):
+ """Test continuing scrolling when there is no regex match."""
+
+ ui = MockCursesUI(
+ 40,
+ 80,
+ command_sequence=[
+ string_to_codes("babble -n 3\n"),
+ string_to_codes("/foo\n"), # Regex search and highlight.
+ string_to_codes("/\n"), # Continue scrolling down.
+ self._EXIT
+ ])
+
+ ui.register_command_handler(
+ "babble", self._babble, "babble some", prefix_aliases=["b"])
+ ui.run_ui()
+
+ # The regex search and continuation search should not have produced any
+ # output.
+ self.assertEqual(1, len(ui.unwrapped_outputs))
+ self.assertEqual([0], ui.output_pad_rows)
+
+ def testRegexSearchContinuationWithoutSearch(self):
+ """Test continuation scrolling when no regex search has been performed."""
+
+ ui = MockCursesUI(
+ 40,
+ 80,
+ command_sequence=[
+ string_to_codes("babble -n 3\n"),
+ string_to_codes("/\n"), # Continue scrolling without search first.
+ self._EXIT
+ ])
+
+ ui.register_command_handler(
+ "babble", self._babble, "babble some", prefix_aliases=["b"])
+ ui.run_ui()
+
+ self.assertEqual(1, len(ui.unwrapped_outputs))
+ self.assertEqual([0], ui.output_pad_rows)
+
+ def testRegexSearchWithInvalidRegex(self):
+ """Test using invalid regex to search."""
+
+ ui = MockCursesUI(
+ 40,
+ 80,
+ command_sequence=[
+ string_to_codes("babble -n 3\n"),
+ string_to_codes("/[\n"), # Continue scrolling without search first.
+ self._EXIT
+ ])
+
+ ui.register_command_handler(
+ "babble", self._babble, "babble some", prefix_aliases=["b"])
+ ui.run_ui()
+
+ # Invalid regex should not have led to a new screen of output.
+ self.assertEqual(1, len(ui.unwrapped_outputs))
+ self.assertEqual([0], ui.output_pad_rows)
+
+ # Invalid regex should have led to a toast error message.
+ self.assertEqual(["ERROR: Invalid regular expression: \"[\""], ui.toasts)
+
def testRegexSearchFromCommandHistory(self):
"""Test regex search commands are recorded in command history."""
@@ -740,24 +880,214 @@ class CursesTest(test_util.TensorFlowTestCase):
self.assertEqual(4, len(ui.wrapped_outputs))
- self.assertEqual(["bar"] * 3, ui.wrapped_outputs[0].lines)
+ self.assertEqual(["bar"] * 3, ui.wrapped_outputs[0].lines[:3])
self.assertEqual({}, ui.wrapped_outputs[0].font_attr_segs)
- self.assertEqual(["bar"] * 3, ui.wrapped_outputs[1].lines)
+ self.assertEqual(["bar"] * 3, ui.wrapped_outputs[1].lines[:3])
for i in range(3):
- self.assertEqual([(0, 1, "bw_reversed"), (2, 3, "bw_reversed")],
+ self.assertEqual([(0, 1, "black_on_white"), (2, 3, "black_on_white")],
ui.wrapped_outputs[1].font_attr_segs[i])
- self.assertEqual(["bar"] * 4, ui.wrapped_outputs[2].lines)
+ self.assertEqual(["bar"] * 4, ui.wrapped_outputs[2].lines[:4])
self.assertEqual({}, ui.wrapped_outputs[2].font_attr_segs)
# The regex search command loaded from history should have worked on the
# new screen output.
- self.assertEqual(["bar"] * 4, ui.wrapped_outputs[3].lines)
+ self.assertEqual(["bar"] * 4, ui.wrapped_outputs[3].lines[:4])
for i in range(4):
- self.assertEqual([(0, 1, "bw_reversed"), (2, 3, "bw_reversed")],
+ self.assertEqual([(0, 1, "black_on_white"), (2, 3, "black_on_white")],
ui.wrapped_outputs[3].font_attr_segs[i])
+ def testDisplayTensorWithIndices(self):
+ """Test displaying tensor with indices."""
+
+ ui = MockCursesUI(
+ 8, # Use a small screen height to cause scrolling.
+ 80,
+ command_sequence=[
+ string_to_codes("print_ones --size 5\n"),
+ [curses.KEY_NPAGE],
+ [curses.KEY_NPAGE],
+ [curses.KEY_NPAGE],
+ [curses.KEY_END],
+ [curses.KEY_NPAGE], # This PageDown goes over the bottom limit.
+ [curses.KEY_PPAGE],
+ [curses.KEY_PPAGE],
+ [curses.KEY_PPAGE],
+ [curses.KEY_HOME],
+ [curses.KEY_PPAGE], # This PageDown goes over the top limit.
+ self._EXIT
+ ])
+
+ ui.register_command_handler("print_ones", self._print_ones,
+ "print an all-one matrix of specified size")
+ ui.run_ui()
+
+ self.assertEqual(11, len(ui.unwrapped_outputs))
+ self.assertEqual(11, len(ui.output_array_pointer_indices))
+ self.assertEqual(11, len(ui.scroll_messages))
+
+ for i in range(11):
+ self.assertEqual([
+ "Tensor \"m\":", "", "array([[ 1., 1., 1., 1., 1.],",
+ " [ 1., 1., 1., 1., 1.],",
+ " [ 1., 1., 1., 1., 1.],",
+ " [ 1., 1., 1., 1., 1.],",
+ " [ 1., 1., 1., 1., 1.]])"
+ ], ui.unwrapped_outputs[i].lines)
+
+ self.assertEqual({
+ 0: None,
+ -1: [1, 0]
+ }, ui.output_array_pointer_indices[0])
+ self.assertIn(" Scroll: 0.00% -[1,0] ", ui.scroll_messages[0])
+
+ # Scrolled down one line.
+ self.assertEqual({
+ 0: None,
+ -1: [2, 0]
+ }, ui.output_array_pointer_indices[1])
+ self.assertIn(" Scroll: 16.67% -[2,0] ", ui.scroll_messages[1])
+
+ # Scrolled down one line.
+ self.assertEqual({
+ 0: [0, 0],
+ -1: [3, 0]
+ }, ui.output_array_pointer_indices[2])
+ self.assertIn(" Scroll: 33.33% [0,0]-[3,0] ", ui.scroll_messages[2])
+
+ # Scrolled down one line.
+ self.assertEqual({
+ 0: [1, 0],
+ -1: [4, 0]
+ }, ui.output_array_pointer_indices[3])
+ self.assertIn(" Scroll: 50.00% [1,0]-[4,0] ", ui.scroll_messages[3])
+
+ # Scroll to the bottom.
+ self.assertEqual({
+ 0: [4, 0],
+ -1: None
+ }, ui.output_array_pointer_indices[4])
+ self.assertIn(" Scroll: 100.00% [4,0]- ", ui.scroll_messages[4])
+
+ # Attempt to scroll beyond the bottom should lead to no change.
+ self.assertEqual({
+ 0: [4, 0],
+ -1: None
+ }, ui.output_array_pointer_indices[5])
+ self.assertIn(" Scroll: 100.00% [4,0]- ", ui.scroll_messages[5])
+
+ # Scrolled up one line.
+ self.assertEqual({
+ 0: [3, 0],
+ -1: None
+ }, ui.output_array_pointer_indices[6])
+ self.assertIn(" Scroll: 83.33% [3,0]- ", ui.scroll_messages[6])
+
+ # Scrolled up one line.
+ self.assertEqual({
+ 0: [2, 0],
+ -1: None
+ }, ui.output_array_pointer_indices[7])
+ self.assertIn(" Scroll: 66.67% [2,0]- ", ui.scroll_messages[7])
+
+ # Scrolled up one line.
+ self.assertEqual({
+ 0: [1, 0],
+ -1: [4, 0]
+ }, ui.output_array_pointer_indices[8])
+ self.assertIn(" Scroll: 50.00% [1,0]-[4,0] ", ui.scroll_messages[8])
+
+ # Scroll to the top.
+ self.assertEqual({
+ 0: None,
+ -1: [1, 0]
+ }, ui.output_array_pointer_indices[9])
+ self.assertIn(" Scroll: 0.00% -[1,0] ", ui.scroll_messages[9])
+
+ # Attempt to scroll pass the top limit should lead to no change.
+ self.assertEqual({
+ 0: None,
+ -1: [1, 0]
+ }, ui.output_array_pointer_indices[10])
+ self.assertIn(" Scroll: 0.00% -[1,0] ", ui.scroll_messages[10])
+
+ def testScrollTensorByValidIndices(self):
+ """Test scrolling to specified (valid) indices in a tensor."""
+
+ ui = MockCursesUI(
+ 8, # Use a small screen height to cause scrolling.
+ 80,
+ command_sequence=[
+ string_to_codes("print_ones --size 5\n"),
+ string_to_codes("@[0, 0]\n"), # Scroll to element [0, 0].
+ string_to_codes("@1,0\n"), # Scroll to element [3, 0].
+ string_to_codes("@[0,2]\n"), # Scroll back to line 0.
+ self._EXIT
+ ])
+
+ ui.register_command_handler("print_ones", self._print_ones,
+ "print an all-one matrix of specified size")
+ ui.run_ui()
+
+ self.assertEqual(4, len(ui.unwrapped_outputs))
+ self.assertEqual(4, len(ui.output_array_pointer_indices))
+
+ for i in range(4):
+ self.assertEqual([
+ "Tensor \"m\":", "", "array([[ 1., 1., 1., 1., 1.],",
+ " [ 1., 1., 1., 1., 1.],",
+ " [ 1., 1., 1., 1., 1.],",
+ " [ 1., 1., 1., 1., 1.],",
+ " [ 1., 1., 1., 1., 1.]])"
+ ], ui.unwrapped_outputs[i].lines)
+
+ self.assertEqual({
+ 0: None,
+ -1: [1, 0]
+ }, ui.output_array_pointer_indices[0])
+ self.assertEqual({
+ 0: [0, 0],
+ -1: [3, 0]
+ }, ui.output_array_pointer_indices[1])
+ self.assertEqual({
+ 0: [1, 0],
+ -1: [4, 0]
+ }, ui.output_array_pointer_indices[2])
+ self.assertEqual({
+ 0: [0, 0],
+ -1: [3, 0]
+ }, ui.output_array_pointer_indices[3])
+
+ def testScrollTensorByInvalidIndices(self):
+ """Test scrolling to specified invalid indices in a tensor."""
+
+ ui = MockCursesUI(
+ 8, # Use a small screen height to cause scrolling.
+ 80,
+ command_sequence=[
+ string_to_codes("print_ones --size 5\n"),
+ string_to_codes("@[10, 0]\n"), # Scroll to invalid indices.
+ string_to_codes("@[]\n"), # Scroll to invalid indices.
+ string_to_codes("@\n"), # Scroll to invalid indices.
+ self._EXIT
+ ])
+
+ ui.register_command_handler("print_ones", self._print_ones,
+ "print an all-one matrix of specified size")
+ ui.run_ui()
+
+ # Because all scroll-by-indices commands are invalid, there should be only
+ # one output event.
+ self.assertEqual(1, len(ui.unwrapped_outputs))
+ self.assertEqual(1, len(ui.output_array_pointer_indices))
+
+ # Check error messages.
+ self.assertEqual("ERROR: Indices exceed tensor dimensions.", ui.toasts[1])
+ self.assertEqual("ERROR: invalid literal for int() with base 10: ''",
+ ui.toasts[2])
+ self.assertEqual("ERROR: Empty indices.", ui.toasts[3])
+
if __name__ == "__main__":
googletest.main()
diff --git a/tensorflow/python/debug/cli/debugger_cli_common.py b/tensorflow/python/debug/cli/debugger_cli_common.py
index 5f96af52cc..2fd8e6b4cb 100644
--- a/tensorflow/python/debug/cli/debugger_cli_common.py
+++ b/tensorflow/python/debug/cli/debugger_cli_common.py
@@ -19,6 +19,7 @@ from __future__ import print_function
import copy
import re
+import sre_constants
import traceback
from six.moves import xrange # pylint: disable=redefined-builtin
@@ -26,6 +27,7 @@ from six.moves import xrange # pylint: disable=redefined-builtin
HELP_INDENT = " "
EXPLICIT_USER_EXIT = "explicit_user_exit"
+REGEX_MATCH_LINES_KEY = "regex_match_lines"
class CommandLineExit(Exception):
@@ -171,14 +173,21 @@ def regex_find(orig_screen_output, regex, font_attr):
Returns:
A modified copy of orig_screen_output.
+
+ Raises:
+ ValueError: If input str regex is not a valid regular expression.
"""
new_screen_output = RichTextLines(
orig_screen_output.lines,
font_attr_segs=copy.deepcopy(orig_screen_output.font_attr_segs),
annotations=orig_screen_output.annotations)
- re_prog = re.compile(regex)
+ try:
+ re_prog = re.compile(regex)
+ except sre_constants.error:
+ raise ValueError("Invalid regular expression: \"%s\"" % regex)
+ regex_match_lines = []
for i in xrange(len(new_screen_output.lines)):
line = new_screen_output.lines[i]
find_it = re_prog.finditer(line)
@@ -194,7 +203,9 @@ def regex_find(orig_screen_output, regex, font_attr):
new_screen_output.font_attr_segs[i].extend(match_segs)
new_screen_output.font_attr_segs[i] = sorted(
new_screen_output.font_attr_segs[i], key=lambda x: x[0])
+ regex_match_lines.append(i)
+ new_screen_output.annotations[REGEX_MATCH_LINES_KEY] = regex_match_lines
return new_screen_output
@@ -282,6 +293,11 @@ def wrap_rich_text_lines(inp, cols):
out.lines.extend(wlines)
+ # Copy over keys of annotation that are not row indices.
+ for key in inp.annotations:
+ if not isinstance(key, int):
+ out.annotations[key] = inp.annotations[key]
+
return out
diff --git a/tensorflow/python/debug/cli/debugger_cli_common_test.py b/tensorflow/python/debug/cli/debugger_cli_common_test.py
index 2703826d49..36a935eade 100644
--- a/tensorflow/python/debug/cli/debugger_cli_common_test.py
+++ b/tensorflow/python/debug/cli/debugger_cli_common_test.py
@@ -407,6 +407,10 @@ class RegexFindTest(test_util.TensorFlowTestCase):
self.assertEqual([(6, 9, "yellow")], new_screen_output.font_attr_segs[0])
self.assertEqual([(8, 11, "yellow")], new_screen_output.font_attr_segs[1])
+ # Check field in annotations carrying a list of matching line indices.
+ self.assertEqual([0, 1], new_screen_output.annotations[
+ debugger_cli_common.REGEX_MATCH_LINES_KEY])
+
def testRegexFindWithExistingFontAttrSegs(self):
# Add a font attribute segment first.
self._orig_screen_output.font_attr_segs[0] = [(9, 12, "red")]
@@ -419,6 +423,21 @@ class RegexFindTest(test_util.TensorFlowTestCase):
self.assertEqual([(6, 9, "yellow"), (9, 12, "red")],
new_screen_output.font_attr_segs[0])
+ self.assertEqual([0, 1], new_screen_output.annotations[
+ debugger_cli_common.REGEX_MATCH_LINES_KEY])
+
+ def testRegexFindWithNoMatches(self):
+ new_screen_output = debugger_cli_common.regex_find(self._orig_screen_output,
+ "infrared", "yellow")
+
+ self.assertEqual({}, new_screen_output.font_attr_segs)
+ self.assertEqual([], new_screen_output.annotations[
+ debugger_cli_common.REGEX_MATCH_LINES_KEY])
+
+ def testInvalidRegex(self):
+ with self.assertRaisesRegexp(ValueError, "Invalid regular expression"):
+ debugger_cli_common.regex_find(self._orig_screen_output, "[", "yellow")
+
class WrapScreenOutputTest(test_util.TensorFlowTestCase):
@@ -445,6 +464,9 @@ class WrapScreenOutputTest(test_util.TensorFlowTestCase):
def testWrappingWithAttrCutoff(self):
out = debugger_cli_common.wrap_rich_text_lines(self._orig_screen_output, 11)
+ # Add non-row-index field to out.
+ out.annotations["metadata"] = "foo"
+
# Check wrapped text.
self.assertEqual(5, len(out.lines))
self.assertEqual("Folk song:", out.lines[0])
@@ -468,6 +490,9 @@ class WrapScreenOutputTest(test_util.TensorFlowTestCase):
self.assertEqual("shorter wavelength", out.annotations[3])
self.assertFalse(4 in out.annotations)
+ # Chec that the non-row-index field is present in output.
+ self.assertEqual("foo", out.annotations["metadata"])
+
def testWrappingWithMultipleAttrCutoff(self):
self._orig_screen_output = debugger_cli_common.RichTextLines(
["Folk song:", "Roses are red", "Violets are blue"],
diff --git a/tensorflow/python/debug/examples/README.md b/tensorflow/python/debug/examples/README.md
index 2e67c129d7..f0aaf6b0fc 100644
--- a/tensorflow/python/debug/examples/README.md
+++ b/tensorflow/python/debug/examples/README.md
@@ -141,9 +141,13 @@ tfdbg>
Try the following commands at the `tfdbg>` prompt:
| Command example | Explanation |
-| ------------- |:--------------------- |
+|:----------------------------- |:----------------------------------- |
| `pt hidden/Relu:0` | Print the value of the tensor `hidden/Relu:0`. |
-| `pt hidden/Relu:0[:, 1]` | Print a subarray of the tensor `hidden/Relu:0`, using numpy-style array slicing. |
+| `pt hidden/Relu:0[0:50,:]` | Print a subarray of the tensor `hidden/Relu:0`, using numpy-style array slicing. |
+| `pt hidden/Relu:0[0:50,:] -a` | For a large tensor like the one here, print its value in its entirety, i.e., without using any ellipsis. May take a long time for large tensors. |
+| `@[10,0]` or `@10,0` | Navigate to indices [10, 0] in the tensor being displayed. |
+| `/inf` | Search the screen output with the regex `inf` and highlight any matches. |
+| `/` | Scroll to the next line with matches to the searched regex (if any). |
| `ni -a hidden/Relu` | Displays information about the node `hidden/Relu`, including node attributes. |
| `li -r hidden/Relu:0` | List the inputs to the node `hidden/Relu`, recursively, i.e., the input tree. |
| `lo -r hidden/Relu:0` | List the recipients of the output of the node `hidden/Relu`, recursively, i.e., the output recipient tree. |
@@ -278,16 +282,11 @@ diff = y_ * tf.log(tf.clip_by_value(y, 1e-8, 1.0))
**Other features of the tfdbg diagnostics CLI:**
-<!---
-TODO(cais): Add the following UI features once they are checked in:
-regex search and highlighting.
---->
* Navigation through command history using the Up and Down arrow keys.
Prefix-based navigation is also supported.
* Tab completion of commands and some command arguments.
-
Frequently-asked questions:
===========================
@@ -303,4 +302,4 @@ Frequently-asked questions:
* **Q**: How do I link tfdbg against my Session in Bazel?<br />
**A**: In your BUILD rule, declare the dependency: `"//tensorflow:tensorflow_py"`.
In your Python file, do:
- `from tensorflow.python import debug import tf_debug` \ No newline at end of file
+ `from tensorflow.python import debug as tf_debug`