diff options
author | 2017-02-06 10:51:16 -0800 | |
---|---|---|
committer | 2017-02-06 11:19:58 -0800 | |
commit | 96cbc386976cd4c631ac47a80692bdcdc6d8df11 (patch) | |
tree | 8e521a51fd491472ba9936208befb96e2f36dff2 /tensorflow/python/debug/cli | |
parent | 1b6fe5a41cddf09063a4944ecb911ff27f598264 (diff) |
tfdbg stepper: display last updated variables
* Add public last_updated() method to NodeStepper class.
* The Stepper CLI uses this new method to inform the user of what variables (if any) are updated in the last cont/step action.
Change: 146683089
Diffstat (limited to 'tensorflow/python/debug/cli')
-rw-r--r-- | tensorflow/python/debug/cli/stepper_cli.py | 94 | ||||
-rw-r--r-- | tensorflow/python/debug/cli/stepper_cli_test.py | 41 |
2 files changed, 97 insertions, 38 deletions
diff --git a/tensorflow/python/debug/cli/stepper_cli.py b/tensorflow/python/debug/cli/stepper_cli.py index bb76c440bc..6bdbfd72b0 100644 --- a/tensorflow/python/debug/cli/stepper_cli.py +++ b/tensorflow/python/debug/cli/stepper_cli.py @@ -65,6 +65,8 @@ class NodeStepperCLI(object): "Please use full tensor name.", } + _UPDATED_ATTRIBUTE = "bold" + _STATE_COLORS = { STATE_CONT: "green", STATE_DIRTY_VARIABLE: "magenta", @@ -74,6 +76,13 @@ class NodeStepperCLI(object): STATE_UNFEEDABLE: "red", } + _FEED_COLORS = { + stepper.NodeStepper.FEED_TYPE_CLIENT: "white", + stepper.NodeStepper.FEED_TYPE_HANDLE: "green", + stepper.NodeStepper.FEED_TYPE_OVERRIDE: "yellow", + stepper.NodeStepper.FEED_TYPE_DUMPED_INTERMEDIATE: "blue", + } + def __init__(self, node_stepper): self._node_stepper = node_stepper @@ -403,45 +412,13 @@ class NodeStepperCLI(object): restore_variable_values=parsed.restore_variable_values) self._completed_nodes.add(parsed.target_name.split(":")[0]) - feed_types = self._node_stepper.last_feed_types() - - lines = ["Continued to %s:" % parsed.target_name, ""] - font_attr_segs = {} - lines.append("Stepper used feeds:") - line_counter = len(lines) - - if feed_types: - for feed_name in feed_types: - feed_info_line = " %s : %s" % (feed_name, feed_types[feed_name]) - lines.append(feed_info_line) - if feed_types[feed_name] == stepper.NodeStepper.FEED_TYPE_HANDLE: - font_attr_segs[line_counter] = [ - (len(feed_name) + 2, - len(feed_info_line), - self._STATE_COLORS[self.STATE_UNFEEDABLE]) - ] - elif (feed_types[feed_name] == - stepper.NodeStepper.FEED_TYPE_DUMPED_INTERMEDIATE): - font_attr_segs[line_counter] = [( - len(feed_name) + 2, - len(feed_info_line), - self._STATE_COLORS[self.STATE_DUMPED_INTERMEDIATE])] - elif feed_types[feed_name] == stepper.NodeStepper.FEED_TYPE_OVERRIDE: - font_attr_segs[line_counter] = [ - (len(feed_name) + 2, len(feed_info_line), "yellow") - ] - line_counter += 1 - else: - lines.append(" (No feeds)") - lines.append("") - screen_output = debugger_cli_common.RichTextLines( - lines, font_attr_segs=font_attr_segs) - - tensor_output = tensor_format.format_tensor( - cont_result, parsed.target_name, - include_metadata=True) - screen_output.extend(tensor_output) + ["Continued to %s:" % parsed.target_name, ""]) + screen_output.extend(self._report_last_feed_types()) + screen_output.extend(self._report_last_updated()) + screen_output.extend( + tensor_format.format_tensor( + cont_result, parsed.target_name, include_metadata=True)) # Generate windowed view of the sorted transitive closure on which the # stepping is occurring. @@ -458,6 +435,47 @@ class NodeStepperCLI(object): return final_output + def _report_last_feed_types(self): + """Generate a report of the feed types used in the cont/step call. + + Returns: + (debugger_cli_common.RichTextLines) A RichTextLines representation of the + feeds used in the last cont/step call. + """ + feed_types = self._node_stepper.last_feed_types() + + out = debugger_cli_common.RichTextLines(["Stepper used feeds:"]) + if feed_types: + for feed_name in feed_types: + feed_info = RL(" %s : " % feed_name) + feed_info += RL(feed_types[feed_name], + self._FEED_COLORS[feed_types[feed_name]]) + out.append(feed_info.text, font_attr_segs=feed_info.font_attr_segs) + else: + out.append(" (No feeds)") + out.append("") + + return out + + def _report_last_updated(self): + """Generate a report of the variables updated in the last cont/step call. + + Returns: + (debugger_cli_common.RichTextLines) A RichTextLines representation of the + variables updated in the last cont/step call. + """ + + last_updated = self._node_stepper.last_updated() + if not last_updated: + return debugger_cli_common.RichTextLines([]) + + rich_lines = [RL("Updated:", self._UPDATED_ATTRIBUTE)] + sorted_last_updated = sorted(list(last_updated)) + for updated in sorted_last_updated: + rich_lines.append(RL(" %s" % updated)) + rich_lines.append(RL("")) + return debugger_cli_common.rich_text_lines_from_rich_line_list(rich_lines) + def step(self, args, screen_info=None): """Step once. diff --git a/tensorflow/python/debug/cli/stepper_cli_test.py b/tensorflow/python/debug/cli/stepper_cli_test.py index 0dd4493c95..00a06d1334 100644 --- a/tensorflow/python/debug/cli/stepper_cli_test.py +++ b/tensorflow/python/debug/cli/stepper_cli_test.py @@ -95,6 +95,38 @@ def _parsed_used_feeds(lines): feed_types[feed_name] = feed_type +def _parse_updated(lines): + """Parse the Updated section in the output text lines. + + Args: + lines: (list of str) The output text lines to be parsed. + + Returns: + If the Updated section does not exist, returns None. + Otherwise, returns the Tensor names included in the section. + """ + updated = None + + begin_line = -1 + for i, line in enumerate(lines): + if line.startswith("Updated:"): + updated = [] + begin_line = i + 1 + break + + if begin_line == -1: + return updated + + for line in lines[begin_line:]: + line = line.strip() + if not line: + return updated + else: + updated.append(line.strip()) + + return updated + + class NodeStepperSimpleGraphTest(test_util.TensorFlowTestCase): def setUp(self): @@ -208,6 +240,7 @@ class NodeStepperSimpleGraphTest(test_util.TensorFlowTestCase): self.assertEqual(0, node_pointer) output = cli.cont("c") + self.assertIsNone(_parse_updated(output.lines)) node_names, stat_labels, node_pointer = _parse_sorted_nodes_list( output.lines) @@ -218,6 +251,7 @@ class NodeStepperSimpleGraphTest(test_util.TensorFlowTestCase): self.assertIn(stepper_cli.NodeStepperCLI.STATE_CONT, stat_labels[index_c]) output = cli.cont("d") + self.assertIsNone(_parse_updated(output.lines)) node_names, stat_labels, node_pointer = _parse_sorted_nodes_list( output.lines) @@ -309,6 +343,7 @@ class NodeStepperSimpleGraphTest(test_util.TensorFlowTestCase): cli = stepper_cli.NodeStepperCLI(node_stepper) output = cli.cont([no_output_node]) + self.assertIsNone(_parse_updated(output.lines)) node_names, stat_labels, node_pointer = _parse_sorted_nodes_list( output.lines) @@ -333,6 +368,8 @@ class NodeStepperSimpleGraphTest(test_util.TensorFlowTestCase): cli = stepper_cli.NodeStepperCLI(node_stepper) output = cli.cont(["opt/update_b/ApplyGradientDescent"]) + self.assertItemsEqual([self.b.name], _parse_updated(output.lines)) + output = cli.list_sorted_nodes([]) node_names, stat_labels, _ = _parse_sorted_nodes_list(output.lines) self.assertIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE, @@ -346,6 +383,8 @@ class NodeStepperSimpleGraphTest(test_util.TensorFlowTestCase): output = cli.cont(["opt/update_a/ApplyGradientDescent", "--invalidate_from_updated_variables"]) + self.assertItemsEqual([self.a.name], _parse_updated(output.lines)) + # After cont() call on .../update_a/..., Variable a should have been # marked as dirty, whereas b should not have. output = cli.list_sorted_nodes([]) @@ -357,6 +396,8 @@ class NodeStepperSimpleGraphTest(test_util.TensorFlowTestCase): output = cli.cont(["opt/update_b/ApplyGradientDescent", "-r", "-i"]) + self.assertItemsEqual([self.b.name], _parse_updated(output.lines)) + # After cont() call on .../update_b/... with the -r flag, Variable b # should have been marked as dirty, whereas Variable a should not be # because it should have been restored. |