aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Shanqing Cai <cais@google.com>2016-11-08 18:17:40 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-11-08 18:23:20 -0800
commit6eb522b4d6fac69274cbc245a2a0f5e4738ced5a (patch)
tree8a596941d9b423778bff741b7ed32b44bc238216
parent7ea2f7d2689d9686c93650bf5bcc4c8ba459377d (diff)
tfdbg CLI: enable highlighting of tensor elements by range in print_tensor
Command example: $ blaze build -c opt third_party/tensorflow/python/debug:debug_mnist && \ blaze-bin/third_party/tensorflow/python/debug/debug_mnist --debug tfdbg> run tfdbg> pt hidden/Relu:0[0,:] -r [0.25, 0.75] tfdbg> pt hidden/Relu:0[0,:] -r [0.75, inf] tfdbg> pt hidden/Relu:0[0,:] -r "[[0, 0.25], [0.75, inf]]" Change: 138590602
-rw-r--r--tensorflow/python/debug/cli/analyzer_cli.py57
-rw-r--r--tensorflow/python/debug/cli/analyzer_cli_test.py44
-rw-r--r--tensorflow/python/debug/cli/command_parser.py44
-rw-r--r--tensorflow/python/debug/cli/command_parser_test.py46
-rw-r--r--tensorflow/python/debug/cli/curses_ui.py2
-rw-r--r--tensorflow/python/debug/cli/debugger_cli_common.py2
-rw-r--r--tensorflow/python/debug/cli/tensor_format.py269
-rw-r--r--tensorflow/python/debug/cli/tensor_format_test.py398
-rw-r--r--tensorflow/python/debug/examples/README.md1
9 files changed, 804 insertions, 59 deletions
diff --git a/tensorflow/python/debug/cli/analyzer_cli.py b/tensorflow/python/debug/cli/analyzer_cli.py
index 868aeaef20..e00f8c810e 100644
--- a/tensorflow/python/debug/cli/analyzer_cli.py
+++ b/tensorflow/python/debug/cli/analyzer_cli.py
@@ -27,6 +27,7 @@ import argparse
import copy
import re
+import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.python.debug import debug_data
@@ -201,6 +202,15 @@ class DebugAnalyzer(object):
default=-1,
help="0-based dump number for the specified tensor. "
"Required for tensor with multiple dumps.")
+ ap.add_argument(
+ "-r",
+ "--ranges",
+ dest="ranges",
+ type=str,
+ default="",
+ help="Numerical ranges to highlight tensor elements in. "
+ "Examples: -r 0,1e-8, -r [-0.1,0.1], "
+ "-r \"[[-inf, -0.1], [0.1, inf]]\"")
ap.add_argument(
"-a",
@@ -476,6 +486,9 @@ class DebugAnalyzer(object):
else:
np_printoptions = {}
+ # Determine if any range-highlighting is required.
+ highlight_options = self._parse_ranges_highlight(parsed.ranges)
+
# Determine if there parsed.tensor_name contains any indexing (slicing).
if parsed.tensor_name.count("[") == 1 and parsed.tensor_name.endswith("]"):
tensor_name = parsed.tensor_name[:parsed.tensor_name.index("[")]
@@ -517,7 +530,8 @@ class DebugAnalyzer(object):
matching_data[0].watch_key,
np_printoptions,
print_all=parsed.print_all,
- tensor_slicing=tensor_slicing)
+ tensor_slicing=tensor_slicing,
+ highlight_options=highlight_options)
else:
return self._error(
"Invalid number (%d) for tensor %s, which generated one dump." %
@@ -553,14 +567,45 @@ class DebugAnalyzer(object):
parsed.number,
np_printoptions,
print_all=parsed.print_all,
- tensor_slicing=tensor_slicing)
+ tensor_slicing=tensor_slicing,
+ highlight_options=highlight_options)
+
+ def _parse_ranges_highlight(self, ranges_string):
+ """Process ranges highlight string.
+
+ Args:
+ ranges_string: (str) A string representing a numerical range of a list of
+ numerical ranges. See the help info of the -r flag of the print_tensor
+ command for more details.
+
+ Returns:
+ An instance of tensor_format.HighlightOptions, if range_string is a valid
+ representation of a range or a list of ranges.
+ """
+
+ ranges = None
+
+ def ranges_filter(x):
+ r = np.zeros(x.shape, dtype=bool)
+ for rng_start, rng_end in ranges:
+ r = np.logical_or(r, np.logical_and(x >= rng_start, x <= rng_end))
+
+ return r
+
+ if ranges_string:
+ ranges = command_parser.parse_ranges(ranges_string)
+ return tensor_format.HighlightOptions(
+ ranges_filter, description=ranges_string)
+ else:
+ return None
def _format_tensor(self,
tensor,
watch_key,
np_printoptions,
print_all=False,
- tensor_slicing=None):
+ tensor_slicing=None,
+ highlight_options=None):
"""Generate formatted str to represent a tensor or its slices.
Args:
@@ -575,6 +620,9 @@ class DebugAnalyzer(object):
can handle.)
tensor_slicing: (str or None) Slicing of the tensor, e.g., "[:, 1]". If
None, no slicing will be performed on the tensor.
+ highlight_options: (tensor_format.HighlightOptions) options to highlight
+ elements of the tensor. See the doc of tensor_format.format_tensor()
+ for more details.
Returns:
(str) Formatted str representing the (potentially sliced) tensor.
@@ -603,7 +651,8 @@ class DebugAnalyzer(object):
value,
sliced_name,
include_metadata=True,
- np_printoptions=np_printoptions)
+ np_printoptions=np_printoptions,
+ highlight_options=highlight_options)
def list_outputs(self, args, screen_info=None):
"""Command handler for inputs.
diff --git a/tensorflow/python/debug/cli/analyzer_cli_test.py b/tensorflow/python/debug/cli/analyzer_cli_test.py
index f555201766..ac1ece6af7 100644
--- a/tensorflow/python/debug/cli/analyzer_cli_test.py
+++ b/tensorflow/python/debug/cli/analyzer_cli_test.py
@@ -505,6 +505,48 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase):
self.assertIn(4, out.annotations)
self.assertIn(5, out.annotations)
+ def testPrintTensorHighlightingRanges(self):
+ out = self._registry.dispatch_command(
+ "print_tensor", ["simple_mul_add/matmul:0", "--ranges", "[-inf, 0.0]"],
+ screen_info={"cols": 80})
+
+ self.assertEqual([
+ "Tensor \"simple_mul_add/matmul:0:DebugIdentity\": "
+ "Highlighted([-inf, 0.0]): 1 of 2 element(s) (50.00%)",
+ " dtype: float64",
+ " shape: (2, 1)",
+ "",
+ "array([[ 7.],",
+ " [-2.]])",
+ ], out.lines)
+
+ self.assertIn("tensor_metadata", out.annotations)
+ self.assertIn(4, out.annotations)
+ self.assertIn(5, out.annotations)
+ self.assertEqual([(8, 11, "bold")], out.font_attr_segs[5])
+
+ out = self._registry.dispatch_command(
+ "print_tensor",
+ ["simple_mul_add/matmul:0", "--ranges", "[[-inf, -5.5], [5.5, inf]]"],
+ screen_info={"cols": 80})
+
+ self.assertEqual([
+ "Tensor \"simple_mul_add/matmul:0:DebugIdentity\": "
+ "Highlighted([[-inf, -5.5], [5.5, inf]]): "
+ "1 of 2 element(s) (50.00%)",
+ " dtype: float64",
+ " shape: (2, 1)",
+ "",
+ "array([[ 7.],",
+ " [-2.]])",
+ ], out.lines)
+
+ self.assertIn("tensor_metadata", out.annotations)
+ self.assertIn(4, out.annotations)
+ self.assertIn(5, out.annotations)
+ self.assertEqual([(9, 11, "bold")], out.font_attr_segs[4])
+ self.assertNotIn(5, out.font_attr_segs)
+
def testPrintTensorWithSlicing(self):
out = self._registry.dispatch_command(
"print_tensor", ["simple_mul_add/matmul:0[1, :]"],
@@ -667,8 +709,6 @@ class AnalyzerCLIPrintLargeTensorTest(test_util.TensorFlowTestCase):
out = self._registry.dispatch_command(
"print_tensor", ["large_tensors/x:0"], screen_info={"cols": 80})
- print(out.lines) # DEBUG
-
# Assert that ellipses are present in the tensor value printout.
self.assertIn("...,", out.lines[4])
diff --git a/tensorflow/python/debug/cli/command_parser.py b/tensorflow/python/debug/cli/command_parser.py
index 4f940d480c..02ff4b277d 100644
--- a/tensorflow/python/debug/cli/command_parser.py
+++ b/tensorflow/python/debug/cli/command_parser.py
@@ -17,7 +17,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import ast
import re
+import sys
+
_BRACKETS_PATTERN = re.compile(r"\[[^\]]*\]")
_QUOTES_PATTERN = re.compile(r"\"[^\"]*\"")
@@ -132,3 +135,44 @@ def parse_indices(indices_string):
indices_string = indices_string[1:-1]
return [int(element) for element in indices_string.split(",")]
+
+
+def parse_ranges(range_string):
+ """Parse a string representing numerical range(s).
+
+ Args:
+ range_string: (str) A string representing a numerical range or a list of
+ them. For example:
+ "[-1.0,1.0]", "[-inf, 0]", "[[-inf, -1.0], [1.0, inf]]"
+
+ Returns:
+ (list of list of float) A list of numerical ranges parsed from the input
+ string.
+
+ Raises:
+ ValueError: If the input doesn't represent a range or a list of ranges.
+ """
+
+ range_string = range_string.strip()
+ if not range_string:
+ return []
+
+ if "inf" in range_string:
+ range_string = re.sub(r"inf", repr(sys.float_info.max), range_string)
+
+ ranges = ast.literal_eval(range_string)
+ if isinstance(ranges, list) and not isinstance(ranges[0], list):
+ ranges = [ranges]
+
+ # Verify that ranges is a list of list of numbers.
+ for item in ranges:
+ if len(item) != 2:
+ raise ValueError("Incorrect number of elements in range")
+ elif not isinstance(item[0], (int, float)):
+ raise ValueError("Incorrect type in the 1st element of range: %s" %
+ type(item[0]))
+ elif not isinstance(item[1], (int, float)):
+ raise ValueError("Incorrect type in the 2nd element of range: %s" %
+ type(item[0]))
+
+ return ranges
diff --git a/tensorflow/python/debug/cli/command_parser_test.py b/tensorflow/python/debug/cli/command_parser_test.py
index acb1d7f90b..93a3e44e96 100644
--- a/tensorflow/python/debug/cli/command_parser_test.py
+++ b/tensorflow/python/debug/cli/command_parser_test.py
@@ -17,6 +17,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import sys
+
from tensorflow.python.debug.cli import command_parser
from tensorflow.python.framework import test_util
from tensorflow.python.platform import googletest
@@ -158,5 +160,49 @@ class ParseIndicesTest(test_util.TensorFlowTestCase):
self.assertEqual([0], command_parser.parse_indices("3, 4,"))
+class ParseRangesTest(test_util.TensorFlowTestCase):
+
+ INF_VALUE = sys.float_info.max
+
+ def testParseEmptyRangeString(self):
+ self.assertEqual([], command_parser.parse_ranges(""))
+ self.assertEqual([], command_parser.parse_ranges(" "))
+
+ def testParseSingleRange(self):
+ self.assertAllClose([[-0.1, 0.2]],
+ command_parser.parse_ranges("[-0.1, 0.2]"))
+ self.assertAllClose([[-0.1, self.INF_VALUE]],
+ command_parser.parse_ranges("[-0.1, inf]"))
+ self.assertAllClose([[-self.INF_VALUE, self.INF_VALUE]],
+ command_parser.parse_ranges("[-inf, inf]"))
+
+ def testParseSingleListOfRanges(self):
+ self.assertAllClose([[-0.1, 0.2], [10.0, 12.0]],
+ command_parser.parse_ranges("[[-0.1, 0.2], [10, 12]]"))
+ self.assertAllClose(
+ [[-self.INF_VALUE, -1.0], [1.0, self.INF_VALUE]],
+ command_parser.parse_ranges("[[-inf, -1.0],[1.0, inf]]"))
+
+ def testParseInvalidRangeString(self):
+ with self.assertRaises(SyntaxError):
+ command_parser.parse_ranges("[[1,2]")
+
+ with self.assertRaisesRegexp(ValueError,
+ "Incorrect number of elements in range"):
+ command_parser.parse_ranges("[1,2,3]")
+
+ with self.assertRaisesRegexp(ValueError,
+ "Incorrect number of elements in range"):
+ command_parser.parse_ranges("[inf]")
+
+ with self.assertRaisesRegexp(ValueError,
+ "Incorrect type in the 1st element of range"):
+ command_parser.parse_ranges("[1j, 1]")
+
+ with self.assertRaisesRegexp(ValueError,
+ "Incorrect type in the 2nd element of range"):
+ command_parser.parse_ranges("[1, 1j]")
+
+
if __name__ == "__main__":
googletest.main()
diff --git a/tensorflow/python/debug/cli/curses_ui.py b/tensorflow/python/debug/cli/curses_ui.py
index e35e2bf404..8e3d069aa0 100644
--- a/tensorflow/python/debug/cli/curses_ui.py
+++ b/tensorflow/python/debug/cli/curses_ui.py
@@ -448,7 +448,7 @@ class CursesUI(object):
if indices_str:
try:
indices = command_parser.parse_indices(indices_str)
- omitted, line_index = tensor_format.locate_tensor_element(
+ omitted, line_index, _, _ = tensor_format.locate_tensor_element(
self._curr_wrapped_output, indices)
if not omitted:
diff --git a/tensorflow/python/debug/cli/debugger_cli_common.py b/tensorflow/python/debug/cli/debugger_cli_common.py
index 16187d3b71..d603009401 100644
--- a/tensorflow/python/debug/cli/debugger_cli_common.py
+++ b/tensorflow/python/debug/cli/debugger_cli_common.py
@@ -94,10 +94,12 @@ class RichTextLines(object):
self._font_attr_segs = font_attr_segs
if not self._font_attr_segs:
self._font_attr_segs = {}
+ # TODO(cais): Refactor to collections.defaultdict(list) to simplify code.
self._annotations = annotations
if not self._annotations:
self._annotations = {}
+ # TODO(cais): Refactor to collections.defaultdict(list) to simplify code.
@property
def lines(self):
diff --git a/tensorflow/python/debug/cli/tensor_format.py b/tensorflow/python/debug/cli/tensor_format.py
index 0f9c5dab1f..9f5ebe47ab 100644
--- a/tensorflow/python/debug/cli/tensor_format.py
+++ b/tensorflow/python/debug/cli/tensor_format.py
@@ -18,6 +18,7 @@ from __future__ import division
from __future__ import print_function
import copy
+import re
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
@@ -27,12 +28,51 @@ from tensorflow.python.debug.cli import debugger_cli_common
_NUMPY_OMISSION = "...,"
_NUMPY_DEFAULT_EDGE_ITEMS = 3
+_NUMBER_REGEX = re.compile(r"[-+]?([0-9][-+0-9eE\.]+|nan|inf)(\s|,|\])")
+
BEGIN_INDICES_KEY = "i0"
OMITTED_INDICES_KEY = "omitted"
-
-def format_tensor(
- tensor, tensor_name, include_metadata=False, np_printoptions=None):
+DEFAULT_TENSOR_ELEMENT_HIGHLIGHT_FONT_ATTR = "bold"
+
+
+class HighlightOptions(object):
+ """Options for highlighting elements of a tensor."""
+
+ def __init__(self,
+ criterion,
+ description=None,
+ font_attr=DEFAULT_TENSOR_ELEMENT_HIGHLIGHT_FONT_ATTR):
+ """Constructor of HighlightOptions.
+
+ Args:
+ criterion: (callable) A callable of the following signature:
+ def to_highlight(X):
+ # Args:
+ # X: The tensor to highlight elements in.
+ #
+ # Returns:
+ # (boolean ndarray) A boolean ndarray of the same shape as X
+ # indicating which elements are to be highlighted (iff True).
+ This callable will be used as the argument of np.argwhere() to
+ determine which elements of the tensor are to be highlighted.
+ description: (str) Description of the highlight criterion embodied by
+ criterion.
+ font_attr: (str) Font attribute to be applied to the
+ highlighted elements.
+
+ """
+
+ self.criterion = criterion
+ self.description = description
+ self.font_attr = font_attr
+
+
+def format_tensor(tensor,
+ tensor_name,
+ include_metadata=False,
+ np_printoptions=None,
+ highlight_options=None):
"""Generate a RichTextLines object showing a tensor in formatted style.
Args:
@@ -45,6 +85,8 @@ def format_tensor(
np_printoptions: A dictionary of keyword arguments that are passed to a
call of np.set_printoptions() to set the text format for display numpy
ndarrays.
+ highlight_options: (HighlightOptions) options for highlighting elements
+ of the tensor.
Returns:
A RichTextLines object. Its annotation field has line-by-line markups to
@@ -92,7 +134,38 @@ def format_tensor(
annotations = _annotate_ndarray_lines(
array_lines, tensor, np_printoptions=np_printoptions, offset=hlines)
- return debugger_cli_common.RichTextLines(lines, annotations=annotations)
+ formatted = debugger_cli_common.RichTextLines(lines, annotations=annotations)
+
+ # Perform optional highlighting.
+ if highlight_options is not None:
+ indices_list = list(np.argwhere(highlight_options.criterion(tensor)))
+
+ total_elements = np.size(tensor)
+ highlight_summary = "Highlighted%s: %d of %d element(s) (%.2f%%)" % (
+ "(%s)" % highlight_options.description if highlight_options.description
+ else "", len(indices_list), total_elements,
+ len(indices_list) / float(total_elements) * 100.0)
+
+ formatted.lines[0] += " " + highlight_summary
+
+ if indices_list:
+ indices_list = [list(indices) for indices in indices_list]
+
+ are_omitted, rows, start_cols, end_cols = locate_tensor_element(
+ formatted, indices_list)
+ for is_omitted, row, start_col, end_col in zip(are_omitted, rows,
+ start_cols, end_cols):
+ if is_omitted or start_col is None or end_col is None:
+ continue
+
+ if row in formatted.font_attr_segs:
+ formatted.font_attr_segs[row].append(
+ (start_col, end_col, highlight_options.font_attr))
+ else:
+ formatted.font_attr_segs[row] = [(start_col, end_col,
+ highlight_options.font_attr)]
+
+ return formatted
def _annotate_ndarray_lines(
@@ -181,16 +254,27 @@ def locate_tensor_element(formatted, indices):
Given a RichTextLines object representing a tensor and indices of the sought
element, return the row number at which the element is located (if exists).
- TODO(cais): Return column number as well.
-
Args:
formatted: A RichTextLines object containing formatted text lines
representing the tensor.
- indices: Indices of the sought element, as a list of int.
+ indices: Indices of the sought element, as a list of int or a list of list
+ of int. The former case is for a single set of indices to look up,
+ whereas the latter case is for looking up a batch of indices sets at once.
+ In the latter case, the indices must be in ascending order, or a
+ ValueError will be raised.
Returns:
1) A boolean indicating whether the element falls into an omitted line.
2) Row index.
+ 3) Column start index, i.e., the first column in which the representation
+ of the specified tensor starts, if it can be determined. If it cannot
+ be determined (e.g., due to ellipsis), None.
+ 4) Column end index, i.e., the column right after the last column that
+ represents the specified tensor. Iff it cannot be determined, None.
+
+ For return values described above are based on a single set of indices to
+ look up. In the case of batch mode (multiple sets of indices), the return
+ values will be lists of the types described above.
Raises:
AttributeError: If:
@@ -199,45 +283,168 @@ def locate_tensor_element(formatted, indices):
1) Indices do not match the dimensions of the tensor, or
2) Indices exceed sizes of the tensor, or
3) Indices contain negative value(s).
+ 4) If in batch mode, and if not all sets of indices are in ascending
+ order.
"""
+ if isinstance(indices[0], list):
+ indices_list = indices
+ input_batch = True
+ else:
+ indices_list = [indices]
+ input_batch = False
+
# Check that tensor_metadata is available.
if "tensor_metadata" not in formatted.annotations:
raise AttributeError("tensor_metadata is not available in annotations.")
- # Check "indices" match tensor dimensions.
- dims = formatted.annotations["tensor_metadata"]["shape"]
- if len(indices) != len(dims):
- raise ValueError(
- "Dimensions mismatch: requested: %d; actual: %d" %
- (len(indices), len(dims)))
-
- # Check "indices" is within size limits.
- for req_idx, siz in zip(indices, dims):
- if req_idx >= siz:
- raise ValueError("Indices exceed tensor dimensions.")
- if req_idx < 0:
- raise ValueError("Indices contain negative value(s).")
+ # Sanity check on input argument.
+ _validate_indices_list(indices_list, formatted)
+ dims = formatted.annotations["tensor_metadata"]["shape"]
+ batch_size = len(indices_list)
lines = formatted.lines
annot = formatted.annotations
prev_r = 0
+ prev_line = ""
prev_indices = [0] * len(dims)
+
+ # Initialize return values
+ are_omitted = [None] * batch_size
+ row_indices = [None] * batch_size
+ start_columns = [None] * batch_size
+ end_columns = [None] * batch_size
+
+ batch_pos = 0 # Current position in the batch.
+
for r in xrange(len(lines)):
if r not in annot:
continue
if BEGIN_INDICES_KEY in annot[r]:
- if indices >= prev_indices and indices < annot[r][BEGIN_INDICES_KEY]:
- return OMITTED_INDICES_KEY in annot[prev_r], prev_r
- else:
- prev_r = r
- prev_indices = annot[r][BEGIN_INDICES_KEY]
+ indices_key = BEGIN_INDICES_KEY
elif OMITTED_INDICES_KEY in annot[r]:
- if indices >= prev_indices and indices < annot[r][OMITTED_INDICES_KEY]:
- return OMITTED_INDICES_KEY in annot[prev_r], prev_r
- else:
- prev_r = r
- prev_indices = annot[r][OMITTED_INDICES_KEY]
+ indices_key = OMITTED_INDICES_KEY
+
+ matching_indices_list = [
+ ind for ind in indices_list[batch_pos:]
+ if prev_indices <= ind < annot[r][indices_key]
+ ]
+
+ if matching_indices_list:
+ num_matches = len(matching_indices_list)
+
+ match_start_columns, match_end_columns = _locate_elements_in_line(
+ prev_line, matching_indices_list, prev_indices)
+
+ start_columns[batch_pos:batch_pos + num_matches] = match_start_columns
+ end_columns[batch_pos:batch_pos + num_matches] = match_end_columns
+ are_omitted[batch_pos:batch_pos + num_matches] = [
+ OMITTED_INDICES_KEY in annot[prev_r]
+ ] * num_matches
+ row_indices[batch_pos:batch_pos + num_matches] = [prev_r] * num_matches
+
+ batch_pos += num_matches
+ if batch_pos >= batch_size:
+ break
+
+ prev_r = r
+ prev_line = lines[r]
+ prev_indices = annot[r][indices_key]
+
+ if batch_pos < batch_size:
+ matching_indices_list = indices_list[batch_pos:]
+ num_matches = len(matching_indices_list)
+
+ match_start_columns, match_end_columns = _locate_elements_in_line(
+ prev_line, matching_indices_list, prev_indices)
+
+ start_columns[batch_pos:batch_pos + num_matches] = match_start_columns
+ end_columns[batch_pos:batch_pos + num_matches] = match_end_columns
+ are_omitted[batch_pos:batch_pos + num_matches] = [
+ OMITTED_INDICES_KEY in annot[prev_r]
+ ] * num_matches
+ row_indices[batch_pos:batch_pos + num_matches] = [prev_r] * num_matches
+
+ if input_batch:
+ return are_omitted, row_indices, start_columns, end_columns
+ else:
+ return are_omitted[0], row_indices[0], start_columns[0], end_columns[0]
+
+
+def _validate_indices_list(indices_list, formatted):
+ prev_ind = None
+ for ind in indices_list:
+ # Check indices match tensor dimensions.
+ dims = formatted.annotations["tensor_metadata"]["shape"]
+ if len(ind) != len(dims):
+ raise ValueError("Dimensions mismatch: requested: %d; actual: %d" %
+ (len(ind), len(dims)))
+
+ # Check indices is within size limits.
+ for req_idx, siz in zip(ind, dims):
+ if req_idx >= siz:
+ raise ValueError("Indices exceed tensor dimensions.")
+ if req_idx < 0:
+ raise ValueError("Indices contain negative value(s).")
+
+ # Check indices are in ascending order.
+ if prev_ind and ind < prev_ind:
+ raise ValueError("Input indices sets are not in ascending order.")
+
+ prev_ind = ind
+
+
+def _locate_elements_in_line(line, indices_list, ref_indices):
+ """Determine the start and end indices of an element in a line.
+
+ Args:
+ line: (str) the line in which the element is to be sought.
+ indices_list: (list of list of int) list of indices of the element to
+ search for. Assumes that the indices in the batch are unique and sorted
+ in ascending order.
+ ref_indices: (list of int) reference indices, i.e., the indices of the
+ first element represented in the line.
+
+ Returns:
+ start_columns: (list of int) start column indices, if found. If not found,
+ None.
+ end_columns: (list of int) end column indices, if found. If not found,
+ None.
+ If found, the element is represented in the left-closed-right-open interval
+ [start_column, end_column].
+ """
+
+ batch_size = len(indices_list)
+ offsets = [indices[-1] - ref_indices[-1] for indices in indices_list]
+
+ start_columns = [None] * batch_size
+ end_columns = [None] * batch_size
+
+ if _NUMPY_OMISSION in line:
+ ellipsis_index = line.find(_NUMPY_OMISSION)
+ else:
+ ellipsis_index = len(line)
+
+ matches_iter = re.finditer(_NUMBER_REGEX, line)
+
+ batch_pos = 0
+
+ offset_counter = 0
+ for match in matches_iter:
+ if match.start() > ellipsis_index:
+ # Do not attempt to search beyond ellipsis.
+ break
+
+ if offset_counter == offsets[batch_pos]:
+ start_columns[batch_pos] = match.start()
+ # Remove the final comma, right bracket, or whitespace.
+ end_columns[batch_pos] = match.end() - 1
+
+ batch_pos += 1
+ if batch_pos >= batch_size:
+ break
+
+ offset_counter += 1
- return OMITTED_INDICES_KEY in annot[prev_r], prev_r
+ return start_columns, end_columns
diff --git a/tensorflow/python/debug/cli/tensor_format_test.py b/tensorflow/python/debug/cli/tensor_format_test.py
index 5d9b150e7c..bd4437887f 100644
--- a/tensorflow/python/debug/cli/tensor_format_test.py
+++ b/tensorflow/python/debug/cli/tensor_format_test.py
@@ -199,6 +199,94 @@ class RichTextLinesTest(test_util.TensorFlowTestCase):
self._checkBeginIndices([1, 1, 0], out.annotations[7])
self._checkBeginIndices([1, 2, 0], out.annotations[8])
+ def testFormatTensor3DNoEllipsisWithArgwhereHighlightWithMatches(self):
+ a = np.linspace(0.0, 1.0 - 1.0 / 24.0, 24).reshape([2, 3, 4])
+
+ lower_bound = 0.26
+ upper_bound = 0.5
+
+ def highlight_filter(x):
+ return np.logical_and(x > lower_bound, x < upper_bound)
+
+ highlight_options = tensor_format.HighlightOptions(
+ highlight_filter, description="between 0.26 and 0.5")
+ out = tensor_format.format_tensor(
+ a, "a", highlight_options=highlight_options)
+
+ self.assertEqual([
+ "Tensor \"a\": "
+ "Highlighted(between 0.26 and 0.5): 5 of 24 element(s) (20.83%)",
+ "",
+ "array([[[ 0. , 0.04166667, 0.08333333, 0.125 ],",
+ " [ 0.16666667, 0.20833333, 0.25 , 0.29166667],",
+ " [ 0.33333333, 0.375 , 0.41666667, 0.45833333]],",
+ "",
+ " [[ 0.5 , 0.54166667, 0.58333333, 0.625 ],",
+ " [ 0.66666667, 0.70833333, 0.75 , 0.79166667],",
+ " [ 0.83333333, 0.875 , 0.91666667, 0.95833333]]])",
+ ], out.lines)
+
+ self._checkTensorMetadata(a, out.annotations)
+
+ # Check annotations for beginning indices of the lines.
+ self._checkBeginIndices([0, 0, 0], out.annotations[2])
+ self._checkBeginIndices([0, 1, 0], out.annotations[3])
+ self._checkBeginIndices([0, 2, 0], out.annotations[4])
+ self.assertNotIn(5, out.annotations)
+ self._checkBeginIndices([1, 0, 0], out.annotations[6])
+ self._checkBeginIndices([1, 1, 0], out.annotations[7])
+ self._checkBeginIndices([1, 2, 0], out.annotations[8])
+
+ # Check font attribute segments for highlighted elements.
+ self.assertNotIn(2, out.font_attr_segs)
+ self.assertEqual([(49, 59, "bold")], out.font_attr_segs[3])
+ self.assertEqual([(10, 20, "bold"), (23, 28, "bold"), (36, 46, "bold"),
+ (49, 59, "bold")], out.font_attr_segs[4])
+ self.assertNotIn(5, out.font_attr_segs)
+ self.assertNotIn(6, out.font_attr_segs)
+ self.assertNotIn(7, out.font_attr_segs)
+ self.assertNotIn(8, out.font_attr_segs)
+
+ def testFormatTensor3DNoEllipsisWithArgwhereHighlightWithNoMatches(self):
+ a = np.linspace(0.0, 1.0 - 1.0 / 24.0, 24).reshape([2, 3, 4])
+
+ def highlight_filter(x):
+ return x > 10.0
+
+ highlight_options = tensor_format.HighlightOptions(highlight_filter)
+ out = tensor_format.format_tensor(
+ a, "a", highlight_options=highlight_options)
+
+ self.assertEqual([
+ "Tensor \"a\": Highlighted: 0 of 24 element(s) (0.00%)", "",
+ "array([[[ 0. , 0.04166667, 0.08333333, 0.125 ],",
+ " [ 0.16666667, 0.20833333, 0.25 , 0.29166667],",
+ " [ 0.33333333, 0.375 , 0.41666667, 0.45833333]],", "",
+ " [[ 0.5 , 0.54166667, 0.58333333, 0.625 ],",
+ " [ 0.66666667, 0.70833333, 0.75 , 0.79166667],",
+ " [ 0.83333333, 0.875 , 0.91666667, 0.95833333]]])"
+ ], out.lines)
+
+ self._checkTensorMetadata(a, out.annotations)
+
+ # Check annotations for beginning indices of the lines.
+ self._checkBeginIndices([0, 0, 0], out.annotations[2])
+ self._checkBeginIndices([0, 1, 0], out.annotations[3])
+ self._checkBeginIndices([0, 2, 0], out.annotations[4])
+ self.assertNotIn(5, out.annotations)
+ self._checkBeginIndices([1, 0, 0], out.annotations[6])
+ self._checkBeginIndices([1, 1, 0], out.annotations[7])
+ self._checkBeginIndices([1, 2, 0], out.annotations[8])
+
+ # Check font attribute segments for highlighted elements.
+ self.assertNotIn(2, out.font_attr_segs)
+ self.assertNotIn(3, out.font_attr_segs)
+ self.assertNotIn(4, out.font_attr_segs)
+ self.assertNotIn(5, out.font_attr_segs)
+ self.assertNotIn(6, out.font_attr_segs)
+ self.assertNotIn(7, out.font_attr_segs)
+ self.assertNotIn(8, out.font_attr_segs)
+
def testFormatTensorWithEllipses(self):
a = np.zeros([11, 11, 11])
@@ -277,33 +365,54 @@ class RichTextLinesTest(test_util.TensorFlowTestCase):
" 0., 0.])",
], out.lines)
- is_omitted, row = tensor_format.locate_tensor_element(out, [0])
+ is_omitted, row, start_col, end_col = tensor_format.locate_tensor_element(
+ out, [0])
self.assertFalse(is_omitted)
self.assertEqual(2, row)
+ self.assertEqual(8, start_col)
+ self.assertEqual(10, end_col)
- is_omitted, row = tensor_format.locate_tensor_element(out, [5])
+ is_omitted, row, start_col, end_col = tensor_format.locate_tensor_element(
+ out, [5])
self.assertFalse(is_omitted)
self.assertEqual(2, row)
+ self.assertEqual(33, start_col)
+ self.assertEqual(35, end_col)
- is_omitted, row = tensor_format.locate_tensor_element(out, [6])
+ is_omitted, row, start_col, end_col = tensor_format.locate_tensor_element(
+ out, [6])
self.assertFalse(is_omitted)
self.assertEqual(3, row)
+ self.assertEqual(8, start_col)
+ self.assertEqual(10, end_col)
- is_omitted, row = tensor_format.locate_tensor_element(out, [11])
+ is_omitted, row, start_col, end_col = tensor_format.locate_tensor_element(
+ out, [11])
self.assertFalse(is_omitted)
self.assertEqual(3, row)
+ self.assertEqual(33, start_col)
+ self.assertEqual(35, end_col)
- is_omitted, row = tensor_format.locate_tensor_element(out, [12])
+ is_omitted, row, start_col, end_col = tensor_format.locate_tensor_element(
+ out, [12])
self.assertFalse(is_omitted)
self.assertEqual(4, row)
+ self.assertEqual(8, start_col)
+ self.assertEqual(10, end_col)
- is_omitted, row = tensor_format.locate_tensor_element(out, [18])
+ is_omitted, row, start_col, end_col = tensor_format.locate_tensor_element(
+ out, [18])
self.assertFalse(is_omitted)
self.assertEqual(5, row)
+ self.assertEqual(8, start_col)
+ self.assertEqual(10, end_col)
- is_omitted, row = tensor_format.locate_tensor_element(out, [19])
+ is_omitted, row, start_col, end_col = tensor_format.locate_tensor_element(
+ out, [19])
self.assertFalse(is_omitted)
self.assertEqual(5, row)
+ self.assertEqual(13, start_col)
+ self.assertEqual(15, end_col)
with self.assertRaisesRegexp(
ValueError, "Indices exceed tensor dimensions"):
@@ -317,6 +426,144 @@ class RichTextLinesTest(test_util.TensorFlowTestCase):
ValueError, "Dimensions mismatch"):
tensor_format.locate_tensor_element(out, [0, 0])
+ def testLocateTensorElement1DNoEllipsisBatchMode(self):
+ a = np.zeros(20)
+
+ out = tensor_format.format_tensor(
+ a, "a", np_printoptions={"linewidth": 40})
+
+ self.assertEqual([
+ "Tensor \"a\":",
+ "",
+ "array([ 0., 0., 0., 0., 0., 0.,",
+ " 0., 0., 0., 0., 0., 0.,",
+ " 0., 0., 0., 0., 0., 0.,",
+ " 0., 0.])",
+ ], out.lines)
+
+ (are_omitted, rows, start_cols,
+ end_cols) = tensor_format.locate_tensor_element(out, [[0]])
+ self.assertEqual([False], are_omitted)
+ self.assertEqual([2], rows)
+ self.assertEqual([8], start_cols)
+ self.assertEqual([10], end_cols)
+
+ (are_omitted, rows, start_cols,
+ end_cols) = tensor_format.locate_tensor_element(out, [[0], [5]])
+ self.assertEqual([False, False], are_omitted)
+ self.assertEqual([2, 2], rows)
+ self.assertEqual([8, 33], start_cols)
+ self.assertEqual([10, 35], end_cols)
+
+ (are_omitted, rows, start_cols,
+ end_cols) = tensor_format.locate_tensor_element(out, [[0], [6]])
+ self.assertEqual([False, False], are_omitted)
+ self.assertEqual([2, 3], rows)
+ self.assertEqual([8, 8], start_cols)
+ self.assertEqual([10, 10], end_cols)
+
+ (are_omitted, rows, start_cols,
+ end_cols) = tensor_format.locate_tensor_element(out, [[0], [5], [6]])
+ self.assertEqual([False, False, False], are_omitted)
+ self.assertEqual([2, 2, 3], rows)
+ self.assertEqual([8, 33, 8], start_cols)
+ self.assertEqual([10, 35, 10], end_cols)
+
+ (are_omitted, rows, start_cols,
+ end_cols) = tensor_format.locate_tensor_element(out, [[0], [5], [6], [19]])
+ self.assertEqual([False, False, False, False], are_omitted)
+ self.assertEqual([2, 2, 3, 5], rows)
+ self.assertEqual([8, 33, 8, 13], start_cols)
+ self.assertEqual([10, 35, 10, 15], end_cols)
+
+ def testBatchModeWithErrors(self):
+ a = np.zeros(20)
+
+ out = tensor_format.format_tensor(
+ a, "a", np_printoptions={"linewidth": 40})
+
+ self.assertEqual([
+ "Tensor \"a\":",
+ "",
+ "array([ 0., 0., 0., 0., 0., 0.,",
+ " 0., 0., 0., 0., 0., 0.,",
+ " 0., 0., 0., 0., 0., 0.,",
+ " 0., 0.])",
+ ], out.lines)
+
+ with self.assertRaisesRegexp(ValueError, "Dimensions mismatch"):
+ tensor_format.locate_tensor_element(out, [[0, 0], [0]])
+
+ with self.assertRaisesRegexp(ValueError,
+ "Indices exceed tensor dimensions"):
+ tensor_format.locate_tensor_element(out, [[0], [20]])
+
+ with self.assertRaisesRegexp(ValueError,
+ r"Indices contain negative value\(s\)"):
+ tensor_format.locate_tensor_element(out, [[0], [-1]])
+
+ with self.assertRaisesRegexp(
+ ValueError, "Input indices sets are not in ascending order"):
+ tensor_format.locate_tensor_element(out, [[5], [0]])
+
+ def testLocateTensorElement1DTinyAndNanValues(self):
+ a = np.ones([3, 3]) * 1e-8
+ a[1, 0] = np.nan
+ a[1, 2] = np.inf
+
+ out = tensor_format.format_tensor(
+ a, "a", np_printoptions={"linewidth": 100})
+
+ self.assertEqual([
+ "Tensor \"a\":",
+ "",
+ "array([[ 1.00000000e-08, 1.00000000e-08, 1.00000000e-08],",
+ " [ nan, 1.00000000e-08, inf],",
+ " [ 1.00000000e-08, 1.00000000e-08, 1.00000000e-08]])",
+ ], out.lines)
+
+ is_omitted, row, start_col, end_col = tensor_format.locate_tensor_element(
+ out, [0, 0])
+ self.assertFalse(is_omitted)
+ self.assertEqual(2, row)
+ self.assertEqual(10, start_col)
+ self.assertEqual(24, end_col)
+
+ is_omitted, row, start_col, end_col = tensor_format.locate_tensor_element(
+ out, [0, 2])
+ self.assertFalse(is_omitted)
+ self.assertEqual(2, row)
+ self.assertEqual(46, start_col)
+ self.assertEqual(60, end_col)
+
+ is_omitted, row, start_col, end_col = tensor_format.locate_tensor_element(
+ out, [1, 0])
+ self.assertFalse(is_omitted)
+ self.assertEqual(3, row)
+ self.assertEqual(21, start_col)
+ self.assertEqual(24, end_col)
+
+ is_omitted, row, start_col, end_col = tensor_format.locate_tensor_element(
+ out, [1, 1])
+ self.assertFalse(is_omitted)
+ self.assertEqual(3, row)
+ self.assertEqual(28, start_col)
+ self.assertEqual(42, end_col)
+
+ is_omitted, row, start_col, end_col = tensor_format.locate_tensor_element(
+ out, [1, 2])
+ self.assertFalse(is_omitted)
+ self.assertEqual(3, row)
+ self.assertEqual(57, start_col)
+ self.assertEqual(60, end_col)
+
+ is_omitted, row, start_col, end_col = tensor_format.locate_tensor_element(
+ out, [2, 2])
+ self.assertFalse(is_omitted)
+ self.assertEqual(4, row)
+ self.assertEqual(46, start_col)
+ self.assertEqual(60, end_col)
+
def testLocateTensorElement2DNoEllipsis(self):
a = np.linspace(0.0, 1.0 - 1.0 / 16.0, 16).reshape([4, 4])
@@ -331,25 +578,40 @@ class RichTextLinesTest(test_util.TensorFlowTestCase):
" [ 0.75 , 0.8125, 0.875 , 0.9375]])",
], out.lines)
- is_omitted, row = tensor_format.locate_tensor_element(out, [0, 0])
+ is_omitted, row, start_col, end_col = tensor_format.locate_tensor_element(
+ out, [0, 0])
self.assertFalse(is_omitted)
self.assertEqual(2, row)
+ self.assertEqual(9, start_col)
+ self.assertEqual(11, end_col)
- is_omitted, row = tensor_format.locate_tensor_element(out, [0, 3])
+ is_omitted, row, start_col, end_col = tensor_format.locate_tensor_element(
+ out, [0, 3])
self.assertFalse(is_omitted)
self.assertEqual(2, row)
+ self.assertEqual(36, start_col)
+ self.assertEqual(42, end_col)
- is_omitted, row = tensor_format.locate_tensor_element(out, [1, 0])
+ is_omitted, row, start_col, end_col = tensor_format.locate_tensor_element(
+ out, [1, 0])
self.assertFalse(is_omitted)
self.assertEqual(3, row)
+ self.assertEqual(9, start_col)
+ self.assertEqual(13, end_col)
- is_omitted, row = tensor_format.locate_tensor_element(out, [1, 3])
+ is_omitted, row, start_col, end_col = tensor_format.locate_tensor_element(
+ out, [1, 3])
self.assertFalse(is_omitted)
self.assertEqual(3, row)
+ self.assertEqual(36, start_col)
+ self.assertEqual(42, end_col)
- is_omitted, row = tensor_format.locate_tensor_element(out, [3, 3])
+ is_omitted, row, start_col, end_col = tensor_format.locate_tensor_element(
+ out, [3, 3])
self.assertFalse(is_omitted)
self.assertEqual(5, row)
+ self.assertEqual(36, start_col)
+ self.assertEqual(42, end_col)
with self.assertRaisesRegexp(
ValueError, "Indices exceed tensor dimensions"):
@@ -398,41 +660,68 @@ class RichTextLinesTest(test_util.TensorFlowTestCase):
" [ 0., 0., ..., 0., 0.]]])",
], out.lines)
- is_omitted, row = tensor_format.locate_tensor_element(out, [0, 0, 0])
+ is_omitted, row, start_col, end_col = tensor_format.locate_tensor_element(
+ out, [0, 0, 0])
self.assertFalse(is_omitted)
self.assertEqual(2, row)
+ self.assertEqual(10, start_col)
+ self.assertEqual(12, end_col)
- is_omitted, row = tensor_format.locate_tensor_element(out, [0, 0, 10])
+ is_omitted, row, start_col, end_col = tensor_format.locate_tensor_element(
+ out, [0, 0, 10])
self.assertFalse(is_omitted)
self.assertEqual(2, row)
+ self.assertIsNone(start_col) # Passes ellipsis.
+ self.assertIsNone(end_col)
- is_omitted, row = tensor_format.locate_tensor_element(out, [0, 1, 0])
+ is_omitted, row, start_col, end_col = tensor_format.locate_tensor_element(
+ out, [0, 1, 0])
self.assertFalse(is_omitted)
self.assertEqual(3, row)
+ self.assertEqual(10, start_col)
+ self.assertEqual(12, end_col)
- is_omitted, row = tensor_format.locate_tensor_element(out, [0, 2, 0])
+ is_omitted, row, start_col, end_col = tensor_format.locate_tensor_element(
+ out, [0, 2, 0])
self.assertTrue(is_omitted) # In omitted line.
self.assertEqual(4, row)
+ self.assertIsNone(start_col)
+ self.assertIsNone(end_col)
- is_omitted, row = tensor_format.locate_tensor_element(out, [0, 2, 10])
+ is_omitted, row, start_col, end_col = tensor_format.locate_tensor_element(
+ out, [0, 2, 10])
self.assertTrue(is_omitted) # In omitted line.
self.assertEqual(4, row)
+ self.assertIsNone(start_col)
+ self.assertIsNone(end_col)
- is_omitted, row = tensor_format.locate_tensor_element(out, [0, 8, 10])
+ is_omitted, row, start_col, end_col = tensor_format.locate_tensor_element(
+ out, [0, 8, 10])
self.assertTrue(is_omitted) # In omitted line.
self.assertEqual(4, row)
+ self.assertIsNone(start_col)
+ self.assertIsNone(end_col)
- is_omitted, row = tensor_format.locate_tensor_element(out, [0, 10, 1])
+ is_omitted, row, start_col, end_col = tensor_format.locate_tensor_element(
+ out, [0, 10, 1])
self.assertFalse(is_omitted)
self.assertEqual(6, row)
+ self.assertEqual(15, start_col)
+ self.assertEqual(17, end_col)
- is_omitted, row = tensor_format.locate_tensor_element(out, [5, 1, 1])
+ is_omitted, row, start_col, end_col = tensor_format.locate_tensor_element(
+ out, [5, 1, 1])
self.assertTrue(is_omitted) # In omitted line.
self.assertEqual(14, row)
+ self.assertIsNone(start_col)
+ self.assertIsNone(end_col)
- is_omitted, row = tensor_format.locate_tensor_element(out, [10, 10, 10])
+ is_omitted, row, start_col, end_col = tensor_format.locate_tensor_element(
+ out, [10, 10, 10])
self.assertFalse(is_omitted)
self.assertEqual(25, row)
+ self.assertIsNone(start_col) # Past ellipsis.
+ self.assertIsNone(end_col)
with self.assertRaisesRegexp(
ValueError, "Indices exceed tensor dimensions"):
@@ -446,6 +735,73 @@ class RichTextLinesTest(test_util.TensorFlowTestCase):
ValueError, "Dimensions mismatch"):
tensor_format.locate_tensor_element(out, [5, 5])
+ def testLocateTensorElement3DWithEllipsesBatchMode(self):
+ a = np.zeros([11, 11, 11])
+
+ out = tensor_format.format_tensor(
+ a, "a", False, np_printoptions={"threshold": 100,
+ "edgeitems": 2})
+
+ self.assertEqual([
+ "Tensor \"a\":",
+ "",
+ "array([[[ 0., 0., ..., 0., 0.],",
+ " [ 0., 0., ..., 0., 0.],",
+ " ..., ",
+ " [ 0., 0., ..., 0., 0.],",
+ " [ 0., 0., ..., 0., 0.]],",
+ "",
+ " [[ 0., 0., ..., 0., 0.],",
+ " [ 0., 0., ..., 0., 0.],",
+ " ..., ",
+ " [ 0., 0., ..., 0., 0.],",
+ " [ 0., 0., ..., 0., 0.]],",
+ "",
+ " ..., ",
+ " [[ 0., 0., ..., 0., 0.],",
+ " [ 0., 0., ..., 0., 0.],",
+ " ..., ",
+ " [ 0., 0., ..., 0., 0.],",
+ " [ 0., 0., ..., 0., 0.]],",
+ "",
+ " [[ 0., 0., ..., 0., 0.],",
+ " [ 0., 0., ..., 0., 0.],",
+ " ..., ",
+ " [ 0., 0., ..., 0., 0.],",
+ " [ 0., 0., ..., 0., 0.]]])",
+ ], out.lines)
+
+ (are_omitted, rows, start_cols,
+ end_cols) = tensor_format.locate_tensor_element(out, [[0, 0, 0]])
+ self.assertEqual([False], are_omitted)
+ self.assertEqual([2], rows)
+ self.assertEqual([10], start_cols)
+ self.assertEqual([12], end_cols)
+
+ (are_omitted, rows, start_cols,
+ end_cols) = tensor_format.locate_tensor_element(out,
+ [[0, 0, 0], [0, 0, 10]])
+ self.assertEqual([False, False], are_omitted)
+ self.assertEqual([2, 2], rows)
+ self.assertEqual([10, None], start_cols)
+ self.assertEqual([12, None], end_cols)
+
+ (are_omitted, rows, start_cols,
+ end_cols) = tensor_format.locate_tensor_element(out,
+ [[0, 0, 0], [0, 2, 0]])
+ self.assertEqual([False, True], are_omitted)
+ self.assertEqual([2, 4], rows)
+ self.assertEqual([10, None], start_cols)
+ self.assertEqual([12, None], end_cols)
+
+ (are_omitted, rows, start_cols,
+ end_cols) = tensor_format.locate_tensor_element(out,
+ [[0, 0, 0], [10, 10, 10]])
+ self.assertEqual([False, False], are_omitted)
+ self.assertEqual([2, 25], rows)
+ self.assertEqual([10, None], start_cols)
+ self.assertEqual([12, None], end_cols)
+
def testLocateTensorElementAnnotationsUnavailable(self):
out = tensor_format.format_tensor(None, "a")
diff --git a/tensorflow/python/debug/examples/README.md b/tensorflow/python/debug/examples/README.md
index a5c56e8c54..8719dea3f7 100644
--- a/tensorflow/python/debug/examples/README.md
+++ b/tensorflow/python/debug/examples/README.md
@@ -156,6 +156,7 @@ Try the following commands at the `tfdbg>` prompt (referencing the code at
| `pt hidden/Relu:0` | Print the value of the tensor `hidden/Relu:0`. |
| `pt hidden/Relu:0[0:50,:]` | Print a subarray of the tensor `hidden/Relu:0`, using [numpy](http://www.numpy.org/)-style array slicing. |
| `pt hidden/Relu:0[0:50,:] -a` | For a large tensor like the one here, print its value in its entirety—i.e., without using any ellipsis. May take a long time for large tensors. |
+| `pt hidden/Relu:0[0:10,:] -a -r [1,inf]` | Use the `-r` flag to highlight elements falling into the specified numerical range. Multiple ranges can be used in conjunction, e.g., `-r [[-inf,-1],[1,inf]]`.|
| `@[10,0]` or `@10,0` | Navigate to indices [10, 0] in the tensor being displayed. |
| `/inf` | Search the screen output with the regex `inf` and highlight any matches. |
| `/` | Scroll to the next line with matches to the searched regex (if any). |