diff options
author | Shanqing Cai <cais@google.com> | 2017-02-06 10:51:16 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-02-06 11:19:58 -0800 |
commit | 96cbc386976cd4c631ac47a80692bdcdc6d8df11 (patch) | |
tree | 8e521a51fd491472ba9936208befb96e2f36dff2 | |
parent | 1b6fe5a41cddf09063a4944ecb911ff27f598264 (diff) |
tfdbg stepper: display last updated variables
* Add public last_updated() method to NodeStepper class.
* The Stepper CLI uses this new method to inform the user of what variables (if any) are updated in the last cont/step action.
Change: 146683089
-rw-r--r-- | tensorflow/python/debug/cli/stepper_cli.py | 94 | ||||
-rw-r--r-- | tensorflow/python/debug/cli/stepper_cli_test.py | 41 | ||||
-rw-r--r-- | tensorflow/python/debug/stepper.py | 30 | ||||
-rw-r--r-- | tensorflow/python/debug/stepper_test.py | 21 |
4 files changed, 137 insertions, 49 deletions
diff --git a/tensorflow/python/debug/cli/stepper_cli.py b/tensorflow/python/debug/cli/stepper_cli.py index bb76c440bc..6bdbfd72b0 100644 --- a/tensorflow/python/debug/cli/stepper_cli.py +++ b/tensorflow/python/debug/cli/stepper_cli.py @@ -65,6 +65,8 @@ class NodeStepperCLI(object): "Please use full tensor name.", } + _UPDATED_ATTRIBUTE = "bold" + _STATE_COLORS = { STATE_CONT: "green", STATE_DIRTY_VARIABLE: "magenta", @@ -74,6 +76,13 @@ class NodeStepperCLI(object): STATE_UNFEEDABLE: "red", } + _FEED_COLORS = { + stepper.NodeStepper.FEED_TYPE_CLIENT: "white", + stepper.NodeStepper.FEED_TYPE_HANDLE: "green", + stepper.NodeStepper.FEED_TYPE_OVERRIDE: "yellow", + stepper.NodeStepper.FEED_TYPE_DUMPED_INTERMEDIATE: "blue", + } + def __init__(self, node_stepper): self._node_stepper = node_stepper @@ -403,45 +412,13 @@ class NodeStepperCLI(object): restore_variable_values=parsed.restore_variable_values) self._completed_nodes.add(parsed.target_name.split(":")[0]) - feed_types = self._node_stepper.last_feed_types() - - lines = ["Continued to %s:" % parsed.target_name, ""] - font_attr_segs = {} - lines.append("Stepper used feeds:") - line_counter = len(lines) - - if feed_types: - for feed_name in feed_types: - feed_info_line = " %s : %s" % (feed_name, feed_types[feed_name]) - lines.append(feed_info_line) - if feed_types[feed_name] == stepper.NodeStepper.FEED_TYPE_HANDLE: - font_attr_segs[line_counter] = [ - (len(feed_name) + 2, - len(feed_info_line), - self._STATE_COLORS[self.STATE_UNFEEDABLE]) - ] - elif (feed_types[feed_name] == - stepper.NodeStepper.FEED_TYPE_DUMPED_INTERMEDIATE): - font_attr_segs[line_counter] = [( - len(feed_name) + 2, - len(feed_info_line), - self._STATE_COLORS[self.STATE_DUMPED_INTERMEDIATE])] - elif feed_types[feed_name] == stepper.NodeStepper.FEED_TYPE_OVERRIDE: - font_attr_segs[line_counter] = [ - (len(feed_name) + 2, len(feed_info_line), "yellow") - ] - line_counter += 1 - else: - lines.append(" (No feeds)") - lines.append("") - screen_output = debugger_cli_common.RichTextLines( - lines, font_attr_segs=font_attr_segs) - - tensor_output = tensor_format.format_tensor( - cont_result, parsed.target_name, - include_metadata=True) - screen_output.extend(tensor_output) + ["Continued to %s:" % parsed.target_name, ""]) + screen_output.extend(self._report_last_feed_types()) + screen_output.extend(self._report_last_updated()) + screen_output.extend( + tensor_format.format_tensor( + cont_result, parsed.target_name, include_metadata=True)) # Generate windowed view of the sorted transitive closure on which the # stepping is occurring. @@ -458,6 +435,47 @@ class NodeStepperCLI(object): return final_output + def _report_last_feed_types(self): + """Generate a report of the feed types used in the cont/step call. + + Returns: + (debugger_cli_common.RichTextLines) A RichTextLines representation of the + feeds used in the last cont/step call. + """ + feed_types = self._node_stepper.last_feed_types() + + out = debugger_cli_common.RichTextLines(["Stepper used feeds:"]) + if feed_types: + for feed_name in feed_types: + feed_info = RL(" %s : " % feed_name) + feed_info += RL(feed_types[feed_name], + self._FEED_COLORS[feed_types[feed_name]]) + out.append(feed_info.text, font_attr_segs=feed_info.font_attr_segs) + else: + out.append(" (No feeds)") + out.append("") + + return out + + def _report_last_updated(self): + """Generate a report of the variables updated in the last cont/step call. + + Returns: + (debugger_cli_common.RichTextLines) A RichTextLines representation of the + variables updated in the last cont/step call. + """ + + last_updated = self._node_stepper.last_updated() + if not last_updated: + return debugger_cli_common.RichTextLines([]) + + rich_lines = [RL("Updated:", self._UPDATED_ATTRIBUTE)] + sorted_last_updated = sorted(list(last_updated)) + for updated in sorted_last_updated: + rich_lines.append(RL(" %s" % updated)) + rich_lines.append(RL("")) + return debugger_cli_common.rich_text_lines_from_rich_line_list(rich_lines) + def step(self, args, screen_info=None): """Step once. diff --git a/tensorflow/python/debug/cli/stepper_cli_test.py b/tensorflow/python/debug/cli/stepper_cli_test.py index 0dd4493c95..00a06d1334 100644 --- a/tensorflow/python/debug/cli/stepper_cli_test.py +++ b/tensorflow/python/debug/cli/stepper_cli_test.py @@ -95,6 +95,38 @@ def _parsed_used_feeds(lines): feed_types[feed_name] = feed_type +def _parse_updated(lines): + """Parse the Updated section in the output text lines. + + Args: + lines: (list of str) The output text lines to be parsed. + + Returns: + If the Updated section does not exist, returns None. + Otherwise, returns the Tensor names included in the section. + """ + updated = None + + begin_line = -1 + for i, line in enumerate(lines): + if line.startswith("Updated:"): + updated = [] + begin_line = i + 1 + break + + if begin_line == -1: + return updated + + for line in lines[begin_line:]: + line = line.strip() + if not line: + return updated + else: + updated.append(line.strip()) + + return updated + + class NodeStepperSimpleGraphTest(test_util.TensorFlowTestCase): def setUp(self): @@ -208,6 +240,7 @@ class NodeStepperSimpleGraphTest(test_util.TensorFlowTestCase): self.assertEqual(0, node_pointer) output = cli.cont("c") + self.assertIsNone(_parse_updated(output.lines)) node_names, stat_labels, node_pointer = _parse_sorted_nodes_list( output.lines) @@ -218,6 +251,7 @@ class NodeStepperSimpleGraphTest(test_util.TensorFlowTestCase): self.assertIn(stepper_cli.NodeStepperCLI.STATE_CONT, stat_labels[index_c]) output = cli.cont("d") + self.assertIsNone(_parse_updated(output.lines)) node_names, stat_labels, node_pointer = _parse_sorted_nodes_list( output.lines) @@ -309,6 +343,7 @@ class NodeStepperSimpleGraphTest(test_util.TensorFlowTestCase): cli = stepper_cli.NodeStepperCLI(node_stepper) output = cli.cont([no_output_node]) + self.assertIsNone(_parse_updated(output.lines)) node_names, stat_labels, node_pointer = _parse_sorted_nodes_list( output.lines) @@ -333,6 +368,8 @@ class NodeStepperSimpleGraphTest(test_util.TensorFlowTestCase): cli = stepper_cli.NodeStepperCLI(node_stepper) output = cli.cont(["opt/update_b/ApplyGradientDescent"]) + self.assertItemsEqual([self.b.name], _parse_updated(output.lines)) + output = cli.list_sorted_nodes([]) node_names, stat_labels, _ = _parse_sorted_nodes_list(output.lines) self.assertIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE, @@ -346,6 +383,8 @@ class NodeStepperSimpleGraphTest(test_util.TensorFlowTestCase): output = cli.cont(["opt/update_a/ApplyGradientDescent", "--invalidate_from_updated_variables"]) + self.assertItemsEqual([self.a.name], _parse_updated(output.lines)) + # After cont() call on .../update_a/..., Variable a should have been # marked as dirty, whereas b should not have. output = cli.list_sorted_nodes([]) @@ -357,6 +396,8 @@ class NodeStepperSimpleGraphTest(test_util.TensorFlowTestCase): output = cli.cont(["opt/update_b/ApplyGradientDescent", "-r", "-i"]) + self.assertItemsEqual([self.b.name], _parse_updated(output.lines)) + # After cont() call on .../update_b/... with the -r flag, Variable b # should have been marked as dirty, whereas Variable a should not be # because it should have been restored. diff --git a/tensorflow/python/debug/stepper.py b/tensorflow/python/debug/stepper.py index 3cbe83d072..ab500f52e3 100644 --- a/tensorflow/python/debug/stepper.py +++ b/tensorflow/python/debug/stepper.py @@ -156,6 +156,9 @@ class NodeStepper(object): # Keep track of which variables are in a dirty state. self._dirty_variables = set() + # Variables updated in the last cont() call. + self._last_updated = None + # Cached tensor handles: a dict with keys as tensor names and values as # tensor handles. self._tensor_handles = {} @@ -441,9 +444,6 @@ class NodeStepper(object): return self._last_feed_types - # TODO(cais): Add method last_updated_variables() to allow client to look - # up the Variables updated in the last cont() call. - def cont(self, target, use_tensor_handles=True, @@ -550,7 +550,7 @@ class NodeStepper(object): # Keep track of which variables are "touched" (i.e., possibly updated) in # this cont() call. - touched_variables = set() + self._last_updated = set() # ========================================================================= # Use a non-recursive method to trace the inputs from the node and set up @@ -583,7 +583,7 @@ class NodeStepper(object): if (restore_variable_values and inp.name in self._dirty_variables and inp.name not in restored_variables and - inp.name not in touched_variables): + inp.name not in self._last_updated): # Do not restore Variables touched or restored previously in this # cont() call. initializer_op = self._variable_initializers[inp.name] @@ -604,7 +604,7 @@ class NodeStepper(object): if (is_inp_ref and inp.op.type in ["Variable", "VariableV2"] and curr_node.type != "Identity"): # Mark the variable as dirty. - touched_variables.add(inp.name) + self._last_updated.add(inp.name) # Obtain the old value of the variable and cache it. if inp.name not in self._cached_variable_values: @@ -648,8 +648,8 @@ class NodeStepper(object): # ========================================================================= - if touched_variables: - self._dirty_variables.update(touched_variables) + if self._last_updated: + self._dirty_variables.update(self._last_updated) for variable in restored_variables: self._dirty_variables.remove(variable) @@ -686,8 +686,8 @@ class NodeStepper(object): if invalidate_from_updated_variables: # Invalidate caches at the end. - for touched_variable in touched_variables: - self._invalidate_transitively_outgoing_cache(touched_variable) + for last_updated_variable in self._last_updated: + self._invalidate_transitively_outgoing_cache(last_updated_variable) return return_value @@ -848,6 +848,16 @@ class NodeStepper(object): return self._dumped_intermediate_tensors.keys() + def last_updated(self): + """Get the names of the variables updated in the last cont() call. + + Returns: + A set of the variable names updated in the previous cont() call. + If no cont() call has occurred before, returns None. + """ + + return self._last_updated + def dirty_variables(self): """Get the set of variables that are currently "dirty". diff --git a/tensorflow/python/debug/stepper_test.py b/tensorflow/python/debug/stepper_test.py index 63501b4fe6..41a62c49b1 100644 --- a/tensorflow/python/debug/stepper_test.py +++ b/tensorflow/python/debug/stepper_test.py @@ -587,6 +587,10 @@ class StepperAssignAddTest(test_util.TensorFlowTestCase): def tearDown(self): ops.reset_default_graph() + def testLastUpdatedVariablesReturnsNoneBeforeAnyContCalls(self): + with NodeStepper(self.sess, [self.q, self.v_add]) as stepper: + self.assertIsNone(stepper.last_updated()) + def testContToUpdateInvalidatesDumpedIntermedates(self): with NodeStepper(self.sess, [self.q, self.v_add]) as stepper: self.assertAllClose(400.0, stepper.cont("q:0")) @@ -599,6 +603,7 @@ class StepperAssignAddTest(test_util.TensorFlowTestCase): 12.0, stepper.cont( self.v_add, invalidate_from_updated_variables=True)) self.assertAllClose(12.0, self.sess.run(self.v)) + self.assertSetEqual({self.v.name}, stepper.last_updated()) self.assertItemsEqual(["v:0"], stepper.dirty_variables()) # Updating the value of v by calling v_add should have invalidated the # dumped intermediate tensors for v/read:0 and p:0. @@ -612,6 +617,7 @@ class StepperAssignAddTest(test_util.TensorFlowTestCase): # The next cont to q should not have used any dumped intermediate tensors # and its result should reflect the updated value. self.assertAllClose(576.0, stepper.cont("q:0")) + self.assertSetEqual(set(), stepper.last_updated()) self.assertEqual({}, stepper.last_feed_types()) def testOverridingUpstreamTensorInvalidatesDumpedIntermediates(self): @@ -652,8 +658,10 @@ class StepperAssignAddTest(test_util.TensorFlowTestCase): def testRepeatedCallsToAssignAddDoesNotUpdateVariableAgain(self): with NodeStepper(self.sess, self.v_add) as stepper: stepper.cont(self.v_add) + self.assertSetEqual({self.v.name}, stepper.last_updated()) self.assertAllClose(12.0, stepper.cont(self.v)) stepper.cont(self.v_add) + self.assertSetEqual(set(), stepper.last_updated()) self.assertEqual({"v_add:0": NodeStepper.FEED_TYPE_HANDLE}, stepper.last_feed_types()) self.assertAllClose(12.0, stepper.cont(self.v)) @@ -661,8 +669,10 @@ class StepperAssignAddTest(test_util.TensorFlowTestCase): def testRepeatedCallsToAssignAddDownStreamDoesNotUpdateVariableAgain(self): with NodeStepper(self.sess, self.v_add_plus_one) as stepper: stepper.cont(self.v_add_plus_one) + self.assertSetEqual({self.v.name}, stepper.last_updated()) self.assertAllClose(12.0, stepper.cont(self.v)) stepper.cont(self.v_add_plus_one) + self.assertSetEqual(set(), stepper.last_updated()) self.assertEqual({"v_add_plus_one:0": NodeStepper.FEED_TYPE_HANDLE}, stepper.last_feed_types()) self.assertAllClose(12.0, stepper.cont(self.v)) @@ -724,6 +734,7 @@ class StepperBackwardRunTest(test_util.TensorFlowTestCase): # Now variable a should have been marked as dirty due to the update # by optim/update_a/ApplyGradientDescent. + self.assertSetEqual({"a:0"}, stepper.last_updated()) self.assertEqual({"a:0"}, stepper.dirty_variables()) self.assertIsNone(result) self.assertEqual({ @@ -745,6 +756,7 @@ class StepperBackwardRunTest(test_util.TensorFlowTestCase): result = stepper.cont("optim/update_b/ApplyGradientDescent", invalidate_from_updated_variables=True) self.assertIsNone(result) + self.assertSetEqual({"b:0"}, stepper.last_updated()) self.assertEqual(set(["b:0"]), stepper.dirty_variables()) # For backprop on Variable b: @@ -763,6 +775,7 @@ class StepperBackwardRunTest(test_util.TensorFlowTestCase): invalidate_from_updated_variables=True, restore_variable_values=True) self.assertIsNone(result) + self.assertSetEqual({"a:0"}, stepper.last_updated()) self.assertEqual(set(["a:0"]), stepper.dirty_variables()) self.assertAllClose(0.84, self.sess.run(self.a)) self.assertAllClose(2.0, self.sess.run(self.b)) @@ -790,6 +803,7 @@ class StepperBackwardRunTest(test_util.TensorFlowTestCase): invalidate_from_updated_variables=False)) # Even though invalidate_from_updated_variables is set to False, dirty # variables should still have been tracked. + self.assertSetEqual({"a:0"}, stepper.last_updated()) self.assertEqual({"a:0"}, stepper.dirty_variables()) self.assertIn("a/read:0", stepper.intermediate_tensor_names()) self.assertIn("b/read:0", stepper.intermediate_tensor_names()) @@ -825,6 +839,7 @@ class StepperBackwardRunTest(test_util.TensorFlowTestCase): invalidate_from_updated_variables=True, restore_variable_values=True) self.assertIsNone(result) + self.assertSetEqual({"a:0"}, stepper.last_updated()) self.assertEqual({"a:0"}, stepper.dirty_variables()) result = stepper.cont( @@ -834,6 +849,7 @@ class StepperBackwardRunTest(test_util.TensorFlowTestCase): self.assertIsNone(result) # Variables a and c should have been restored and hence no longer dirty. # Variable b should have been marked as dirty. + self.assertSetEqual({"b:0"}, stepper.last_updated()) self.assertEqual({"b:0"}, stepper.dirty_variables()) # The result of the update should be identitcal to as if only update_b is @@ -857,6 +873,7 @@ class StepperBackwardRunTest(test_util.TensorFlowTestCase): self.assertItemsEqual(["a/read:0", "b/read:0"], stepper.intermediate_tensor_names()) self.assertItemsEqual(["d:0"], stepper.handle_names()) + self.assertSetEqual(set(), stepper.last_updated()) self.assertEqual(set(), stepper.dirty_variables()) result = stepper.cont("e:0") @@ -867,6 +884,7 @@ class StepperBackwardRunTest(test_util.TensorFlowTestCase): self.assertItemsEqual(["d:0", "e:0"], stepper.handle_names()) self.assertItemsEqual(["a/read:0", "b/read:0", "c/read:0"], stepper.intermediate_tensor_names()) + self.assertSetEqual(set(), stepper.last_updated()) self.assertEqual(set(), stepper.dirty_variables()) # Now run update_a, so as to let Variable a be dirty. @@ -878,7 +896,7 @@ class StepperBackwardRunTest(test_util.TensorFlowTestCase): # Due to the update to the value of a:0, the dumped intermediate a/read:0 # should have been invalidated. self.assertNotIn("a/read:0", stepper.intermediate_tensor_names()) - # ["b/read:0", "c/read:0"], + self.assertSetEqual({"a:0"}, stepper.last_updated()) self.assertEqual({"a:0"}, stepper.dirty_variables()) # Now, run update_b. @@ -942,6 +960,7 @@ class StepperBackwardRunTest(test_util.TensorFlowTestCase): result = stepper.cont("d:0") self.assertAllClose(2.0, result) self.assertEqual({}, stepper.last_feed_types()) + self.assertSetEqual(set(), stepper.last_updated()) self.assertEqual(set(), stepper.dirty_variables()) self.assertEqual(["d:0"], stepper.handle_names()) self.assertSetEqual({"d"}, stepper.handle_node_names()) |