aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Shanqing Cai <cais@google.com>2017-02-06 10:51:16 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-02-06 11:19:58 -0800
commit96cbc386976cd4c631ac47a80692bdcdc6d8df11 (patch)
tree8e521a51fd491472ba9936208befb96e2f36dff2
parent1b6fe5a41cddf09063a4944ecb911ff27f598264 (diff)
tfdbg stepper: display last updated variables
* Add public last_updated() method to NodeStepper class. * The Stepper CLI uses this new method to inform the user of what variables (if any) are updated in the last cont/step action. Change: 146683089
-rw-r--r--tensorflow/python/debug/cli/stepper_cli.py94
-rw-r--r--tensorflow/python/debug/cli/stepper_cli_test.py41
-rw-r--r--tensorflow/python/debug/stepper.py30
-rw-r--r--tensorflow/python/debug/stepper_test.py21
4 files changed, 137 insertions, 49 deletions
diff --git a/tensorflow/python/debug/cli/stepper_cli.py b/tensorflow/python/debug/cli/stepper_cli.py
index bb76c440bc..6bdbfd72b0 100644
--- a/tensorflow/python/debug/cli/stepper_cli.py
+++ b/tensorflow/python/debug/cli/stepper_cli.py
@@ -65,6 +65,8 @@ class NodeStepperCLI(object):
"Please use full tensor name.",
}
+ _UPDATED_ATTRIBUTE = "bold"
+
_STATE_COLORS = {
STATE_CONT: "green",
STATE_DIRTY_VARIABLE: "magenta",
@@ -74,6 +76,13 @@ class NodeStepperCLI(object):
STATE_UNFEEDABLE: "red",
}
+ _FEED_COLORS = {
+ stepper.NodeStepper.FEED_TYPE_CLIENT: "white",
+ stepper.NodeStepper.FEED_TYPE_HANDLE: "green",
+ stepper.NodeStepper.FEED_TYPE_OVERRIDE: "yellow",
+ stepper.NodeStepper.FEED_TYPE_DUMPED_INTERMEDIATE: "blue",
+ }
+
def __init__(self, node_stepper):
self._node_stepper = node_stepper
@@ -403,45 +412,13 @@ class NodeStepperCLI(object):
restore_variable_values=parsed.restore_variable_values)
self._completed_nodes.add(parsed.target_name.split(":")[0])
- feed_types = self._node_stepper.last_feed_types()
-
- lines = ["Continued to %s:" % parsed.target_name, ""]
- font_attr_segs = {}
- lines.append("Stepper used feeds:")
- line_counter = len(lines)
-
- if feed_types:
- for feed_name in feed_types:
- feed_info_line = " %s : %s" % (feed_name, feed_types[feed_name])
- lines.append(feed_info_line)
- if feed_types[feed_name] == stepper.NodeStepper.FEED_TYPE_HANDLE:
- font_attr_segs[line_counter] = [
- (len(feed_name) + 2,
- len(feed_info_line),
- self._STATE_COLORS[self.STATE_UNFEEDABLE])
- ]
- elif (feed_types[feed_name] ==
- stepper.NodeStepper.FEED_TYPE_DUMPED_INTERMEDIATE):
- font_attr_segs[line_counter] = [(
- len(feed_name) + 2,
- len(feed_info_line),
- self._STATE_COLORS[self.STATE_DUMPED_INTERMEDIATE])]
- elif feed_types[feed_name] == stepper.NodeStepper.FEED_TYPE_OVERRIDE:
- font_attr_segs[line_counter] = [
- (len(feed_name) + 2, len(feed_info_line), "yellow")
- ]
- line_counter += 1
- else:
- lines.append(" (No feeds)")
- lines.append("")
-
screen_output = debugger_cli_common.RichTextLines(
- lines, font_attr_segs=font_attr_segs)
-
- tensor_output = tensor_format.format_tensor(
- cont_result, parsed.target_name,
- include_metadata=True)
- screen_output.extend(tensor_output)
+ ["Continued to %s:" % parsed.target_name, ""])
+ screen_output.extend(self._report_last_feed_types())
+ screen_output.extend(self._report_last_updated())
+ screen_output.extend(
+ tensor_format.format_tensor(
+ cont_result, parsed.target_name, include_metadata=True))
# Generate windowed view of the sorted transitive closure on which the
# stepping is occurring.
@@ -458,6 +435,47 @@ class NodeStepperCLI(object):
return final_output
+ def _report_last_feed_types(self):
+ """Generate a report of the feed types used in the cont/step call.
+
+ Returns:
+ (debugger_cli_common.RichTextLines) A RichTextLines representation of the
+ feeds used in the last cont/step call.
+ """
+ feed_types = self._node_stepper.last_feed_types()
+
+ out = debugger_cli_common.RichTextLines(["Stepper used feeds:"])
+ if feed_types:
+ for feed_name in feed_types:
+ feed_info = RL(" %s : " % feed_name)
+ feed_info += RL(feed_types[feed_name],
+ self._FEED_COLORS[feed_types[feed_name]])
+ out.append(feed_info.text, font_attr_segs=feed_info.font_attr_segs)
+ else:
+ out.append(" (No feeds)")
+ out.append("")
+
+ return out
+
+ def _report_last_updated(self):
+ """Generate a report of the variables updated in the last cont/step call.
+
+ Returns:
+ (debugger_cli_common.RichTextLines) A RichTextLines representation of the
+ variables updated in the last cont/step call.
+ """
+
+ last_updated = self._node_stepper.last_updated()
+ if not last_updated:
+ return debugger_cli_common.RichTextLines([])
+
+ rich_lines = [RL("Updated:", self._UPDATED_ATTRIBUTE)]
+ sorted_last_updated = sorted(list(last_updated))
+ for updated in sorted_last_updated:
+ rich_lines.append(RL(" %s" % updated))
+ rich_lines.append(RL(""))
+ return debugger_cli_common.rich_text_lines_from_rich_line_list(rich_lines)
+
def step(self, args, screen_info=None):
"""Step once.
diff --git a/tensorflow/python/debug/cli/stepper_cli_test.py b/tensorflow/python/debug/cli/stepper_cli_test.py
index 0dd4493c95..00a06d1334 100644
--- a/tensorflow/python/debug/cli/stepper_cli_test.py
+++ b/tensorflow/python/debug/cli/stepper_cli_test.py
@@ -95,6 +95,38 @@ def _parsed_used_feeds(lines):
feed_types[feed_name] = feed_type
+def _parse_updated(lines):
+ """Parse the Updated section in the output text lines.
+
+ Args:
+ lines: (list of str) The output text lines to be parsed.
+
+ Returns:
+ If the Updated section does not exist, returns None.
+ Otherwise, returns the Tensor names included in the section.
+ """
+ updated = None
+
+ begin_line = -1
+ for i, line in enumerate(lines):
+ if line.startswith("Updated:"):
+ updated = []
+ begin_line = i + 1
+ break
+
+ if begin_line == -1:
+ return updated
+
+ for line in lines[begin_line:]:
+ line = line.strip()
+ if not line:
+ return updated
+ else:
+ updated.append(line.strip())
+
+ return updated
+
+
class NodeStepperSimpleGraphTest(test_util.TensorFlowTestCase):
def setUp(self):
@@ -208,6 +240,7 @@ class NodeStepperSimpleGraphTest(test_util.TensorFlowTestCase):
self.assertEqual(0, node_pointer)
output = cli.cont("c")
+ self.assertIsNone(_parse_updated(output.lines))
node_names, stat_labels, node_pointer = _parse_sorted_nodes_list(
output.lines)
@@ -218,6 +251,7 @@ class NodeStepperSimpleGraphTest(test_util.TensorFlowTestCase):
self.assertIn(stepper_cli.NodeStepperCLI.STATE_CONT, stat_labels[index_c])
output = cli.cont("d")
+ self.assertIsNone(_parse_updated(output.lines))
node_names, stat_labels, node_pointer = _parse_sorted_nodes_list(
output.lines)
@@ -309,6 +343,7 @@ class NodeStepperSimpleGraphTest(test_util.TensorFlowTestCase):
cli = stepper_cli.NodeStepperCLI(node_stepper)
output = cli.cont([no_output_node])
+ self.assertIsNone(_parse_updated(output.lines))
node_names, stat_labels, node_pointer = _parse_sorted_nodes_list(
output.lines)
@@ -333,6 +368,8 @@ class NodeStepperSimpleGraphTest(test_util.TensorFlowTestCase):
cli = stepper_cli.NodeStepperCLI(node_stepper)
output = cli.cont(["opt/update_b/ApplyGradientDescent"])
+ self.assertItemsEqual([self.b.name], _parse_updated(output.lines))
+
output = cli.list_sorted_nodes([])
node_names, stat_labels, _ = _parse_sorted_nodes_list(output.lines)
self.assertIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE,
@@ -346,6 +383,8 @@ class NodeStepperSimpleGraphTest(test_util.TensorFlowTestCase):
output = cli.cont(["opt/update_a/ApplyGradientDescent",
"--invalidate_from_updated_variables"])
+ self.assertItemsEqual([self.a.name], _parse_updated(output.lines))
+
# After cont() call on .../update_a/..., Variable a should have been
# marked as dirty, whereas b should not have.
output = cli.list_sorted_nodes([])
@@ -357,6 +396,8 @@ class NodeStepperSimpleGraphTest(test_util.TensorFlowTestCase):
output = cli.cont(["opt/update_b/ApplyGradientDescent", "-r", "-i"])
+ self.assertItemsEqual([self.b.name], _parse_updated(output.lines))
+
# After cont() call on .../update_b/... with the -r flag, Variable b
# should have been marked as dirty, whereas Variable a should not be
# because it should have been restored.
diff --git a/tensorflow/python/debug/stepper.py b/tensorflow/python/debug/stepper.py
index 3cbe83d072..ab500f52e3 100644
--- a/tensorflow/python/debug/stepper.py
+++ b/tensorflow/python/debug/stepper.py
@@ -156,6 +156,9 @@ class NodeStepper(object):
# Keep track of which variables are in a dirty state.
self._dirty_variables = set()
+ # Variables updated in the last cont() call.
+ self._last_updated = None
+
# Cached tensor handles: a dict with keys as tensor names and values as
# tensor handles.
self._tensor_handles = {}
@@ -441,9 +444,6 @@ class NodeStepper(object):
return self._last_feed_types
- # TODO(cais): Add method last_updated_variables() to allow client to look
- # up the Variables updated in the last cont() call.
-
def cont(self,
target,
use_tensor_handles=True,
@@ -550,7 +550,7 @@ class NodeStepper(object):
# Keep track of which variables are "touched" (i.e., possibly updated) in
# this cont() call.
- touched_variables = set()
+ self._last_updated = set()
# =========================================================================
# Use a non-recursive method to trace the inputs from the node and set up
@@ -583,7 +583,7 @@ class NodeStepper(object):
if (restore_variable_values and inp.name in self._dirty_variables and
inp.name not in restored_variables and
- inp.name not in touched_variables):
+ inp.name not in self._last_updated):
# Do not restore Variables touched or restored previously in this
# cont() call.
initializer_op = self._variable_initializers[inp.name]
@@ -604,7 +604,7 @@ class NodeStepper(object):
if (is_inp_ref and inp.op.type in ["Variable", "VariableV2"] and
curr_node.type != "Identity"):
# Mark the variable as dirty.
- touched_variables.add(inp.name)
+ self._last_updated.add(inp.name)
# Obtain the old value of the variable and cache it.
if inp.name not in self._cached_variable_values:
@@ -648,8 +648,8 @@ class NodeStepper(object):
# =========================================================================
- if touched_variables:
- self._dirty_variables.update(touched_variables)
+ if self._last_updated:
+ self._dirty_variables.update(self._last_updated)
for variable in restored_variables:
self._dirty_variables.remove(variable)
@@ -686,8 +686,8 @@ class NodeStepper(object):
if invalidate_from_updated_variables:
# Invalidate caches at the end.
- for touched_variable in touched_variables:
- self._invalidate_transitively_outgoing_cache(touched_variable)
+ for last_updated_variable in self._last_updated:
+ self._invalidate_transitively_outgoing_cache(last_updated_variable)
return return_value
@@ -848,6 +848,16 @@ class NodeStepper(object):
return self._dumped_intermediate_tensors.keys()
+ def last_updated(self):
+ """Get the names of the variables updated in the last cont() call.
+
+ Returns:
+ A set of the variable names updated in the previous cont() call.
+ If no cont() call has occurred before, returns None.
+ """
+
+ return self._last_updated
+
def dirty_variables(self):
"""Get the set of variables that are currently "dirty".
diff --git a/tensorflow/python/debug/stepper_test.py b/tensorflow/python/debug/stepper_test.py
index 63501b4fe6..41a62c49b1 100644
--- a/tensorflow/python/debug/stepper_test.py
+++ b/tensorflow/python/debug/stepper_test.py
@@ -587,6 +587,10 @@ class StepperAssignAddTest(test_util.TensorFlowTestCase):
def tearDown(self):
ops.reset_default_graph()
+ def testLastUpdatedVariablesReturnsNoneBeforeAnyContCalls(self):
+ with NodeStepper(self.sess, [self.q, self.v_add]) as stepper:
+ self.assertIsNone(stepper.last_updated())
+
def testContToUpdateInvalidatesDumpedIntermedates(self):
with NodeStepper(self.sess, [self.q, self.v_add]) as stepper:
self.assertAllClose(400.0, stepper.cont("q:0"))
@@ -599,6 +603,7 @@ class StepperAssignAddTest(test_util.TensorFlowTestCase):
12.0, stepper.cont(
self.v_add, invalidate_from_updated_variables=True))
self.assertAllClose(12.0, self.sess.run(self.v))
+ self.assertSetEqual({self.v.name}, stepper.last_updated())
self.assertItemsEqual(["v:0"], stepper.dirty_variables())
# Updating the value of v by calling v_add should have invalidated the
# dumped intermediate tensors for v/read:0 and p:0.
@@ -612,6 +617,7 @@ class StepperAssignAddTest(test_util.TensorFlowTestCase):
# The next cont to q should not have used any dumped intermediate tensors
# and its result should reflect the updated value.
self.assertAllClose(576.0, stepper.cont("q:0"))
+ self.assertSetEqual(set(), stepper.last_updated())
self.assertEqual({}, stepper.last_feed_types())
def testOverridingUpstreamTensorInvalidatesDumpedIntermediates(self):
@@ -652,8 +658,10 @@ class StepperAssignAddTest(test_util.TensorFlowTestCase):
def testRepeatedCallsToAssignAddDoesNotUpdateVariableAgain(self):
with NodeStepper(self.sess, self.v_add) as stepper:
stepper.cont(self.v_add)
+ self.assertSetEqual({self.v.name}, stepper.last_updated())
self.assertAllClose(12.0, stepper.cont(self.v))
stepper.cont(self.v_add)
+ self.assertSetEqual(set(), stepper.last_updated())
self.assertEqual({"v_add:0": NodeStepper.FEED_TYPE_HANDLE},
stepper.last_feed_types())
self.assertAllClose(12.0, stepper.cont(self.v))
@@ -661,8 +669,10 @@ class StepperAssignAddTest(test_util.TensorFlowTestCase):
def testRepeatedCallsToAssignAddDownStreamDoesNotUpdateVariableAgain(self):
with NodeStepper(self.sess, self.v_add_plus_one) as stepper:
stepper.cont(self.v_add_plus_one)
+ self.assertSetEqual({self.v.name}, stepper.last_updated())
self.assertAllClose(12.0, stepper.cont(self.v))
stepper.cont(self.v_add_plus_one)
+ self.assertSetEqual(set(), stepper.last_updated())
self.assertEqual({"v_add_plus_one:0": NodeStepper.FEED_TYPE_HANDLE},
stepper.last_feed_types())
self.assertAllClose(12.0, stepper.cont(self.v))
@@ -724,6 +734,7 @@ class StepperBackwardRunTest(test_util.TensorFlowTestCase):
# Now variable a should have been marked as dirty due to the update
# by optim/update_a/ApplyGradientDescent.
+ self.assertSetEqual({"a:0"}, stepper.last_updated())
self.assertEqual({"a:0"}, stepper.dirty_variables())
self.assertIsNone(result)
self.assertEqual({
@@ -745,6 +756,7 @@ class StepperBackwardRunTest(test_util.TensorFlowTestCase):
result = stepper.cont("optim/update_b/ApplyGradientDescent",
invalidate_from_updated_variables=True)
self.assertIsNone(result)
+ self.assertSetEqual({"b:0"}, stepper.last_updated())
self.assertEqual(set(["b:0"]), stepper.dirty_variables())
# For backprop on Variable b:
@@ -763,6 +775,7 @@ class StepperBackwardRunTest(test_util.TensorFlowTestCase):
invalidate_from_updated_variables=True,
restore_variable_values=True)
self.assertIsNone(result)
+ self.assertSetEqual({"a:0"}, stepper.last_updated())
self.assertEqual(set(["a:0"]), stepper.dirty_variables())
self.assertAllClose(0.84, self.sess.run(self.a))
self.assertAllClose(2.0, self.sess.run(self.b))
@@ -790,6 +803,7 @@ class StepperBackwardRunTest(test_util.TensorFlowTestCase):
invalidate_from_updated_variables=False))
# Even though invalidate_from_updated_variables is set to False, dirty
# variables should still have been tracked.
+ self.assertSetEqual({"a:0"}, stepper.last_updated())
self.assertEqual({"a:0"}, stepper.dirty_variables())
self.assertIn("a/read:0", stepper.intermediate_tensor_names())
self.assertIn("b/read:0", stepper.intermediate_tensor_names())
@@ -825,6 +839,7 @@ class StepperBackwardRunTest(test_util.TensorFlowTestCase):
invalidate_from_updated_variables=True,
restore_variable_values=True)
self.assertIsNone(result)
+ self.assertSetEqual({"a:0"}, stepper.last_updated())
self.assertEqual({"a:0"}, stepper.dirty_variables())
result = stepper.cont(
@@ -834,6 +849,7 @@ class StepperBackwardRunTest(test_util.TensorFlowTestCase):
self.assertIsNone(result)
# Variables a and c should have been restored and hence no longer dirty.
# Variable b should have been marked as dirty.
+ self.assertSetEqual({"b:0"}, stepper.last_updated())
self.assertEqual({"b:0"}, stepper.dirty_variables())
# The result of the update should be identitcal to as if only update_b is
@@ -857,6 +873,7 @@ class StepperBackwardRunTest(test_util.TensorFlowTestCase):
self.assertItemsEqual(["a/read:0", "b/read:0"],
stepper.intermediate_tensor_names())
self.assertItemsEqual(["d:0"], stepper.handle_names())
+ self.assertSetEqual(set(), stepper.last_updated())
self.assertEqual(set(), stepper.dirty_variables())
result = stepper.cont("e:0")
@@ -867,6 +884,7 @@ class StepperBackwardRunTest(test_util.TensorFlowTestCase):
self.assertItemsEqual(["d:0", "e:0"], stepper.handle_names())
self.assertItemsEqual(["a/read:0", "b/read:0", "c/read:0"],
stepper.intermediate_tensor_names())
+ self.assertSetEqual(set(), stepper.last_updated())
self.assertEqual(set(), stepper.dirty_variables())
# Now run update_a, so as to let Variable a be dirty.
@@ -878,7 +896,7 @@ class StepperBackwardRunTest(test_util.TensorFlowTestCase):
# Due to the update to the value of a:0, the dumped intermediate a/read:0
# should have been invalidated.
self.assertNotIn("a/read:0", stepper.intermediate_tensor_names())
- # ["b/read:0", "c/read:0"],
+ self.assertSetEqual({"a:0"}, stepper.last_updated())
self.assertEqual({"a:0"}, stepper.dirty_variables())
# Now, run update_b.
@@ -942,6 +960,7 @@ class StepperBackwardRunTest(test_util.TensorFlowTestCase):
result = stepper.cont("d:0")
self.assertAllClose(2.0, result)
self.assertEqual({}, stepper.last_feed_types())
+ self.assertSetEqual(set(), stepper.last_updated())
self.assertEqual(set(), stepper.dirty_variables())
self.assertEqual(["d:0"], stepper.handle_names())
self.assertSetEqual({"d"}, stepper.handle_node_names())