aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/debug/cli
diff options
context:
space:
mode:
authorGravatar Shanqing Cai <cais@google.com>2017-02-06 10:51:16 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-02-06 11:19:58 -0800
commit96cbc386976cd4c631ac47a80692bdcdc6d8df11 (patch)
tree8e521a51fd491472ba9936208befb96e2f36dff2 /tensorflow/python/debug/cli
parent1b6fe5a41cddf09063a4944ecb911ff27f598264 (diff)
tfdbg stepper: display last updated variables
* Add public last_updated() method to NodeStepper class. * The Stepper CLI uses this new method to inform the user of what variables (if any) are updated in the last cont/step action. Change: 146683089
Diffstat (limited to 'tensorflow/python/debug/cli')
-rw-r--r--tensorflow/python/debug/cli/stepper_cli.py94
-rw-r--r--tensorflow/python/debug/cli/stepper_cli_test.py41
2 files changed, 97 insertions, 38 deletions
diff --git a/tensorflow/python/debug/cli/stepper_cli.py b/tensorflow/python/debug/cli/stepper_cli.py
index bb76c440bc..6bdbfd72b0 100644
--- a/tensorflow/python/debug/cli/stepper_cli.py
+++ b/tensorflow/python/debug/cli/stepper_cli.py
@@ -65,6 +65,8 @@ class NodeStepperCLI(object):
"Please use full tensor name.",
}
+ _UPDATED_ATTRIBUTE = "bold"
+
_STATE_COLORS = {
STATE_CONT: "green",
STATE_DIRTY_VARIABLE: "magenta",
@@ -74,6 +76,13 @@ class NodeStepperCLI(object):
STATE_UNFEEDABLE: "red",
}
+ _FEED_COLORS = {
+ stepper.NodeStepper.FEED_TYPE_CLIENT: "white",
+ stepper.NodeStepper.FEED_TYPE_HANDLE: "green",
+ stepper.NodeStepper.FEED_TYPE_OVERRIDE: "yellow",
+ stepper.NodeStepper.FEED_TYPE_DUMPED_INTERMEDIATE: "blue",
+ }
+
def __init__(self, node_stepper):
self._node_stepper = node_stepper
@@ -403,45 +412,13 @@ class NodeStepperCLI(object):
restore_variable_values=parsed.restore_variable_values)
self._completed_nodes.add(parsed.target_name.split(":")[0])
- feed_types = self._node_stepper.last_feed_types()
-
- lines = ["Continued to %s:" % parsed.target_name, ""]
- font_attr_segs = {}
- lines.append("Stepper used feeds:")
- line_counter = len(lines)
-
- if feed_types:
- for feed_name in feed_types:
- feed_info_line = " %s : %s" % (feed_name, feed_types[feed_name])
- lines.append(feed_info_line)
- if feed_types[feed_name] == stepper.NodeStepper.FEED_TYPE_HANDLE:
- font_attr_segs[line_counter] = [
- (len(feed_name) + 2,
- len(feed_info_line),
- self._STATE_COLORS[self.STATE_UNFEEDABLE])
- ]
- elif (feed_types[feed_name] ==
- stepper.NodeStepper.FEED_TYPE_DUMPED_INTERMEDIATE):
- font_attr_segs[line_counter] = [(
- len(feed_name) + 2,
- len(feed_info_line),
- self._STATE_COLORS[self.STATE_DUMPED_INTERMEDIATE])]
- elif feed_types[feed_name] == stepper.NodeStepper.FEED_TYPE_OVERRIDE:
- font_attr_segs[line_counter] = [
- (len(feed_name) + 2, len(feed_info_line), "yellow")
- ]
- line_counter += 1
- else:
- lines.append(" (No feeds)")
- lines.append("")
-
screen_output = debugger_cli_common.RichTextLines(
- lines, font_attr_segs=font_attr_segs)
-
- tensor_output = tensor_format.format_tensor(
- cont_result, parsed.target_name,
- include_metadata=True)
- screen_output.extend(tensor_output)
+ ["Continued to %s:" % parsed.target_name, ""])
+ screen_output.extend(self._report_last_feed_types())
+ screen_output.extend(self._report_last_updated())
+ screen_output.extend(
+ tensor_format.format_tensor(
+ cont_result, parsed.target_name, include_metadata=True))
# Generate windowed view of the sorted transitive closure on which the
# stepping is occurring.
@@ -458,6 +435,47 @@ class NodeStepperCLI(object):
return final_output
+ def _report_last_feed_types(self):
+ """Generate a report of the feed types used in the cont/step call.
+
+ Returns:
+ (debugger_cli_common.RichTextLines) A RichTextLines representation of the
+ feeds used in the last cont/step call.
+ """
+ feed_types = self._node_stepper.last_feed_types()
+
+ out = debugger_cli_common.RichTextLines(["Stepper used feeds:"])
+ if feed_types:
+ for feed_name in feed_types:
+ feed_info = RL(" %s : " % feed_name)
+ feed_info += RL(feed_types[feed_name],
+ self._FEED_COLORS[feed_types[feed_name]])
+ out.append(feed_info.text, font_attr_segs=feed_info.font_attr_segs)
+ else:
+ out.append(" (No feeds)")
+ out.append("")
+
+ return out
+
+ def _report_last_updated(self):
+ """Generate a report of the variables updated in the last cont/step call.
+
+ Returns:
+ (debugger_cli_common.RichTextLines) A RichTextLines representation of the
+ variables updated in the last cont/step call.
+ """
+
+ last_updated = self._node_stepper.last_updated()
+ if not last_updated:
+ return debugger_cli_common.RichTextLines([])
+
+ rich_lines = [RL("Updated:", self._UPDATED_ATTRIBUTE)]
+ sorted_last_updated = sorted(list(last_updated))
+ for updated in sorted_last_updated:
+ rich_lines.append(RL(" %s" % updated))
+ rich_lines.append(RL(""))
+ return debugger_cli_common.rich_text_lines_from_rich_line_list(rich_lines)
+
def step(self, args, screen_info=None):
"""Step once.
diff --git a/tensorflow/python/debug/cli/stepper_cli_test.py b/tensorflow/python/debug/cli/stepper_cli_test.py
index 0dd4493c95..00a06d1334 100644
--- a/tensorflow/python/debug/cli/stepper_cli_test.py
+++ b/tensorflow/python/debug/cli/stepper_cli_test.py
@@ -95,6 +95,38 @@ def _parsed_used_feeds(lines):
feed_types[feed_name] = feed_type
+def _parse_updated(lines):
+ """Parse the Updated section in the output text lines.
+
+ Args:
+ lines: (list of str) The output text lines to be parsed.
+
+ Returns:
+ If the Updated section does not exist, returns None.
+ Otherwise, returns the Tensor names included in the section.
+ """
+ updated = None
+
+ begin_line = -1
+ for i, line in enumerate(lines):
+ if line.startswith("Updated:"):
+ updated = []
+ begin_line = i + 1
+ break
+
+ if begin_line == -1:
+ return updated
+
+ for line in lines[begin_line:]:
+ line = line.strip()
+ if not line:
+ return updated
+ else:
+ updated.append(line.strip())
+
+ return updated
+
+
class NodeStepperSimpleGraphTest(test_util.TensorFlowTestCase):
def setUp(self):
@@ -208,6 +240,7 @@ class NodeStepperSimpleGraphTest(test_util.TensorFlowTestCase):
self.assertEqual(0, node_pointer)
output = cli.cont("c")
+ self.assertIsNone(_parse_updated(output.lines))
node_names, stat_labels, node_pointer = _parse_sorted_nodes_list(
output.lines)
@@ -218,6 +251,7 @@ class NodeStepperSimpleGraphTest(test_util.TensorFlowTestCase):
self.assertIn(stepper_cli.NodeStepperCLI.STATE_CONT, stat_labels[index_c])
output = cli.cont("d")
+ self.assertIsNone(_parse_updated(output.lines))
node_names, stat_labels, node_pointer = _parse_sorted_nodes_list(
output.lines)
@@ -309,6 +343,7 @@ class NodeStepperSimpleGraphTest(test_util.TensorFlowTestCase):
cli = stepper_cli.NodeStepperCLI(node_stepper)
output = cli.cont([no_output_node])
+ self.assertIsNone(_parse_updated(output.lines))
node_names, stat_labels, node_pointer = _parse_sorted_nodes_list(
output.lines)
@@ -333,6 +368,8 @@ class NodeStepperSimpleGraphTest(test_util.TensorFlowTestCase):
cli = stepper_cli.NodeStepperCLI(node_stepper)
output = cli.cont(["opt/update_b/ApplyGradientDescent"])
+ self.assertItemsEqual([self.b.name], _parse_updated(output.lines))
+
output = cli.list_sorted_nodes([])
node_names, stat_labels, _ = _parse_sorted_nodes_list(output.lines)
self.assertIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE,
@@ -346,6 +383,8 @@ class NodeStepperSimpleGraphTest(test_util.TensorFlowTestCase):
output = cli.cont(["opt/update_a/ApplyGradientDescent",
"--invalidate_from_updated_variables"])
+ self.assertItemsEqual([self.a.name], _parse_updated(output.lines))
+
# After cont() call on .../update_a/..., Variable a should have been
# marked as dirty, whereas b should not have.
output = cli.list_sorted_nodes([])
@@ -357,6 +396,8 @@ class NodeStepperSimpleGraphTest(test_util.TensorFlowTestCase):
output = cli.cont(["opt/update_b/ApplyGradientDescent", "-r", "-i"])
+ self.assertItemsEqual([self.b.name], _parse_updated(output.lines))
+
# After cont() call on .../update_b/... with the -r flag, Variable b
# should have been marked as dirty, whereas Variable a should not be
# because it should have been restored.