diff options
-rw-r--r-- | tensorflow/python/debug/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/python/debug/cli/debugger_cli_common.py | 1 | ||||
-rw-r--r-- | tensorflow/python/debug/cli/stepper_cli.py | 90 | ||||
-rw-r--r-- | tensorflow/python/debug/cli/stepper_cli_test.py | 485 | ||||
-rw-r--r-- | tensorflow/python/debug/stepper.py | 228 | ||||
-rw-r--r-- | tensorflow/python/debug/stepper_test.py | 1394 | ||||
-rw-r--r-- | tensorflow/python/debug/wrappers/framework.py | 7 | ||||
-rw-r--r-- | tensorflow/python/debug/wrappers/hooks.py | 11 |
8 files changed, 1307 insertions, 910 deletions
diff --git a/tensorflow/python/debug/BUILD b/tensorflow/python/debug/BUILD index 22eb840791..339a6a72e0 100644 --- a/tensorflow/python/debug/BUILD +++ b/tensorflow/python/debug/BUILD @@ -52,6 +52,7 @@ py_library( srcs_version = "PY2AND3", deps = [ ":debug_data", + ":debug_utils", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:session_ops", "@six_archive//:six", diff --git a/tensorflow/python/debug/cli/debugger_cli_common.py b/tensorflow/python/debug/cli/debugger_cli_common.py index f456a2387c..f4274d958d 100644 --- a/tensorflow/python/debug/cli/debugger_cli_common.py +++ b/tensorflow/python/debug/cli/debugger_cli_common.py @@ -90,6 +90,7 @@ class RichLine(object): ret = RichLine() if isinstance(other, str): ret.text = self.text + other + ret.font_attr_segs = self.font_attr_segs[:] return ret elif isinstance(other, RichLine): ret.text = self.text + other.text diff --git a/tensorflow/python/debug/cli/stepper_cli.py b/tensorflow/python/debug/cli/stepper_cli.py index 377fcc0387..bb76c440bc 100644 --- a/tensorflow/python/debug/cli/stepper_cli.py +++ b/tensorflow/python/debug/cli/stepper_cli.py @@ -41,7 +41,7 @@ class NodeStepperCLI(object): STATE_CONT = "H" # State where an intermediate dump of the tensor is available. - STATE_INTERMEDIATE = "I" + STATE_DUMPED_INTERMEDIATE = "I" # State where the element is already overridden. STATE_OVERRIDDEN = "O" @@ -53,6 +53,8 @@ class NodeStepperCLI(object): # this NodeStepperCLI instance. STATE_DIRTY_VARIABLE = "D" + STATE_UNFEEDABLE = "U" + NEXT_NODE_POINTER_STR = "-->" _MESSAGE_TEMPLATES = { @@ -63,6 +65,15 @@ class NodeStepperCLI(object): "Please use full tensor name.", } + _STATE_COLORS = { + STATE_CONT: "green", + STATE_DIRTY_VARIABLE: "magenta", + STATE_DUMPED_INTERMEDIATE: "blue", + STATE_OVERRIDDEN: "yellow", + STATE_IS_PLACEHOLDER: "cyan", + STATE_UNFEEDABLE: "red", + } + def __init__(self, node_stepper): self._node_stepper = node_stepper @@ -98,6 +109,14 @@ class NodeStepperCLI(object): type=str, help="Name of the Tensor or Op to continue to.") ap.add_argument( + "-i", + "--invalidate_from_updated_variables", + dest="invalidate_from_updated_variables", + action="store_true", + help="Whether to invalidate the cached " + "tensor handles and intermediate tensor handles affected " + "by Variable updates in this continue call.") + ap.add_argument( "-r", "--restore_variable_values", dest="restore_variable_values", @@ -211,6 +230,7 @@ class NodeStepperCLI(object): verbose = True handle_node_names = self._node_stepper.handle_node_names() + intermediate_tensor_names = self._node_stepper.intermediate_tensor_names() override_names = self._node_stepper.override_names() dirty_variable_names = [ dirty_variable.split(":")[0] @@ -242,6 +262,7 @@ class NodeStepperCLI(object): labels, label_font_attr_segs = self._get_status_labels( element_name, handle_node_names, + intermediate_tensor_names, override_names, dirty_variable_names, len(node_prefix)) @@ -262,6 +283,7 @@ class NodeStepperCLI(object): def _get_status_labels(self, element_name, handle_node_names, + intermediate_tensor_names, override_names, dirty_variable_names, offset): @@ -278,6 +300,7 @@ class NodeStepperCLI(object): element_name: (str) name of the graph element. handle_node_names: (list of str) Names of the nodes of which the output tensors' handles are available. + intermediate_tensor_names: (list of str) TOOD(cais): document. override_names: (list of str) Names of the tensors of which the values are overridden. dirty_variable_names: (list of str) Names of the dirty variables. @@ -292,19 +315,31 @@ class NodeStepperCLI(object): status = RL(" " * offset) node_name = element_name.split(":")[0] - status += RL("P", "cyan") if node_name in self._placeholders else " " - status += (RL("U", "red") + status += (RL(self.STATE_IS_PLACEHOLDER, + self._STATE_COLORS[self.STATE_IS_PLACEHOLDER]) + if node_name in self._placeholders else " ") + status += (RL(self.STATE_UNFEEDABLE, + self._STATE_COLORS[self.STATE_UNFEEDABLE]) if not self._node_stepper.is_feedable(str(element_name)) else " ") - status += (RL("H", "green") if element_name in handle_node_names else " ") + status += (RL(self.STATE_CONT, self._STATE_COLORS[self.STATE_CONT]) + if element_name in handle_node_names else " ") + + intermediate_node_names = [ + tensor_name.split(":")[0] for tensor_name in intermediate_tensor_names] + status += (RL(self.STATE_DUMPED_INTERMEDIATE, + self._STATE_COLORS[self.STATE_DUMPED_INTERMEDIATE]) + if element_name in intermediate_node_names else " ") slots = self._node_stepper.output_slots_in_closure(element_name) has_override = any(element_name + ":%d" % slot in override_names for slot in slots) - status += RL("O", "yellow") if has_override else " " - status += (RL(self.STATE_DIRTY_VARIABLE, "magenta") - if element_name in dirty_variable_names - else " ") + status += (RL(self.STATE_OVERRIDDEN, + self._STATE_COLORS[self.STATE_OVERRIDDEN]) + if has_override else " ") + status += (RL(self.STATE_DIRTY_VARIABLE, + self._STATE_COLORS[self.STATE_DIRTY_VARIABLE]) + if element_name in dirty_variable_names else " ") # TODO(ebreck) Return status here, once the caller is updated with the # RichLine API. @@ -320,14 +355,30 @@ class NodeStepperCLI(object): return debugger_cli_common.rich_text_lines_from_rich_line_list([ RL(""), RL("Legend:"), - RL(" ") + RL("P", "cyan") + " - Placeholder", - RL(" ") + RL("U", "red") + " - Unfeedable", - (RL(" ") + RL("H", "green") + + (RL(" ") + + RL(self.STATE_IS_PLACEHOLDER, + self._STATE_COLORS[self.STATE_IS_PLACEHOLDER]) + + " - Placeholder"), + (RL(" ") + + RL(self.STATE_UNFEEDABLE, + self._STATE_COLORS[self.STATE_UNFEEDABLE]) + + " - Unfeedable"), + (RL(" ") + + RL(self.STATE_CONT, + self._STATE_COLORS[self.STATE_CONT]) + " - Already continued-to; Tensor handle available from output " "slot(s)"), - (RL(" ") + RL("O", "yellow") + + (RL(" ") + + RL(self.STATE_DUMPED_INTERMEDIATE, + self._STATE_COLORS[self.STATE_DUMPED_INTERMEDIATE]) + + " - Unfeedable"), + (RL(" ") + + RL(self.STATE_OVERRIDDEN, + self._STATE_COLORS[self.STATE_OVERRIDDEN]) + " - Has overriding (injected) tensor value"), - (RL(" ") + RL("D", "magenta") + + (RL(" ") + + RL(self.STATE_DIRTY_VARIABLE, + self._STATE_COLORS[self.STATE_DIRTY_VARIABLE]) + " - Dirty variable: Variable already updated this node stepper.")]) def cont(self, args, screen_info=None): @@ -347,6 +398,8 @@ class NodeStepperCLI(object): cont_result = self._node_stepper.cont( parsed.target_name, + invalidate_from_updated_variables=( + parsed.invalidate_from_updated_variables), restore_variable_values=parsed.restore_variable_values) self._completed_nodes.add(parsed.target_name.split(":")[0]) @@ -363,8 +416,16 @@ class NodeStepperCLI(object): lines.append(feed_info_line) if feed_types[feed_name] == stepper.NodeStepper.FEED_TYPE_HANDLE: font_attr_segs[line_counter] = [ - (len(feed_name) + 2, len(feed_info_line), "green") + (len(feed_name) + 2, + len(feed_info_line), + self._STATE_COLORS[self.STATE_UNFEEDABLE]) ] + elif (feed_types[feed_name] == + stepper.NodeStepper.FEED_TYPE_DUMPED_INTERMEDIATE): + font_attr_segs[line_counter] = [( + len(feed_name) + 2, + len(feed_info_line), + self._STATE_COLORS[self.STATE_DUMPED_INTERMEDIATE])] elif feed_types[feed_name] == stepper.NodeStepper.FEED_TYPE_OVERRIDE: font_attr_segs[line_counter] = [ (len(feed_name) + 2, len(feed_info_line), "yellow") @@ -542,4 +603,3 @@ class NodeStepperCLI(object): return [(element_name + ":%d" % slot) for slot in slots] else: return [] - diff --git a/tensorflow/python/debug/cli/stepper_cli_test.py b/tensorflow/python/debug/cli/stepper_cli_test.py index 78c56697ae..0dd4493c95 100644 --- a/tensorflow/python/debug/cli/stepper_cli_test.py +++ b/tensorflow/python/debug/cli/stepper_cli_test.py @@ -140,310 +140,339 @@ class NodeStepperSimpleGraphTest(test_util.TensorFlowTestCase): self.assertLess(node_names.index("e"), node_names.index("f")) def testListingSortedNodesPresentsTransitveClosure(self): - cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.e)) + with stepper.NodeStepper(self.sess, self.e) as node_stepper: + cli = stepper_cli.NodeStepperCLI(node_stepper) - output = cli.list_sorted_nodes([]) - node_names, stat_labels, node_pointer = _parse_sorted_nodes_list( - output.lines) + output = cli.list_sorted_nodes([]) + node_names, stat_labels, node_pointer = _parse_sorted_nodes_list( + output.lines) - self._assert_nodes_topologically_sorted_with_target_e(node_names) - self.assertEqual(len(node_names), len(stat_labels)) - for stat_label in stat_labels: - self.assertEqual(" ", stat_label) - self.assertEqual(0, node_pointer) + self._assert_nodes_topologically_sorted_with_target_e(node_names) + self.assertEqual(len(node_names), len(stat_labels)) + for stat_label in stat_labels: + self.assertEqual(" ", stat_label) + self.assertEqual(0, node_pointer) def testListingSortedNodesLabelsPlaceholders(self): - cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.f)) + with stepper.NodeStepper(self.sess, self.f) as node_stepper: + cli = stepper_cli.NodeStepperCLI(node_stepper) - output = cli.list_sorted_nodes([]) - node_names, stat_labels, node_pointer = _parse_sorted_nodes_list( - output.lines) + output = cli.list_sorted_nodes([]) + node_names, stat_labels, node_pointer = _parse_sorted_nodes_list( + output.lines) - self._assert_nodes_topologically_sorted_with_target_f(node_names) + self._assert_nodes_topologically_sorted_with_target_f(node_names) - index_ph = node_names.index("ph") - self.assertEqual(len(node_names), len(stat_labels)) - for i in xrange(len(stat_labels)): - if index_ph == i: - self.assertIn(stepper_cli.NodeStepperCLI.STATE_IS_PLACEHOLDER, - stat_labels[i]) - else: - self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_IS_PLACEHOLDER, - stat_labels[i]) + index_ph = node_names.index("ph") + self.assertEqual(len(node_names), len(stat_labels)) + for i in xrange(len(stat_labels)): + if index_ph == i: + self.assertIn(stepper_cli.NodeStepperCLI.STATE_IS_PLACEHOLDER, + stat_labels[i]) + else: + self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_IS_PLACEHOLDER, + stat_labels[i]) - self.assertEqual(0, node_pointer) + self.assertEqual(0, node_pointer) def testContToNonexistentNodeShouldError(self): - cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.f)) + with stepper.NodeStepper(self.sess, self.f) as node_stepper: + cli = stepper_cli.NodeStepperCLI(node_stepper) - output = cli.cont(["foobar"]) - self.assertEqual([ - "ERROR: foobar is not in the transitive closure of this stepper " - "instance." - ], output.lines) + output = cli.cont(["foobar"]) + self.assertEqual([ + "ERROR: foobar is not in the transitive closure of this stepper " + "instance." + ], output.lines) def testContToNodeOutsideTransitiveClosureShouldError(self): - cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.e)) + with stepper.NodeStepper(self.sess, self.e) as node_stepper: + cli = stepper_cli.NodeStepperCLI(node_stepper) - output = cli.cont(["f"]) - self.assertEqual([ - "ERROR: f is not in the transitive closure of this stepper " - "instance." - ], output.lines) + output = cli.cont(["f"]) + self.assertEqual([ + "ERROR: f is not in the transitive closure of this stepper " + "instance." + ], output.lines) def testContToValidNodeShouldUpdateStatus(self): - cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.e)) + with stepper.NodeStepper(self.sess, self.e) as node_stepper: + cli = stepper_cli.NodeStepperCLI(node_stepper) - output = cli.list_sorted_nodes([]) - node_names, stat_labels, node_pointer = _parse_sorted_nodes_list( - output.lines) + output = cli.list_sorted_nodes([]) + node_names, stat_labels, node_pointer = _parse_sorted_nodes_list( + output.lines) - index_c = node_names.index("c") - self.assertEqual(" ", stat_labels[index_c]) - self.assertEqual(0, node_pointer) + index_c = node_names.index("c") + self.assertEqual(" ", stat_labels[index_c]) + self.assertEqual(0, node_pointer) - output = cli.cont("c") - node_names, stat_labels, node_pointer = _parse_sorted_nodes_list( - output.lines) + output = cli.cont("c") + node_names, stat_labels, node_pointer = _parse_sorted_nodes_list( + output.lines) - self.assertGreaterEqual(len(node_names), 3) - self.assertIn("c", node_names) - index_c = node_names.index("c") - self.assertEqual(index_c, node_pointer) - self.assertIn(stepper_cli.NodeStepperCLI.STATE_CONT, stat_labels[index_c]) + self.assertGreaterEqual(len(node_names), 3) + self.assertIn("c", node_names) + index_c = node_names.index("c") + self.assertEqual(index_c, node_pointer) + self.assertIn(stepper_cli.NodeStepperCLI.STATE_CONT, stat_labels[index_c]) - output = cli.cont("d") - node_names, stat_labels, node_pointer = _parse_sorted_nodes_list( - output.lines) + output = cli.cont("d") + node_names, stat_labels, node_pointer = _parse_sorted_nodes_list( + output.lines) - used_feed_types = _parsed_used_feeds(output.lines) - self.assertEqual({"c:0": "handle"}, used_feed_types) + used_feed_types = _parsed_used_feeds(output.lines) + self.assertEqual({ + "c:0": stepper.NodeStepper.FEED_TYPE_HANDLE, + "a/read:0": stepper.NodeStepper.FEED_TYPE_DUMPED_INTERMEDIATE, + }, used_feed_types) - self.assertGreaterEqual(len(node_names), 3) - self.assertIn("d", node_names) - index_d = node_names.index("d") - self.assertEqual(index_d, node_pointer) - self.assertIn(stepper_cli.NodeStepperCLI.STATE_CONT, stat_labels[index_d]) + self.assertGreaterEqual(len(node_names), 3) + self.assertIn("d", node_names) + index_d = node_names.index("d") + self.assertEqual(index_d, node_pointer) + self.assertIn(stepper_cli.NodeStepperCLI.STATE_CONT, stat_labels[index_d]) def testSteppingOneStepAtATimeShouldUpdateStatus(self): - cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.e)) + with stepper.NodeStepper(self.sess, self.e) as node_stepper: + cli = stepper_cli.NodeStepperCLI(node_stepper) - output = cli.list_sorted_nodes([]) - orig_node_names, _, node_pointer = _parse_sorted_nodes_list(output.lines) - self.assertEqual(0, node_pointer) - - for i in xrange(len(orig_node_names)): + output = cli.list_sorted_nodes([]) + orig_node_names, _, node_pointer = _parse_sorted_nodes_list(output.lines) + self.assertEqual(0, node_pointer) + + for i in xrange(len(orig_node_names)): + output = cli.step([]) + node_names, stat_labels, node_pointer = _parse_sorted_nodes_list( + output.lines) + + next_node_name = node_names[node_pointer] + self.assertEqual(orig_node_names[i], next_node_name) + + self.assertIn(stepper_cli.NodeStepperCLI.STATE_CONT, + stat_labels[node_pointer]) + + # The order in which the nodes are listed should not change as the + # stepping happens. + output = cli.list_sorted_nodes([]) + node_names, _, node_pointer = _parse_sorted_nodes_list(output.lines) + self.assertEqual(orig_node_names, node_names) + + if i < len(orig_node_names) - 1: + self.assertEqual(i + 1, node_pointer) + else: + # Stepped over the limit. Pointer should be at -1. + self.assertEqual(-1, node_pointer) + + # Attempt to step once more after the end has been reached should error + # out. output = cli.step([]) + self.assertEqual([ + "ERROR: Cannot step any further because the end of the sorted " + "transitive closure has been reached." + ], output.lines) + + def testSteppingMultipleStepsUpdatesStatus(self): + with stepper.NodeStepper(self.sess, self.e) as node_stepper: + cli = stepper_cli.NodeStepperCLI(node_stepper) + + output = cli.list_sorted_nodes([]) + orig_node_names, _, _ = _parse_sorted_nodes_list(output.lines) + + output = cli.step(["-t", "3"]) node_names, stat_labels, node_pointer = _parse_sorted_nodes_list( output.lines) - next_node_name = node_names[node_pointer] - self.assertEqual(orig_node_names[i], next_node_name) + self.assertEqual(orig_node_names[2], node_names[node_pointer]) - self.assertIn(stepper_cli.NodeStepperCLI.STATE_CONT, - stat_labels[node_pointer]) + for i in xrange(node_pointer): + self.assertIn(stepper_cli.NodeStepperCLI.STATE_CONT, stat_labels[i]) - # The order in which the nodes are listed should not change as the - # stepping happens. - output = cli.list_sorted_nodes([]) - node_names, _, node_pointer = _parse_sorted_nodes_list(output.lines) - self.assertEqual(orig_node_names, node_names) + for i in xrange(node_pointer + 1, len(stat_labels)): + self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_CONT, stat_labels[i]) - if i < len(orig_node_names) - 1: - self.assertEqual(i + 1, node_pointer) - else: - # Stepped over the limit. Pointer should be at -1. - self.assertEqual(-1, node_pointer) + def testContToNodeWithoutOutputTensorInClosureShowsNoHandleCached(self): + with stepper.NodeStepper(self.sess, self.opt) as node_stepper: + sorted_nodes = node_stepper.sorted_nodes() + closure_elements = node_stepper.closure_elements() + + # Find a node which is in the list of sorted nodes, but whose output + # Tensor is not in the transitive closure. + no_output_node = None + for node in sorted_nodes: + if (node + ":0" not in closure_elements and + node + ":1" not in closure_elements): + no_output_node = node + break + + self.assertIsNotNone(no_output_node) + + cli = stepper_cli.NodeStepperCLI(node_stepper) + output = cli.cont([no_output_node]) + node_names, stat_labels, node_pointer = _parse_sorted_nodes_list( + output.lines) - # Attempt to step once more after the end has been reached should error out. - output = cli.step([]) - self.assertEqual([ - "ERROR: Cannot step any further because the end of the sorted " - "transitive closure has been reached." - ], output.lines) + self.assertEqual(no_output_node, node_names[node_pointer]) + self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_CONT, + stat_labels[node_pointer]) - def testSteppingMultipleStepsUpdatesStatus(self): - cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.e)) + def testContToUpdateNodeWithTrackingLeadsToDirtyVariableLabel(self): + with stepper.NodeStepper(self.sess, self.opt) as node_stepper: + cli = stepper_cli.NodeStepperCLI(node_stepper) + output = cli.cont(["opt/update_b/ApplyGradientDescent", "-i"]) - output = cli.list_sorted_nodes([]) - orig_node_names, _, _ = _parse_sorted_nodes_list(output.lines) + output = cli.list_sorted_nodes([]) + node_names, stat_labels, _ = _parse_sorted_nodes_list(output.lines) + self.assertIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE, + stat_labels[node_names.index("b")]) + self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE, + stat_labels[node_names.index("a")]) - output = cli.step(["-t", "3"]) - node_names, stat_labels, node_pointer = _parse_sorted_nodes_list( - output.lines) + def testContToUpdateNodeWithoutTrackingLeadsToNoDirtyVariableLabel(self): + with stepper.NodeStepper(self.sess, self.opt) as node_stepper: + cli = stepper_cli.NodeStepperCLI(node_stepper) + output = cli.cont(["opt/update_b/ApplyGradientDescent"]) - self.assertEqual(orig_node_names[2], node_names[node_pointer]) + output = cli.list_sorted_nodes([]) + node_names, stat_labels, _ = _parse_sorted_nodes_list(output.lines) + self.assertIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE, + stat_labels[node_names.index("b")]) + self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE, + stat_labels[node_names.index("a")]) - for i in xrange(node_pointer): - self.assertIn(stepper_cli.NodeStepperCLI.STATE_CONT, stat_labels[i]) + def testContWithRestoreVariablesOptionShouldRestoreVariableValue(self): + with stepper.NodeStepper(self.sess, self.opt) as node_stepper: + cli = stepper_cli.NodeStepperCLI(node_stepper) + output = cli.cont(["opt/update_a/ApplyGradientDescent", + "--invalidate_from_updated_variables"]) - for i in xrange(node_pointer + 1, len(stat_labels)): - self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_CONT, stat_labels[i]) + # After cont() call on .../update_a/..., Variable a should have been + # marked as dirty, whereas b should not have. + output = cli.list_sorted_nodes([]) + node_names, stat_labels, _ = _parse_sorted_nodes_list(output.lines) + self.assertIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE, + stat_labels[node_names.index("a")]) + self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE, + stat_labels[node_names.index("b")]) - def testContToNodeWithoutOutputTensorInClosureShowsNoHandleCached(self): - node_stepper = stepper.NodeStepper(self.sess, self.opt) - - sorted_nodes = node_stepper.sorted_nodes() - closure_elements = node_stepper.closure_elements() - - # Find a node which is in the list of sorted nodes, but whose output tensor - # is not in the transitive closure. - no_output_node = None - for node in sorted_nodes: - if (node + ":0" not in closure_elements and - node + ":1" not in closure_elements): - no_output_node = node - break - - self.assertIsNotNone(no_output_node) - - cli = stepper_cli.NodeStepperCLI(node_stepper) - output = cli.cont([no_output_node]) - node_names, stat_labels, node_pointer = _parse_sorted_nodes_list( - output.lines) - - self.assertEqual(no_output_node, node_names[node_pointer]) - self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_CONT, - stat_labels[node_pointer]) - - def testContToUpdateNodeLeadsToDirtyVariableLabel(self): - cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.opt)) - output = cli.cont(["opt/update_b/ApplyGradientDescent"]) - - output = cli.list_sorted_nodes([]) - node_names, stat_labels, _ = _parse_sorted_nodes_list(output.lines) - self.assertIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE, - stat_labels[node_names.index("b")]) - self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE, - stat_labels[node_names.index("a")]) + output = cli.cont(["opt/update_b/ApplyGradientDescent", "-r", "-i"]) - def testContWithRestoreVariablesOptionShouldRestoreVariableValue(self): - cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.opt)) - output = cli.cont(["opt/update_a/ApplyGradientDescent"]) - - # After cont() call on .../update_a/..., Variable a should have been marked - # as dirty, whereas b should not have. - output = cli.list_sorted_nodes([]) - node_names, stat_labels, _ = _parse_sorted_nodes_list(output.lines) - self.assertIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE, - stat_labels[node_names.index("a")]) - self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE, - stat_labels[node_names.index("b")]) - - output = cli.cont(["opt/update_b/ApplyGradientDescent", "-r"]) - - # After cont() call on .../update_b/... with the -r flag, Variable b should - # have been marked as dirty, whereas Variable a should not be because it - # should have been restored. - output = cli.list_sorted_nodes([]) - node_names, stat_labels, _ = _parse_sorted_nodes_list(output.lines) - self.assertIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE, - stat_labels[node_names.index("b")]) - self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE, - stat_labels[node_names.index("a")]) + # After cont() call on .../update_b/... with the -r flag, Variable b + # should have been marked as dirty, whereas Variable a should not be + # because it should have been restored. + output = cli.list_sorted_nodes([]) + node_names, stat_labels, _ = _parse_sorted_nodes_list(output.lines) + self.assertIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE, + stat_labels[node_names.index("b")]) + self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE, + stat_labels[node_names.index("a")]) def testPrintTensorShouldWorkWithTensorName(self): - cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.e)) + with stepper.NodeStepper(self.sess, self.e) as node_stepper: + cli = stepper_cli.NodeStepperCLI(node_stepper) - cli.cont("d") - output = cli.print_tensor(["d:0"]) + cli.cont("d") + output = cli.print_tensor(["d:0"]) - self.assertEqual("Tensor \"d:0\":", output.lines[0]) - self.assertEqual("-20.0", output.lines[-1]) + self.assertEqual("Tensor \"d:0\":", output.lines[0]) + self.assertEqual("-20.0", output.lines[-1]) def testPrintTensorShouldWorkWithNodeNameWithOutputTensor(self): - cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.e)) + with stepper.NodeStepper(self.sess, self.e) as node_stepper: + cli = stepper_cli.NodeStepperCLI(node_stepper) - cli.cont("d") - output = cli.print_tensor(["d"]) + cli.cont("d") + output = cli.print_tensor(["d"]) - self.assertEqual("Tensor \"d:0\":", output.lines[0]) - self.assertEqual("-20.0", output.lines[-1]) + self.assertEqual("Tensor \"d:0\":", output.lines[0]) + self.assertEqual("-20.0", output.lines[-1]) def testPrintTensorShouldWorkSlicingString(self): ph_value = np.array([[1.0, 0.0], [0.0, 2.0]]) - cli = stepper_cli.NodeStepperCLI( - stepper.NodeStepper( - self.sess, self.f, feed_dict={self.ph: ph_value})) + with stepper.NodeStepper( + self.sess, self.f, feed_dict={self.ph: ph_value}) as node_stepper: + cli = stepper_cli.NodeStepperCLI(node_stepper) - output = cli.print_tensor(["ph:0[:, 1]"]) - self.assertEqual("Tensor \"ph:0[:, 1]\":", output.lines[0]) - self.assertEqual(repr(ph_value[:, 1]), output.lines[-1]) + output = cli.print_tensor(["ph:0[:, 1]"]) + self.assertEqual("Tensor \"ph:0[:, 1]\":", output.lines[0]) + self.assertEqual(repr(ph_value[:, 1]), output.lines[-1]) - output = cli.print_tensor(["ph[:, 1]"]) - self.assertEqual("Tensor \"ph:0[:, 1]\":", output.lines[0]) - self.assertEqual(repr(ph_value[:, 1]), output.lines[-1]) + output = cli.print_tensor(["ph[:, 1]"]) + self.assertEqual("Tensor \"ph:0[:, 1]\":", output.lines[0]) + self.assertEqual(repr(ph_value[:, 1]), output.lines[-1]) def testPrintTensorWithNonexistentTensorShouldError(self): - cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.e)) + with stepper.NodeStepper(self.sess, self.e) as node_stepper: + cli = stepper_cli.NodeStepperCLI(node_stepper) - output = cli.print_tensor(["foobar"]) - self.assertEqual([ - "ERROR: foobar is not in the transitive closure of this stepper " - "instance." - ], output.lines) + output = cli.print_tensor(["foobar"]) + self.assertEqual([ + "ERROR: foobar is not in the transitive closure of this stepper " + "instance." + ], output.lines) def testPrintTensorWithNoHandleShouldError(self): - cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.e)) + with stepper.NodeStepper(self.sess, self.e) as node_stepper: + cli = stepper_cli.NodeStepperCLI(node_stepper) - output = cli.print_tensor("e") - self.assertEqual([ - "This stepper instance does not have access to the value of tensor " - "\"e:0\"" - ], output.lines) + output = cli.print_tensor("e") + self.assertEqual([ + "This stepper instance does not have access to the value of tensor " + "\"e:0\"" + ], output.lines) def testInjectTensorValueByTensorNameShouldBeReflected(self): - node_stepper = stepper.NodeStepper(self.sess, self.e) - cli = stepper_cli.NodeStepperCLI(node_stepper) + with stepper.NodeStepper(self.sess, self.e) as node_stepper: + cli = stepper_cli.NodeStepperCLI(node_stepper) - output = cli.cont(["d"]) - node_names, _, node_pointer = _parse_sorted_nodes_list(output.lines) - self.assertEqual("d", node_names[node_pointer]) + output = cli.cont(["d"]) + node_names, _, node_pointer = _parse_sorted_nodes_list(output.lines) + self.assertEqual("d", node_names[node_pointer]) - output = cli.list_sorted_nodes([]) - node_names, stat_labels, node_pointer = _parse_sorted_nodes_list( - output.lines) + output = cli.list_sorted_nodes([]) + node_names, stat_labels, node_pointer = _parse_sorted_nodes_list( + output.lines) - index_d = node_names.index("d") - self.assertIn(stepper_cli.NodeStepperCLI.STATE_CONT, stat_labels[index_d]) - self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_OVERRIDDEN, - stat_labels[index_d]) + index_d = node_names.index("d") + self.assertIn(stepper_cli.NodeStepperCLI.STATE_CONT, stat_labels[index_d]) + self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_OVERRIDDEN, + stat_labels[index_d]) - self.assertAllClose(-20.0, node_stepper.get_tensor_value("d:0")) + self.assertAllClose(-20.0, node_stepper.get_tensor_value("d:0")) - output = cli.inject_value(["d:0", "20.0"]) + output = cli.inject_value(["d:0", "20.0"]) - # Verify that the override is available. - self.assertEqual(["d:0"], node_stepper.override_names()) + # Verify that the override is available. + self.assertEqual(["d:0"], node_stepper.override_names()) - # Verify that the list of sorted nodes reflects the existence of the value - # override (i.e., injection). - output = cli.list_sorted_nodes([]) - node_names, stat_labels, node_pointer = _parse_sorted_nodes_list( - output.lines) + # Verify that the list of sorted nodes reflects the existence of the value + # override (i.e., injection). + output = cli.list_sorted_nodes([]) + node_names, stat_labels, node_pointer = _parse_sorted_nodes_list( + output.lines) - index_d = node_names.index("d") - self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_CONT, - stat_labels[index_d]) - self.assertIn(stepper_cli.NodeStepperCLI.STATE_OVERRIDDEN, - stat_labels[index_d]) + index_d = node_names.index("d") + self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_CONT, + stat_labels[index_d]) + self.assertIn(stepper_cli.NodeStepperCLI.STATE_OVERRIDDEN, + stat_labels[index_d]) def testInjectTensorValueByNodeNameShouldBeReflected(self): - node_stepper = stepper.NodeStepper(self.sess, self.e) - cli = stepper_cli.NodeStepperCLI(node_stepper) + with stepper.NodeStepper(self.sess, self.e) as node_stepper: + cli = stepper_cli.NodeStepperCLI(node_stepper) - cli.inject_value(["d", "20.0"]) - self.assertEqual(["d:0"], node_stepper.override_names()) + cli.inject_value(["d", "20.0"]) + self.assertEqual(["d:0"], node_stepper.override_names()) def testInjectToNonexistentTensorShouldError(self): - node_stepper = stepper.NodeStepper(self.sess, self.e) - cli = stepper_cli.NodeStepperCLI(node_stepper) - - output = cli.inject_value(["foobar:0", "20.0"]) - self.assertEqual([ - "ERROR: foobar:0 is not in the transitive closure of this stepper " - "instance." - ], output.lines) + with stepper.NodeStepper(self.sess, self.e) as node_stepper: + cli = stepper_cli.NodeStepperCLI(node_stepper) + + output = cli.inject_value(["foobar:0", "20.0"]) + self.assertEqual([ + "ERROR: foobar:0 is not in the transitive closure of this stepper " + "instance." + ], output.lines) if __name__ == "__main__": diff --git a/tensorflow/python/debug/stepper.py b/tensorflow/python/debug/stepper.py index fb5e972627..3cbe83d072 100644 --- a/tensorflow/python/debug/stepper.py +++ b/tensorflow/python/debug/stepper.py @@ -18,11 +18,16 @@ from __future__ import division from __future__ import print_function import copy +import os +import shutil +import tempfile +import time import six from tensorflow.core.protobuf import config_pb2 from tensorflow.python.debug import debug_data +from tensorflow.python.debug import debug_utils from tensorflow.python.framework import ops from tensorflow.python.ops import session_ops @@ -64,9 +69,17 @@ class NodeStepper(object): tree of the target. When it reaches an input where one of the following is available, it will supply the available value to the feed_dict of the cont() call: - (1) TensorHandles from previous cont() calls. - (2) Overriding (injected) values from the client. - (3) Feeds supplied during the construction of the stepper instance. + (1) Overriding (injected) values from the client. + (2) TensorHandles from previous cont() calls. + (3) Dumped intermediate Tensors from previous cont() calls. + (4) Feeds supplied during the construction of the stepper instance. + + During the cont() call, intermediate Tensors are dumped to temporary + directories. The dumped Tensor values will be used in subsequent cont() calls + when they are required as data dependencies. + + The temporary directories are automatically clean when the NodeStepper + instance exits as a context mananger. Once the tracing is complete, it will issue a run() call on the underlying session, using the aforementioned feed_dict prepared by the input @@ -95,10 +108,7 @@ class NodeStepper(object): FEED_TYPE_CLIENT = "client" FEED_TYPE_HANDLE = "handle" FEED_TYPE_OVERRIDE = "override" - - # TODO(cais): The following member constant is currently unused. Use it when - # the stepper is capable of using dumped intermediate tensors. - FEED_TYPE_INTERMEDIATE = "intermediate" + FEED_TYPE_DUMPED_INTERMEDIATE = "dumped_intermediate" def __init__(self, sess, fetches, feed_dict=None): """Constructor for Debugger. @@ -125,11 +135,15 @@ class NodeStepper(object): self._variable_initial_values = {} # Initialize the map for output recipients (targets). - self._non_control_output_targets = {} + self._output_targets = {} # Sorted transitive closure of the fetched node. - self._sorted_nodes, self._closure_elements = self._dfs_visit( - self._sess.graph, self._fetch_list) + # We also collect the list of the names of the reference-type Tensors, + # because we later need to avoid using intermediate dumps for such Tensors. + (self._sorted_nodes, + self._closure_elements, + self._ref_tensor_names) = self._dfs_visit(self._sess.graph, + self._fetch_list) self._transitive_closure_set = set(self._sorted_nodes) @@ -146,6 +160,11 @@ class NodeStepper(object): # tensor handles. self._tensor_handles = {} + # Cached intermediate tensor values: a dict mapping tensor names to + # DebugTensorDatum. + self._dumped_intermediate_tensors = {} + self._dump_session_root = tempfile.mkdtemp(prefix="tfdbg_stepper_") + # Feed dict from the client. self._client_feed_dict = {} if feed_dict: @@ -161,6 +180,13 @@ class NodeStepper(object): # What the feed types were used by the last cont() call. self._last_feed_types = {} + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, exc_traceback): + if os.path.isdir(self._dump_session_root): + shutil.rmtree(self._dump_session_root) + def _get_fetch_and_name_lists(self, flattened_fetches): """Get the lists of fetches and their names. @@ -220,6 +246,11 @@ class NodeStepper(object): # Graph elements in the transitive closure, including the nodes and tensors. closure_elements = [elem.name for elem in elem_list] + ref_tensor_names = set() + for element in elem_list: + if isinstance(element, ops.Tensor) and element.dtype._is_ref_dtype: # pylint: disable=protected-access + ref_tensor_names.add(element.name) + while elem_stack: curr_elem = elem_stack.pop() curr_node = self._get_node(curr_elem) @@ -238,22 +269,21 @@ class NodeStepper(object): # Iterate through the (non-control) inputs. for inp in all_inputs: - is_non_control_input = inp in non_control_inputs - # Set up the non-control output map. - if is_non_control_input: - if inp.name not in self._non_control_output_targets: - self._non_control_output_targets[inp.name] = set([curr_elem.name]) - else: - self._non_control_output_targets[inp.name].add(curr_elem.name) - - if (inp.op.type in ["Variable", "VariableV2"] and - inp.name not in self._variable_initializers): - # Obtain the initializer op of the variable, in case the Variable's - # value needs to be restored later. - initializer = graph.as_graph_element(inp.op.name + "/Assign") - self._variable_initializers[inp.name] = initializer - self._variable_initial_values[inp.name] = initializer.inputs[1] + # if is_non_control_input: + if inp.name not in self._output_targets: + self._output_targets[inp.name] = set([curr_elem.name]) + else: + self._output_targets[inp.name].add(curr_elem.name) + + if (isinstance(inp, ops.Tensor) and + inp.op.type in ["Variable", "VariableV2"] and + inp.name not in self._variable_initializers): + # Obtain the initializer op of the variable, in case the Variable's + # value needs to be restored later. + initializer = graph.as_graph_element(inp.op.name + "/Assign") + self._variable_initializers[inp.name] = initializer + self._variable_initial_values[inp.name] = initializer.inputs[1] inp_node = self._get_node(inp) if inp_node.name in done: @@ -262,6 +292,8 @@ class NodeStepper(object): elem_stack.append(inp) closure_elements.append(inp.name) + if isinstance(inp, ops.Tensor) and inp.dtype._is_ref_dtype: # pylint: disable=protected-access + ref_tensor_names.add(inp.name) # Now that we have traversed the transitive closure and obtained the # node-input map, we can topologically sort them. @@ -291,7 +323,7 @@ class NodeStepper(object): stack.extend(pushes) - return sorted_nodes, closure_elements + return sorted_nodes, closure_elements, ref_tensor_names def sorted_nodes(self): """Get a topologically-sorted list of node names of the stepper. @@ -409,10 +441,15 @@ class NodeStepper(object): return self._last_feed_types + # TODO(cais): Add method last_updated_variables() to allow client to look + # up the Variables updated in the last cont() call. + def cont(self, target, use_tensor_handles=True, + use_dumped_intermediates=True, use_overrides=True, + invalidate_from_updated_variables=False, restore_variable_values=False): """Continue till the completion of the specified target tensor. @@ -423,8 +460,13 @@ class NodeStepper(object): # TODO(cais): Support multiple fetches as in Session.run() interface. use_tensor_handles: (bool) Whether this cont() run will use cached tensor handles to avoid recomputation. Default: True. + use_dumped_intermediates: (bool) Whether this cont() call will use dumped + intermediate tensors to avoid recomputation. use_overrides: (bool) Whether the overriding tensor values supplied by the client are to be used in this cont() call. Default: True. + invalidate_from_updated_variables: (bool) Whether to invalidate the + tensor handles and intermediate tensor handles affected by the + Variable updates that happen in this cont() call. restore_variable_values: (bool) Whether the old values of the variables (before any cont() calls in this object) are to be restored. @@ -441,9 +483,6 @@ class NodeStepper(object): self._last_feed_types = {} - # The feeds to be used in the Session.run() call. - feeds = {} - if isinstance(target, six.string_types): # Fetch target is a string. Assume it is the name of the Tensor or Op and # will attempt to find it in the Session's graph. @@ -494,6 +533,12 @@ class NodeStepper(object): self._last_feed_types[target_name] = self.FEED_TYPE_HANDLE return self._tensor_handles[target_name].eval() + # Check if a dumped intermediate tensor can be used on the fetch directly. + if (use_dumped_intermediates and + target_name in self._dumped_intermediate_tensors): + self._last_feed_types[target_name] = self.FEED_TYPE_DUMPED_INTERMEDIATE + return self._dumped_intermediate_tensors[target_name].get_tensor() + # Check if an overriding tensor value can be used directly. if use_overrides and target_name in self._override_tensors: # Override is available. Return the value right away. @@ -510,6 +555,7 @@ class NodeStepper(object): # ========================================================================= # Use a non-recursive method to trace the inputs from the node and set up # the feeds. + feeds = {} # The feeds to be used in the Session.run() call. fetched = self._sess.graph.as_graph_element(target_name) elem_stack = [fetched] done = set() @@ -529,7 +575,7 @@ class NodeStepper(object): # Determine whether the input is feedable. Reference-type tensors, # e.g., Variables, should not be fed, because they can change. if isinstance(inp, ops.Tensor): - is_inp_ref = inp.dtype._is_ref_dtype # pylint: disable=protected-access + is_inp_ref = inp.dtype._is_ref_dtype # pylint: disable=protected-access can_feed = self._sess.graph.is_feedable(inp) and not is_inp_ref else: is_inp_ref = False @@ -574,11 +620,17 @@ class NodeStepper(object): # Use client-supplied overriding tensor value. feeds[inp] = self._override_tensors[inp.name] self._last_feed_types[inp.name] = self.FEED_TYPE_OVERRIDE - elif (use_tensor_handles and can_feed and - inp.name in self._tensor_handles and inp not in feeds): + elif (can_feed and inp not in feeds and + use_tensor_handles and inp.name in self._tensor_handles): # Tensor handle found in cache. feeds[inp] = self._tensor_handles[inp.name].eval() self._last_feed_types[inp.name] = self.FEED_TYPE_HANDLE + elif (can_feed and inp not in feeds and + use_dumped_intermediates and + inp.name in self._dumped_intermediate_tensors): + # Dumped intermediate Tensor found. + feeds[inp] = self._dumped_intermediate_tensors[inp.name].get_tensor() + self._last_feed_types[inp.name] = self.FEED_TYPE_DUMPED_INTERMEDIATE elif inp.name in self._client_feed_dict: # This input is available in the client feed_dict. feeds[inp] = self._client_feed_dict[inp.name] @@ -602,14 +654,12 @@ class NodeStepper(object): for variable in restored_variables: self._dirty_variables.remove(variable) - # Prepare RunOptions for DebugTensorWatches - run_options = config_pb2.RunOptions() - # TODO(cais): Add fields for watching intermediate tensors. - + (dump_path, + run_options) = self._prepare_cont_call_dump_path_and_run_options() if isinstance(fetched, ops.Operation): # The fetched is an Operation: Will not get tensor handle. self._sess.run(fetched, feed_dict=feeds, options=run_options) - # No return value for a run of an Operation + return_value = None else: # This is a Tensor: Will get tensor handle and cache it. # Will also get the additional requested tensor handles (if any). @@ -622,17 +672,58 @@ class NodeStepper(object): ]) handle_names.extend(additional_handle_requests) - for handle_name, tensor in zip(handle_names, tensors_to_get_handles_for): - handle = self._sess.run(session_ops.get_session_handle(tensor), - feed_dict=feeds, - options=run_options) + handles = self._sess.run( + [session_ops.get_session_handle(tensor) for tensor in + tensors_to_get_handles_for], + feed_dict=feeds, + options=run_options) + for handle_name, handle in zip(handle_names, handles): self._tensor_handles[handle_name] = handle - return self._tensor_handles[target_name].eval() + return_value = self._tensor_handles[target_name].eval() + + self._load_dumped_intermediate_tensors(dump_path, target_name) - # Invalidate caches at the end. - for touched_variable in touched_variables: - self._invalidate_transitively_outgoing_cache(touched_variable) + if invalidate_from_updated_variables: + # Invalidate caches at the end. + for touched_variable in touched_variables: + self._invalidate_transitively_outgoing_cache(touched_variable) + + return return_value + + def _prepare_cont_call_dump_path_and_run_options(self): + """Prepare the dump path and RunOptions for next cont() call. + + Returns: + dump_path: (str) Directory path to which the intermediate tensor will be + dumped. + run_options: (config_pb2.RunOptions) The RunOptions containing the tensor + watch options for this graph. + """ + run_options = config_pb2.RunOptions() + dump_path = self._cont_call_dump_path() + for element_name in self._closure_elements: + if ":" in element_name: + debug_utils.add_debug_tensor_watch( + run_options, + debug_data.get_node_name(element_name), + output_slot=debug_data.get_output_slot(element_name), + debug_urls=["file://" + dump_path]) + + return dump_path, run_options + + def _cont_call_dump_path(self): + return os.path.join(self._dump_session_root, + "cont_%d" % int(time.time() * 1e6)) + + def _load_dumped_intermediate_tensors(self, dump_path, target_name): + dump_dir = debug_data.DebugDumpDir(dump_path, validate=False) + for dump in dump_dir.dumped_tensor_data: + if (dump.tensor_name not in self._ref_tensor_names and + dump.tensor_name not in self._tensor_handles and + dump.tensor_name not in self._override_tensors and + dump.tensor_name != target_name): + self._dumped_intermediate_tensors[dump.tensor_name] = dump def _get_node_name(self, graph_element_name): return graph_element_name.split(":")[0] @@ -646,29 +737,33 @@ class NodeStepper(object): Uses non-recursive implementation to avoid stack overflow on deep networks. - TODO(cais): Currently, only TensorHandle caches are invalidated. Invalidate - cached intermediate tensor values from dumps when dumps are added. - Args: source_element: The source graph element (e.g., a Variable output slot) to trace the output from. """ - if not self._tensor_handles: + if not self._tensor_handles and not self._dumped_intermediate_tensors: return # First, use cached invalidation paths to eliminate some cached tensor - # handles. - to_delete = [] + # handles and intermediate tensors. + to_delete_handles = [] for handle_name in self._tensor_handles: if (handle_name in self._cached_invalidation_path and source_element in self._cached_invalidation_path[handle_name]): - to_delete.append(handle_name) - - for handle_name in to_delete: + to_delete_handles.append(handle_name) + for handle_name in to_delete_handles: del self._tensor_handles[handle_name] - if not self._tensor_handles: + to_delete_intermediates = [] + for intm_tensor_name in self._dumped_intermediate_tensors: + if (intm_tensor_name in self._cached_invalidation_path and + source_element in self._cached_invalidation_path[intm_tensor_name]): + to_delete_intermediates.append(intm_tensor_name) + for intermediate in to_delete_intermediates: + del self._dumped_intermediate_tensors[intermediate] + + if not self._tensor_handles and not self._dumped_intermediate_tensors: return stack = [source_element] @@ -676,19 +771,22 @@ class NodeStepper(object): while stack: curr_element = stack.pop() - done.add(curr_element) - if curr_element in self._tensor_handles: + if (curr_element in self._tensor_handles or + curr_element in self._dumped_intermediate_tensors): # Cache the invalidation path for potential future use. if curr_element not in self._cached_invalidation_path: self._cached_invalidation_path[curr_element] = set([source_element]) else: self._cached_invalidation_path[curr_element].add(source_element) - del self._tensor_handles[curr_element] + if curr_element in self._tensor_handles: + del self._tensor_handles[curr_element] + else: + del self._dumped_intermediate_tensors[curr_element] - targets = self._non_control_output_targets.get(curr_element, []) + targets = self._output_targets.get(curr_element, []) for target in targets: if target in done: continue @@ -740,6 +838,16 @@ class NodeStepper(object): return set([self._get_node_name(name) for name in self._tensor_handles]) + def intermediate_tensor_names(self): + """Get list of the names of the Tensors for which dumps are available. + + Returns: + (list of str) List of the names of the Tensors for which intermediate + dumps are available. + """ + + return self._dumped_intermediate_tensors.keys() + def dirty_variables(self): """Get the set of variables that are currently "dirty". @@ -817,6 +925,8 @@ class NodeStepper(object): return self._override_tensors[tensor_name] elif tensor_name in self._tensor_handles: return self._tensor_handles[tensor_name].eval() + elif tensor_name in self._dumped_intermediate_tensors: + return self._dumped_intermediate_tensors[tensor_name].get_tensor() else: raise ValueError( "This stepper instance does not have access to the value of " diff --git a/tensorflow/python/debug/stepper_test.py b/tensorflow/python/debug/stepper_test.py index 20d79f3ade..63501b4fe6 100644 --- a/tensorflow/python/debug/stepper_test.py +++ b/tensorflow/python/debug/stepper_test.py @@ -25,6 +25,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import state_ops from tensorflow.python.ops import variables from tensorflow.python.platform import googletest from tensorflow.python.training import gradient_descent @@ -54,277 +55,310 @@ class StepperTest(test_util.TensorFlowTestCase): self.sess = session.Session() self.sess.run(variables.global_variables_initializer()) - self.sess = session.Session() - self.sess.run(variables.global_variables_initializer()) - def tearDown(self): ops.reset_default_graph() def testContToFetchNotInTransitiveClosureShouldError(self): - stepper = NodeStepper(self.sess, "e:0") - - sorted_nodes = stepper.sorted_nodes() - self.assertEqual(7, len(sorted_nodes)) - self.assertLess(sorted_nodes.index("a"), sorted_nodes.index("a/read")) - self.assertLess(sorted_nodes.index("b"), sorted_nodes.index("b/read")) - self.assertLess(sorted_nodes.index("a"), sorted_nodes.index("c")) - self.assertLess(sorted_nodes.index("b"), sorted_nodes.index("c")) - self.assertLess(sorted_nodes.index("a"), sorted_nodes.index("d")) - self.assertLess(sorted_nodes.index("d"), sorted_nodes.index("e")) - self.assertLess(sorted_nodes.index("c"), sorted_nodes.index("e")) - - self.assertSetEqual( - {"e:0", "d:0", "c:0", "a/read:0", "b/read:0", "b:0", "a:0"}, - set(stepper.closure_elements())) - - with self.assertRaisesRegexp( - ValueError, - "Target \"f:0\" is not in the transitive closure for the fetch of the " - "stepper"): - stepper.cont("f:0") - - def testContToNodeNameShouldReturnTensorvalue(self): - stepper = NodeStepper(self.sess, "e:0") - - cont_result = stepper.cont("c") - self.assertAllClose(6.0, cont_result) + with NodeStepper(self.sess, "e:0") as stepper: + sorted_nodes = stepper.sorted_nodes() + self.assertEqual(7, len(sorted_nodes)) + self.assertLess(sorted_nodes.index("a"), sorted_nodes.index("a/read")) + self.assertLess(sorted_nodes.index("b"), sorted_nodes.index("b/read")) + self.assertLess(sorted_nodes.index("a"), sorted_nodes.index("c")) + self.assertLess(sorted_nodes.index("b"), sorted_nodes.index("c")) + self.assertLess(sorted_nodes.index("a"), sorted_nodes.index("d")) + self.assertLess(sorted_nodes.index("d"), sorted_nodes.index("e")) + self.assertLess(sorted_nodes.index("c"), sorted_nodes.index("e")) - def testUsingNamesNotUsingIntermediateTensors(self): - stepper = NodeStepper(self.sess, "e:0") + self.assertSetEqual( + {"e:0", "d:0", "c:0", "a/read:0", "b/read:0", "b:0", "a:0"}, + set(stepper.closure_elements())) - # The first cont() call should have used no feeds. - result = stepper.cont("c:0") - self.assertAllClose(6.0, result) - self.assertEqual({}, stepper.last_feed_types()) + with self.assertRaisesRegexp( + ValueError, + "Target \"f:0\" is not in the transitive closure for the fetch of " + "the stepper"): + stepper.cont("f:0") - # The second cont() call should have used the tensor handle from the - # previous cont() call. - result = stepper.cont("e:0") - self.assertAllClose(24.0, result) - self.assertEqual({ - "c:0": NodeStepper.FEED_TYPE_HANDLE - }, stepper.last_feed_types()) + def testContToNodeNameShouldReturnTensorValue(self): + with NodeStepper(self.sess, "e:0") as stepper: + self.assertAllClose(6.0, stepper.cont("c")) - def testUsingNodesNotUsingIntermediateTensors(self): - stepper = NodeStepper(self.sess, self.e) + def testUsingNamesNotUsingIntermediateTensors(self): + with NodeStepper(self.sess, "e:0") as stepper: + # The first cont() call should have used no feeds. + result = stepper.cont("c:0") + self.assertAllClose(6.0, result) + self.assertItemsEqual(["a/read:0", "b/read:0"], + stepper.intermediate_tensor_names()) + self.assertAllClose(2.0, stepper.get_tensor_value("a/read:0")) + self.assertAllClose(3.0, stepper.get_tensor_value("b/read:0")) + self.assertEqual({}, stepper.last_feed_types()) + + # The second cont() call should have used the tensor handle from the + # previous cont() call. + result = stepper.cont("e:0") + self.assertAllClose(24.0, result) + self.assertItemsEqual(["a/read:0", "b/read:0", "d:0"], + stepper.intermediate_tensor_names()) + self.assertAllClose(2.0, stepper.get_tensor_value("a/read:0")) + self.assertAllClose(3.0, stepper.get_tensor_value("b/read:0")) + self.assertAllClose(4.0, stepper.get_tensor_value("d:0")) + self.assertEqual({ + "c:0": NodeStepper.FEED_TYPE_HANDLE, + "a/read:0": NodeStepper.FEED_TYPE_DUMPED_INTERMEDIATE, + }, stepper.last_feed_types()) - # There should be no handles before any cont() calls. - self.assertEqual([], stepper.handle_names()) - self.assertSetEqual(set(), stepper.handle_node_names()) - - # Before the cont() call, the stepper should not have access to the value - # of c:0. - with self.assertRaisesRegexp( - ValueError, - "This stepper instance does not have access to the value of tensor " - "\"c:0\""): - stepper.get_tensor_value("c:0") - - # Using the node/tensor itself, instead of the name str, should work on - # cont(). - result = stepper.cont(self.c) - self.assertAllClose(6.0, result) - self.assertEqual({}, stepper.last_feed_types()) - - self.assertEqual(["c:0"], stepper.handle_names()) - self.assertEqual({"c"}, stepper.handle_node_names()) - - # After the cont() call, the stepper should have access to the value of c:0 - # via a tensor handle. - self.assertAllClose(6.0, stepper.get_tensor_value("c:0")) - - result = stepper.cont(self.e) - self.assertAllClose(24.0, result) - self.assertEqual({ - "c:0": NodeStepper.FEED_TYPE_HANDLE - }, stepper.last_feed_types()) + def testUsingNodesNotUsingIntermediateTensors(self): + with NodeStepper(self.sess, self.e) as stepper: + # There should be no handles before any cont() calls. + self.assertEqual([], stepper.handle_names()) + self.assertSetEqual(set(), stepper.handle_node_names()) + + # Before the cont() call, the stepper should not have access to the value + # of c:0. + with self.assertRaisesRegexp( + ValueError, + "This stepper instance does not have access to the value of tensor " + "\"c:0\""): + stepper.get_tensor_value("c:0") + + # Using the node/tensor itself, instead of the name str, should work on + # cont(). + result = stepper.cont(self.c) + self.assertItemsEqual(["a/read:0", "b/read:0"], + stepper.intermediate_tensor_names()) + self.assertAllClose(6.0, result) + self.assertEqual({}, stepper.last_feed_types()) + + self.assertEqual(["c:0"], stepper.handle_names()) + self.assertEqual({"c"}, stepper.handle_node_names()) + + # After the cont() call, the stepper should have access to the value of + # c:0 via a tensor handle. + self.assertAllClose(6.0, stepper.get_tensor_value("c:0")) + + result = stepper.cont(self.e) + self.assertAllClose(24.0, result) + self.assertItemsEqual(["a/read:0", "b/read:0", "d:0"], + stepper.intermediate_tensor_names()) + self.assertEqual({ + "c:0": NodeStepper.FEED_TYPE_HANDLE, + "a/read:0": NodeStepper.FEED_TYPE_DUMPED_INTERMEDIATE, + }, stepper.last_feed_types()) + + def testContToTensorWithIntermediateDumpShouldUseDump(self): + with NodeStepper(self.sess, ["e:0", "f:0"]) as stepper: + stepper.cont("c:0") + self.assertItemsEqual(["a/read:0", "b/read:0"], + stepper.intermediate_tensor_names()) + self.assertAllClose(2.0, stepper.get_tensor_value("a/read:0")) + self.assertAllClose(3.0, stepper.get_tensor_value("b/read:0")) + + self.assertAllClose(2.0, stepper.cont("a/read:0")) + self.assertEqual({ + "a/read:0": NodeStepper.FEED_TYPE_DUMPED_INTERMEDIATE + }, stepper.last_feed_types()) + + self.assertAllClose(10.0, stepper.cont("f:0")) + self.assertEqual({ + "b/read:0": NodeStepper.FEED_TYPE_DUMPED_INTERMEDIATE + }, stepper.last_feed_types()) + + def testDisablingUseDumpedIntermediatesWorks(self): + with NodeStepper(self.sess, ["e:0", "f:0"]) as stepper: + stepper.cont("c:0") + self.assertItemsEqual(["a/read:0", "b/read:0"], + stepper.intermediate_tensor_names()) + self.assertAllClose(2.0, stepper.get_tensor_value("a/read:0")) + self.assertAllClose(3.0, stepper.get_tensor_value("b/read:0")) + + self.assertAllClose(10.0, + stepper.cont("f:0", use_dumped_intermediates=False)) + self.assertEqual({}, stepper.last_feed_types()) def testIsFeedableShouldGiveCorrectAnswers(self): - stepper = NodeStepper(self.sess, self.e) - - self.assertTrue(stepper.is_feedable("a/read:0")) - self.assertTrue(stepper.is_feedable("b/read:0")) - self.assertTrue(stepper.is_feedable("c:0")) - self.assertTrue(stepper.is_feedable("d:0")) + with NodeStepper(self.sess, self.e) as stepper: + self.assertTrue(stepper.is_feedable("a/read:0")) + self.assertTrue(stepper.is_feedable("b/read:0")) + self.assertTrue(stepper.is_feedable("c:0")) + self.assertTrue(stepper.is_feedable("d:0")) def testOverrideValue(self): - stepper = NodeStepper(self.sess, self.e) - - result = stepper.cont(self.c) - self.assertAllClose(6.0, result) - self.assertEqual({}, stepper.last_feed_types()) - - # There should be no overrides before any cont() calls. - self.assertEqual([], stepper.override_names()) - - # Calling cont() on c again should lead to use of the handle. - result = stepper.cont(self.c) - self.assertAllClose(6.0, result) - self.assertEqual({ - "c:0": NodeStepper.FEED_TYPE_HANDLE - }, stepper.last_feed_types()) - - # Override c:0. - stepper.override_tensor("c:0", 7.0) - - # After the overriding, calling get_tensor_value() on c:0 should yield the - # overriding value. - self.assertEqual(7.0, stepper.get_tensor_value("c:0")) - - # Now c:0 should have only an override value, but no cached handle, because - # the handle should have been invalidated. - self.assertEqual([], stepper.handle_names()) - self.assertSetEqual(set(), stepper.handle_node_names()) - self.assertEqual(["c:0"], stepper.override_names()) - - # Run a downstream tensor after the value override. - result = stepper.cont(self.e) - self.assertAllClose(28.0, result) # Should reflect the overriding value. - - # Should use override, instead of the handle. - self.assertEqual({ - "c:0": NodeStepper.FEED_TYPE_OVERRIDE - }, stepper.last_feed_types()) + with NodeStepper(self.sess, self.e) as stepper: + result = stepper.cont(self.c) + self.assertAllClose(6.0, result) + self.assertEqual({}, stepper.last_feed_types()) + + # There should be no overrides before any cont() calls. + self.assertEqual([], stepper.override_names()) + + # Calling cont() on c again should lead to use of the handle. + result = stepper.cont(self.c) + self.assertAllClose(6.0, result) + self.assertEqual({ + "c:0": NodeStepper.FEED_TYPE_HANDLE + }, stepper.last_feed_types()) + + # Override c:0. + stepper.override_tensor("c:0", 7.0) + + # After the overriding, calling get_tensor_value() on c:0 should yield the + # overriding value. + self.assertEqual(7.0, stepper.get_tensor_value("c:0")) + + # Now c:0 should have only an override value, but no cached handle, + # because the handle should have been invalidated. + self.assertEqual([], stepper.handle_names()) + self.assertSetEqual(set(), stepper.handle_node_names()) + self.assertEqual(["c:0"], stepper.override_names()) + + # Run a downstream tensor after the value override. + result = stepper.cont(self.e) + self.assertAllClose(28.0, result) # Should reflect the overriding value. + + # Should use override, instead of the handle. + self.assertEqual({ + "c:0": NodeStepper.FEED_TYPE_OVERRIDE, + "a/read:0": NodeStepper.FEED_TYPE_DUMPED_INTERMEDIATE, + }, stepper.last_feed_types()) def testOverrideValueTwice(self): - stepper = NodeStepper(self.sess, self.e) - - # Override once. - stepper.override_tensor("c:0", 7.0) - self.assertAllClose(28.0, stepper.cont(self.e)) - self.assertEqual({ - "c:0": NodeStepper.FEED_TYPE_OVERRIDE - }, stepper.last_feed_types()) - - self.assertEqual(["e:0"], stepper.handle_names()) - self.assertSetEqual({"e"}, stepper.handle_node_names()) - self.assertEqual(["c:0"], stepper.override_names()) - - # Calling cont(self.e) again. This time the cached tensor handle of e - # should be used. - self.assertEqual(28.0, stepper.cont(self.e)) - self.assertEqual({ - "e:0": NodeStepper.FEED_TYPE_HANDLE - }, stepper.last_feed_types()) - - # Override c again. This should have invalidated the cache for e. - stepper.override_tensor("c:0", 8.0) - - self.assertEqual([], stepper.handle_names()) - self.assertEqual(set(), stepper.handle_node_names()) - self.assertEqual(["c:0"], stepper.override_names()) - - self.assertAllClose(32.0, stepper.cont(self.e)) - self.assertEqual({ - "c:0": NodeStepper.FEED_TYPE_OVERRIDE - }, stepper.last_feed_types()) + with NodeStepper(self.sess, self.e) as stepper: + # Override once. + stepper.override_tensor("c:0", 7.0) + self.assertAllClose(28.0, stepper.cont(self.e)) + self.assertEqual({ + "c:0": NodeStepper.FEED_TYPE_OVERRIDE + }, stepper.last_feed_types()) + + self.assertEqual(["e:0"], stepper.handle_names()) + self.assertSetEqual({"e"}, stepper.handle_node_names()) + self.assertEqual(["c:0"], stepper.override_names()) + + # Calling cont(self.e) again. This time the cached tensor handle of e + # should be used. + self.assertEqual(28.0, stepper.cont(self.e)) + self.assertEqual({ + "e:0": NodeStepper.FEED_TYPE_HANDLE + }, stepper.last_feed_types()) + + # Override c again. This should have invalidated the cache for e. + stepper.override_tensor("c:0", 8.0) + + self.assertEqual([], stepper.handle_names()) + self.assertEqual(set(), stepper.handle_node_names()) + self.assertEqual(["c:0"], stepper.override_names()) + + self.assertAllClose(32.0, stepper.cont(self.e)) + self.assertEqual({ + "c:0": NodeStepper.FEED_TYPE_OVERRIDE, + "d:0": NodeStepper.FEED_TYPE_DUMPED_INTERMEDIATE, + }, stepper.last_feed_types()) def testRemoveOverrideValue(self): - stepper = NodeStepper(self.sess, self.e) - - result = stepper.cont(self.c) - self.assertAllClose(6.0, result) - self.assertEqual({}, stepper.last_feed_types()) - - # The previous cont() step should have generated a cached tensor handle. - self.assertEqual(["c:0"], stepper.handle_names()) - self.assertSetEqual({"c"}, stepper.handle_node_names()) - - # Override c:0. - stepper.override_tensor("c:0", 7.0) - - # The overriding should have invalidated the tensor handle. - self.assertEqual([], stepper.handle_names()) - self.assertSetEqual(set(), stepper.handle_node_names()) - self.assertEqual(["c:0"], stepper.override_names()) - - result = stepper.cont(self.e) - self.assertAllClose(28.0, result) # Should reflect the overriding value. - self.assertEqual({ - "c:0": NodeStepper.FEED_TYPE_OVERRIDE - }, stepper.last_feed_types()) - - # The handle to tensor e:0 should have been cached, even though its - # transitive closure contains an override. - self.assertIn("e:0", stepper.handle_names()) - self.assertSetEqual({"e"}, stepper.handle_node_names()) - - # Remove the override. - stepper.remove_override("c:0") - # c:0 should not be in the overrides anymore. - self.assertEqual([], stepper.override_names()) + with NodeStepper(self.sess, self.e) as stepper: + result = stepper.cont(self.c) + self.assertAllClose(6.0, result) + self.assertEqual({}, stepper.last_feed_types()) + + # The previous cont() step should have generated a cached tensor handle. + self.assertEqual(["c:0"], stepper.handle_names()) + self.assertSetEqual({"c"}, stepper.handle_node_names()) + + # Override c:0. + stepper.override_tensor("c:0", 7.0) + + # The overriding should have invalidated the tensor handle. + self.assertEqual([], stepper.handle_names()) + self.assertSetEqual(set(), stepper.handle_node_names()) + self.assertEqual(["c:0"], stepper.override_names()) + + result = stepper.cont(self.e) + self.assertAllClose(28.0, result) # Should reflect the overriding value. + self.assertEqual({ + "c:0": NodeStepper.FEED_TYPE_OVERRIDE, + "a/read:0": NodeStepper.FEED_TYPE_DUMPED_INTERMEDIATE, + }, stepper.last_feed_types()) + + # The handle to tensor e:0 should have been cached, even though its + # transitive closure contains an override. + self.assertIn("e:0", stepper.handle_names()) + self.assertSetEqual({"e"}, stepper.handle_node_names()) + + # Remove the override. + stepper.remove_override("c:0") + # c:0 should not be in the overrides anymore. + self.assertEqual([], stepper.override_names()) - # Removing the override should have invalidated the tensor handle for c. - self.assertNotIn("e:0", stepper.handle_names()) - self.assertNotIn("e", stepper.handle_node_names()) + # Removing the override should have invalidated the tensor handle for c. + self.assertNotIn("e:0", stepper.handle_names()) + self.assertNotIn("e", stepper.handle_node_names()) - # Should reflect the non-overriding value. - self.assertAllClose(24.0, stepper.cont(self.e)) + # Should reflect the non-overriding value. + self.assertAllClose(24.0, stepper.cont(self.e)) - # This time, the handle to tensor e:0 should have been cached again, even - # thought its transitive closure contains an override. - self.assertIn("e:0", stepper.handle_names()) - self.assertIn("e", stepper.handle_node_names()) + # This time, the handle to tensor e:0 should have been cached again, even + # thought its transitive closure contains an override. + self.assertIn("e:0", stepper.handle_names()) + self.assertIn("e", stepper.handle_node_names()) - # Calling cont(self.e) again should have used the tensor handle to e:0. - self.assertAllClose(24.0, stepper.cont(self.e)) - self.assertEqual({ - "e:0": NodeStepper.FEED_TYPE_HANDLE - }, stepper.last_feed_types()) + # Calling cont(self.e) again should have used the tensor handle to e:0. + self.assertAllClose(24.0, stepper.cont(self.e)) + self.assertEqual({ + "e:0": NodeStepper.FEED_TYPE_HANDLE, + }, stepper.last_feed_types()) def testOverrideAndContToSameTensor(self): - stepper = NodeStepper(self.sess, self.e) + with NodeStepper(self.sess, self.e) as stepper: + result = stepper.cont(self.c) + self.assertAllClose(6.0, result) + self.assertEqual({}, stepper.last_feed_types()) + self.assertEqual(["c:0"], stepper.handle_names()) + self.assertSetEqual({"c"}, stepper.handle_node_names()) - result = stepper.cont(self.c) - self.assertAllClose(6.0, result) - self.assertEqual({}, stepper.last_feed_types()) - self.assertEqual(["c:0"], stepper.handle_names()) - self.assertSetEqual({"c"}, stepper.handle_node_names()) + self.assertAllClose(6.0, stepper.cont(self.c)) - self.assertAllClose(6.0, stepper.cont(self.c)) + # The last cont() call should use the tensor handle directly. + self.assertEqual({ + "c:0": NodeStepper.FEED_TYPE_HANDLE + }, stepper.last_feed_types()) - # The last cont() call should use the tensor handle directly. - self.assertEqual({ - "c:0": NodeStepper.FEED_TYPE_HANDLE - }, stepper.last_feed_types()) + # Override c:0. + stepper.override_tensor("c:0", 7.0) - # Override c:0. - stepper.override_tensor("c:0", 7.0) + # As a result of the override, the tensor handle should have been + # invalidated. + self.assertEqual([], stepper.handle_names()) + self.assertSetEqual(set(), stepper.handle_node_names()) - # As a result of the override, the tensor handle should have been - # invalidated. - self.assertEqual([], stepper.handle_names()) - self.assertSetEqual(set(), stepper.handle_node_names()) + result = stepper.cont(self.c) + self.assertAllClose(7.0, result) - result = stepper.cont(self.c) - self.assertAllClose(7.0, result) - - self.assertEqual({ - "c:0": NodeStepper.FEED_TYPE_OVERRIDE - }, stepper.last_feed_types()) + self.assertEqual({ + "c:0": NodeStepper.FEED_TYPE_OVERRIDE + }, stepper.last_feed_types()) def testFinalizeWithPreviousOverrides(self): - stepper = NodeStepper(self.sess, self.e) + with NodeStepper(self.sess, self.e) as stepper: + stepper.override_tensor("a/read:0", 20.0) + self.assertEqual(["a/read:0"], stepper.override_names()) - stepper.override_tensor("a/read:0", 20.0) - self.assertEqual(["a/read:0"], stepper.override_names()) + # Should reflect the overriding value. + self.assertAllClose(24000.0, stepper.cont("e:0")) + self.assertEqual({ + "a/read:0": NodeStepper.FEED_TYPE_OVERRIDE + }, stepper.last_feed_types()) - # Should reflect the overriding value. - self.assertAllClose(24000.0, stepper.cont("e:0")) - self.assertEqual({ - "a/read:0": NodeStepper.FEED_TYPE_OVERRIDE - }, stepper.last_feed_types()) - - # Finalize call should have ignored the overriding value. - self.assertAllClose(24.0, stepper.finalize()) + # Finalize call should have ignored the overriding value. + self.assertAllClose(24.0, stepper.finalize()) def testRemoveNonexistentOverrideValue(self): - stepper = NodeStepper(self.sess, self.e) - self.assertEqual([], stepper.override_names()) - - with self.assertRaisesRegexp( - ValueError, "No overriding value exists for tensor \"c:0\""): - stepper.remove_override("c:0") + with NodeStepper(self.sess, self.e) as stepper: + self.assertEqual([], stepper.override_names()) + with self.assertRaisesRegexp( + ValueError, "No overriding value exists for tensor \"c:0\""): + stepper.remove_override("c:0") def testAttemptToOverrideInvalidTensor(self): stepper = NodeStepper(self.sess, self.e) @@ -333,20 +367,18 @@ class StepperTest(test_util.TensorFlowTestCase): stepper.override_tensor("f:0", 42.0) def testInvalidOverrideArgumentType(self): - stepper = NodeStepper(self.sess, self.e) - - with self.assertRaisesRegexp(TypeError, "Expected type str; got type"): - stepper.override_tensor(self.a, 42.0) + with NodeStepper(self.sess, self.e) as stepper: + with self.assertRaisesRegexp(TypeError, "Expected type str; got type"): + stepper.override_tensor(self.a, 42.0) def testTransitiveClosureWithCrossLinksShouldHaveCorrectOrder(self): - stepper = NodeStepper(self.sess, "z:0") - - sorted_nodes = stepper.sorted_nodes() - self.assertEqual(4, len(sorted_nodes)) - self.assertLess(sorted_nodes.index("x"), sorted_nodes.index("x/read")) - self.assertLess(sorted_nodes.index("x"), sorted_nodes.index("y")) - self.assertLess(sorted_nodes.index("x"), sorted_nodes.index("z")) - self.assertLess(sorted_nodes.index("y"), sorted_nodes.index("z")) + with NodeStepper(self.sess, "z:0") as stepper: + sorted_nodes = stepper.sorted_nodes() + self.assertEqual(4, len(sorted_nodes)) + self.assertLess(sorted_nodes.index("x"), sorted_nodes.index("x/read")) + self.assertLess(sorted_nodes.index("x"), sorted_nodes.index("y")) + self.assertLess(sorted_nodes.index("x"), sorted_nodes.index("z")) + self.assertLess(sorted_nodes.index("y"), sorted_nodes.index("z")) def testNodeStepperConstructorShouldAllowListOrTupleOrDictOfFetches(self): for i in range(6): @@ -363,45 +395,44 @@ class StepperTest(test_util.TensorFlowTestCase): elif i == 5: fetches = {"e": "e:0", "fz": {"f": "f:0", "z": "z:0"}} - stepper = NodeStepper(self.sess, fetches) - - sorted_nodes = stepper.sorted_nodes() - self.assertEqual(13, len(sorted_nodes)) - - # Check the topological order of the sorted nodes. - self.assertLess(sorted_nodes.index("x"), sorted_nodes.index("x/read")) - self.assertLess(sorted_nodes.index("x"), sorted_nodes.index("y")) - self.assertLess(sorted_nodes.index("x"), sorted_nodes.index("z")) - self.assertLess(sorted_nodes.index("y"), sorted_nodes.index("z")) - - self.assertLess(sorted_nodes.index("a"), sorted_nodes.index("a/read")) - self.assertLess(sorted_nodes.index("b"), sorted_nodes.index("b/read")) - self.assertLess(sorted_nodes.index("a"), sorted_nodes.index("c")) - self.assertLess(sorted_nodes.index("b"), sorted_nodes.index("c")) - self.assertLess(sorted_nodes.index("a"), sorted_nodes.index("d")) - self.assertLess(sorted_nodes.index("d"), sorted_nodes.index("e")) - self.assertLess(sorted_nodes.index("c"), sorted_nodes.index("e")) - self.assertLess(sorted_nodes.index("b"), sorted_nodes.index("f")) - self.assertLess(sorted_nodes.index("f_y"), sorted_nodes.index("f")) - - closure_elements = stepper.closure_elements() - self.assertIn("x/read:0", closure_elements) - self.assertIn("e:0", closure_elements) - self.assertIn("f:0", closure_elements) - - self.assertEqual([0], stepper.output_slots_in_closure("x/read")) - self.assertEqual([0], stepper.output_slots_in_closure("e")) - self.assertEqual([0], stepper.output_slots_in_closure("f")) - - result = stepper.finalize() - if i == 0 or i == 1 or i == 3 or i == 4: - self.assertAllClose(24.0, result[0]) - self.assertAllClose(10.0, result[1][0]) - self.assertAllClose(-4.0, result[1][1]) - elif i == 2 or i == 5: - self.assertAllClose(24.0, result["e"]) - self.assertAllClose(10.0, result["fz"]["f"]) - self.assertAllClose(-4.0, result["fz"]["z"]) + with NodeStepper(self.sess, fetches) as stepper: + sorted_nodes = stepper.sorted_nodes() + self.assertEqual(13, len(sorted_nodes)) + + # Check the topological order of the sorted nodes. + self.assertLess(sorted_nodes.index("x"), sorted_nodes.index("x/read")) + self.assertLess(sorted_nodes.index("x"), sorted_nodes.index("y")) + self.assertLess(sorted_nodes.index("x"), sorted_nodes.index("z")) + self.assertLess(sorted_nodes.index("y"), sorted_nodes.index("z")) + + self.assertLess(sorted_nodes.index("a"), sorted_nodes.index("a/read")) + self.assertLess(sorted_nodes.index("b"), sorted_nodes.index("b/read")) + self.assertLess(sorted_nodes.index("a"), sorted_nodes.index("c")) + self.assertLess(sorted_nodes.index("b"), sorted_nodes.index("c")) + self.assertLess(sorted_nodes.index("a"), sorted_nodes.index("d")) + self.assertLess(sorted_nodes.index("d"), sorted_nodes.index("e")) + self.assertLess(sorted_nodes.index("c"), sorted_nodes.index("e")) + self.assertLess(sorted_nodes.index("b"), sorted_nodes.index("f")) + self.assertLess(sorted_nodes.index("f_y"), sorted_nodes.index("f")) + + closure_elements = stepper.closure_elements() + self.assertIn("x/read:0", closure_elements) + self.assertIn("e:0", closure_elements) + self.assertIn("f:0", closure_elements) + + self.assertEqual([0], stepper.output_slots_in_closure("x/read")) + self.assertEqual([0], stepper.output_slots_in_closure("e")) + self.assertEqual([0], stepper.output_slots_in_closure("f")) + + result = stepper.finalize() + if i == 0 or i == 1 or i == 3 or i == 4: + self.assertAllClose(24.0, result[0]) + self.assertAllClose(10.0, result[1][0]) + self.assertAllClose(-4.0, result[1][1]) + elif i == 2 or i == 5: + self.assertAllClose(24.0, result["e"]) + self.assertAllClose(10.0, result["fz"]["f"]) + self.assertAllClose(-4.0, result["fz"]["z"]) class StepperTestWithPlaceHolders(test_util.TensorFlowTestCase): @@ -419,128 +450,222 @@ class StepperTestWithPlaceHolders(test_util.TensorFlowTestCase): ops.reset_default_graph() def testGetTensorValueWorksOnPlaceholder(self): - stepper = NodeStepper( + with NodeStepper( self.sess, self.y, feed_dict={ self.ph0: [[1.0, 2.0], [-3.0, 5.0]], self.ph1: [[-1.0], [0.5]] - }) - - self.assertAllClose([[1.0, 2.0], [-3.0, 5.0]], - stepper.get_tensor_value("ph0")) - self.assertAllClose([[1.0, 2.0], [-3.0, 5.0]], - stepper.get_tensor_value("ph0:0")) - with self.assertRaisesRegexp( - KeyError, r"The name 'ph0:1' refers to a Tensor which does not exist"): - stepper.get_tensor_value("ph0:1") + }) as stepper: + self.assertAllClose([[1.0, 2.0], [-3.0, 5.0]], + stepper.get_tensor_value("ph0")) + self.assertAllClose([[1.0, 2.0], [-3.0, 5.0]], + stepper.get_tensor_value("ph0:0")) + with self.assertRaisesRegexp( + KeyError, + r"The name 'ph0:1' refers to a Tensor which does not exist"): + stepper.get_tensor_value("ph0:1") def testIsPlaceholdersShouldGiveCorrectAnswers(self): - stepper = NodeStepper(self.sess, self.y) - - self.assertTrue(stepper.is_placeholder(self.ph0.name)) - self.assertTrue(stepper.is_placeholder(self.ph1.name)) + with NodeStepper(self.sess, self.y) as stepper: + self.assertTrue(stepper.is_placeholder(self.ph0.name)) + self.assertTrue(stepper.is_placeholder(self.ph1.name)) - self.assertFalse(stepper.is_placeholder(self.x.name)) - self.assertFalse(stepper.is_placeholder(self.y.name)) + self.assertFalse(stepper.is_placeholder(self.x.name)) + self.assertFalse(stepper.is_placeholder(self.y.name)) - with self.assertRaisesRegexp(ValueError, - "A is not in the transitive closure"): - self.assertFalse(stepper.is_placeholder("A")) + with self.assertRaisesRegexp(ValueError, + "A is not in the transitive closure"): + self.assertFalse(stepper.is_placeholder("A")) def testPlaceholdersShouldGiveCorrectAnswers(self): - stepper = NodeStepper(self.sess, self.y) - - self.assertSetEqual({"ph0", "ph1"}, set(stepper.placeholders())) + with NodeStepper(self.sess, self.y) as stepper: + self.assertSetEqual({"ph0", "ph1"}, set(stepper.placeholders())) def testContWithPlaceholders(self): - stepper = NodeStepper( + with NodeStepper( self.sess, self.y, feed_dict={ self.ph0: [[1.0, 2.0], [-3.0, 5.0]], self.ph1: [[-1.0], [0.5]] - }) - - self.assertEqual(4, len(stepper.sorted_nodes())) - self.assertSetEqual({"ph0:0", "ph1:0", "x:0", "y:0"}, - set(stepper.closure_elements())) - - result = stepper.cont(self.x) - self.assertAllClose([[0.0], [5.5]], result) - self.assertEqual({ - "ph0:0": NodeStepper.FEED_TYPE_CLIENT, - "ph1:0": NodeStepper.FEED_TYPE_CLIENT, - }, stepper.last_feed_types()) - - self.assertEqual(["x:0"], stepper.handle_names()) - self.assertSetEqual({"x"}, stepper.handle_node_names()) - - result = stepper.cont(self.y) - self.assertAllClose([[-1.0], [6.0]], result) - self.assertEqual({ - "x:0": NodeStepper.FEED_TYPE_HANDLE, - "ph1:0": NodeStepper.FEED_TYPE_CLIENT, - }, stepper.last_feed_types()) + }) as stepper: + self.assertEqual(4, len(stepper.sorted_nodes())) + self.assertSetEqual({"ph0:0", "ph1:0", "x:0", "y:0"}, + set(stepper.closure_elements())) + + result = stepper.cont(self.x) + self.assertAllClose([[0.0], [5.5]], result) + self.assertEqual({ + "ph0:0": NodeStepper.FEED_TYPE_CLIENT, + "ph1:0": NodeStepper.FEED_TYPE_CLIENT, + }, stepper.last_feed_types()) + + self.assertEqual(["x:0"], stepper.handle_names()) + self.assertSetEqual({"x"}, stepper.handle_node_names()) + + result = stepper.cont(self.y) + self.assertAllClose([[-1.0], [6.0]], result) + self.assertEqual({ + "x:0": NodeStepper.FEED_TYPE_HANDLE, + "ph1:0": NodeStepper.FEED_TYPE_CLIENT, + }, stepper.last_feed_types()) def testAttemptToContToPlaceholderWithTensorFeedKeysShouldWork(self): """Continuing to a placeholder should be allowed, using client feed.""" ph0_feed = [[1.0, 2.0], [-3.0, 5.0]] ph1_feed = [[-1.0], [0.5]] - stepper = NodeStepper( + with NodeStepper( self.sess, self.y, feed_dict={ self.ph0: ph0_feed, self.ph1: ph1_feed, - }) - - self.assertAllClose(ph0_feed, stepper.cont(self.ph0)) - self.assertEqual({ - self.ph0.name: NodeStepper.FEED_TYPE_CLIENT - }, stepper.last_feed_types()) + }) as stepper: + self.assertAllClose(ph0_feed, stepper.cont(self.ph0)) + self.assertEqual({ + self.ph0.name: NodeStepper.FEED_TYPE_CLIENT + }, stepper.last_feed_types()) - self.assertAllClose(ph1_feed, stepper.cont(self.ph1)) - self.assertEqual({ - self.ph1.name: NodeStepper.FEED_TYPE_CLIENT - }, stepper.last_feed_types()) + self.assertAllClose(ph1_feed, stepper.cont(self.ph1)) + self.assertEqual({ + self.ph1.name: NodeStepper.FEED_TYPE_CLIENT + }, stepper.last_feed_types()) - ph0_node = self.sess.graph.as_graph_element("ph0") - self.assertAllClose(ph0_feed, stepper.cont(ph0_node)) - self.assertEqual({ - self.ph0.name: NodeStepper.FEED_TYPE_CLIENT - }, stepper.last_feed_types()) + ph0_node = self.sess.graph.as_graph_element("ph0") + self.assertAllClose(ph0_feed, stepper.cont(ph0_node)) + self.assertEqual({ + self.ph0.name: NodeStepper.FEED_TYPE_CLIENT + }, stepper.last_feed_types()) - self.assertAllClose([[-1.0], [6.0]], stepper.finalize()) + self.assertAllClose([[-1.0], [6.0]], stepper.finalize()) def testAttemptToContToPlaceholderWithTensorNameFeedKeysShouldWork(self): ph0_feed = [[1.0, 2.0], [-3.0, 5.0]] ph1_feed = [[-1.0], [0.5]] - stepper = NodeStepper( + with NodeStepper( self.sess, self.y, feed_dict={ self.ph0.name: ph0_feed, self.ph1.name: ph1_feed, - }) + }) as stepper: + self.assertAllClose(ph0_feed, stepper.cont(self.ph0)) + self.assertEqual({ + self.ph0.name: NodeStepper.FEED_TYPE_CLIENT + }, stepper.last_feed_types()) - self.assertAllClose(ph0_feed, stepper.cont(self.ph0)) - self.assertEqual({ - self.ph0.name: NodeStepper.FEED_TYPE_CLIENT - }, stepper.last_feed_types()) + self.assertAllClose(ph1_feed, stepper.cont(self.ph1)) + self.assertEqual({ + self.ph1.name: NodeStepper.FEED_TYPE_CLIENT + }, stepper.last_feed_types()) - self.assertAllClose(ph1_feed, stepper.cont(self.ph1)) - self.assertEqual({ - self.ph1.name: NodeStepper.FEED_TYPE_CLIENT - }, stepper.last_feed_types()) + ph0_node = self.sess.graph.as_graph_element("ph0") + self.assertAllClose(ph0_feed, stepper.cont(ph0_node)) + self.assertEqual({ + self.ph0.name: NodeStepper.FEED_TYPE_CLIENT + }, stepper.last_feed_types()) - ph0_node = self.sess.graph.as_graph_element("ph0") - self.assertAllClose(ph0_feed, stepper.cont(ph0_node)) - self.assertEqual({ - self.ph0.name: NodeStepper.FEED_TYPE_CLIENT - }, stepper.last_feed_types()) + self.assertAllClose([[-1.0], [6.0]], stepper.finalize()) - self.assertAllClose([[-1.0], [6.0]], stepper.finalize()) + +class StepperAssignAddTest(test_util.TensorFlowTestCase): + + def setUp(self): + self.v = variables.Variable(10.0, name="v") + self.p = math_ops.add(self.v, self.v, name="p") + self.q = math_ops.multiply(self.p, self.p, name="q") + self.delta = constant_op.constant(2.0, name="delta") + self.v_add = state_ops.assign_add(self.v, self.delta, name="v_add") + self.v_add_plus_one = math_ops.add(self.v_add, + 1.0, + name="v_add_plus_one") + + self.sess = session.Session() + self.sess.run(self.v.initializer) + + def tearDown(self): + ops.reset_default_graph() + + def testContToUpdateInvalidatesDumpedIntermedates(self): + with NodeStepper(self.sess, [self.q, self.v_add]) as stepper: + self.assertAllClose(400.0, stepper.cont("q:0")) + self.assertItemsEqual(["v/read:0", "p:0"], + stepper.intermediate_tensor_names()) + self.assertAllClose(10.0, stepper.get_tensor_value("v/read:0")) + self.assertAllClose(20.0, stepper.get_tensor_value("p:0")) + + self.assertAllClose( + 12.0, stepper.cont( + self.v_add, invalidate_from_updated_variables=True)) + self.assertAllClose(12.0, self.sess.run(self.v)) + self.assertItemsEqual(["v:0"], stepper.dirty_variables()) + # Updating the value of v by calling v_add should have invalidated the + # dumped intermediate tensors for v/read:0 and p:0. + self.assertItemsEqual(["delta:0"], stepper.intermediate_tensor_names()) + with self.assertRaisesRegexp( + ValueError, + r"This stepper instance does not have access to the value of tensor " + r"\"p:0\""): + stepper.get_tensor_value("p:0") + + # The next cont to q should not have used any dumped intermediate tensors + # and its result should reflect the updated value. + self.assertAllClose(576.0, stepper.cont("q:0")) + self.assertEqual({}, stepper.last_feed_types()) + + def testOverridingUpstreamTensorInvalidatesDumpedIntermediates(self): + with NodeStepper(self.sess, self.q) as stepper: + self.assertAllClose(400.0, stepper.cont("q:0")) + self.assertItemsEqual(["v/read:0", "p:0"], + stepper.intermediate_tensor_names()) + self.assertAllClose(10.0, stepper.get_tensor_value("v/read:0")) + self.assertAllClose(20.0, stepper.get_tensor_value("p:0")) + + stepper.override_tensor("v/read:0", 11.0) + self.assertItemsEqual(["v/read:0"], stepper.override_names()) + # Overriding the upstream v/read:0 should have invalidated the dumped + # intermediate tensor for the downstream p:0. + self.assertItemsEqual([], stepper.intermediate_tensor_names()) + + # The next cont to q should not have used any dumped intermediate tensors + # and its result should reflect the overriding value. + self.assertAllClose(484.0, stepper.cont("q:0")) + self.assertEqual({ + "v/read:0": NodeStepper.FEED_TYPE_OVERRIDE + }, stepper.last_feed_types()) + + def testRemovingOverrideToUpstreamTensorInvalidatesDumpedIntermediates(self): + with NodeStepper(self.sess, self.q) as stepper: + stepper.override_tensor("v/read:0", 9.0) + self.assertItemsEqual(["v/read:0"], stepper.override_names()) + + self.assertAllClose(324.0, stepper.cont(self.q)) + self.assertItemsEqual(["p:0"], stepper.intermediate_tensor_names()) + + stepper.remove_override("v/read:0") + self.assertItemsEqual([], stepper.override_names()) + # Removing the pre-existing override to v/read:0 should have invalidated + # the dumped intermediate tensor. + self.assertItemsEqual([], stepper.intermediate_tensor_names()) + + def testRepeatedCallsToAssignAddDoesNotUpdateVariableAgain(self): + with NodeStepper(self.sess, self.v_add) as stepper: + stepper.cont(self.v_add) + self.assertAllClose(12.0, stepper.cont(self.v)) + stepper.cont(self.v_add) + self.assertEqual({"v_add:0": NodeStepper.FEED_TYPE_HANDLE}, + stepper.last_feed_types()) + self.assertAllClose(12.0, stepper.cont(self.v)) + + def testRepeatedCallsToAssignAddDownStreamDoesNotUpdateVariableAgain(self): + with NodeStepper(self.sess, self.v_add_plus_one) as stepper: + stepper.cont(self.v_add_plus_one) + self.assertAllClose(12.0, stepper.cont(self.v)) + stepper.cont(self.v_add_plus_one) + self.assertEqual({"v_add_plus_one:0": NodeStepper.FEED_TYPE_HANDLE}, + stepper.last_feed_types()) + self.assertAllClose(12.0, stepper.cont(self.v)) class StepperBackwardRunTest(test_util.TensorFlowTestCase): @@ -580,93 +705,136 @@ class StepperBackwardRunTest(test_util.TensorFlowTestCase): ops.reset_default_graph() def testContToUpdateA(self): - stepper = NodeStepper(self.sess, "optim") - - result = stepper.cont("a:0") - self.assertAllClose(1.0, result) - self.assertEqual({}, stepper.last_feed_types()) - - result = stepper.cont("optim/learning_rate:0") - self.assertAllClose(0.01, result) - self.assertEqual({}, stepper.last_feed_types()) - - # Before any cont calls on ApplyGradientDescent, there should be no "dirty" - # variables. - self.assertEqual(set(), stepper.dirty_variables()) - - # First, all the two control inputs to optim. - result = stepper.cont("optim/update_a/ApplyGradientDescent") - - # Now variable a should have been marked as dirty due to the update - # by optim/update_a/ApplyGradientDescent. - self.assertEqual({"a:0"}, stepper.dirty_variables()) - self.assertIsNone(result) - self.assertEqual({ - "optim/learning_rate:0": NodeStepper.FEED_TYPE_HANDLE - }, stepper.last_feed_types()) - - # Check that Variable "a" has been updated properly, but "b", "c" and "d" - # remain the same. - # For backprop on Variable a: - # Because f = a * b * b * c, df / da = b * b * c. - # 1.0 - learning_rate * b * b * c - # = 1.0 - 0.01 * 2.0 * 2.0 * 4.0 = 0.84. - self.assertAllClose(0.84, self.sess.run(self.a)) - self.assertAllClose(2.0, self.sess.run(self.b)) - self.assertAllClose(4.0, self.sess.run(self.c)) + with NodeStepper(self.sess, "optim") as stepper: + result = stepper.cont("a:0") + self.assertAllClose(1.0, result) + self.assertEqual({}, stepper.last_feed_types()) + + result = stepper.cont("optim/learning_rate:0") + self.assertAllClose(0.01, result) + self.assertEqual({}, stepper.last_feed_types()) + + # Before any cont calls on ApplyGradientDescent, there should be no + # "dirty" variables. + self.assertEqual(set(), stepper.dirty_variables()) + + # First, all the two control inputs to optim. + result = stepper.cont("optim/update_a/ApplyGradientDescent", + invalidate_from_updated_variables=True) + + # Now variable a should have been marked as dirty due to the update + # by optim/update_a/ApplyGradientDescent. + self.assertEqual({"a:0"}, stepper.dirty_variables()) + self.assertIsNone(result) + self.assertEqual({ + "optim/learning_rate:0": NodeStepper.FEED_TYPE_HANDLE + }, stepper.last_feed_types()) + + # Check that Variable "a" has been updated properly, but "b", "c" and "d" + # remain the same. + # For backprop on Variable a: + # Because f = a * b * b * c, df / da = b * b * c. + # 1.0 - learning_rate * b * b * c + # = 1.0 - 0.01 * 2.0 * 2.0 * 4.0 = 0.84. + self.assertAllClose(0.84, self.sess.run(self.a)) + self.assertAllClose(2.0, self.sess.run(self.b)) + self.assertAllClose(4.0, self.sess.run(self.c)) def testContToUpdateB(self): - stepper = NodeStepper(self.sess, "optim") - - result = stepper.cont("optim/update_b/ApplyGradientDescent") - self.assertIsNone(result) - self.assertEqual(set(["b:0"]), stepper.dirty_variables()) - - # For backprop on Variable b: - # Because f = a * b * b * c, df / da = 2 * a * b * c. - # 2.0 - learning_rate * 2 * a * b * c - # = 2.0 - 0.01 * 2 * 1.0 * 2.0 * 4.0 = 1.84 - self.assertAllClose(1.0, self.sess.run(self.a)) - self.assertAllClose(1.84, self.sess.run(self.b)) - self.assertAllClose(4.0, self.sess.run(self.c)) + with NodeStepper(self.sess, "optim") as stepper: + result = stepper.cont("optim/update_b/ApplyGradientDescent", + invalidate_from_updated_variables=True) + self.assertIsNone(result) + self.assertEqual(set(["b:0"]), stepper.dirty_variables()) + + # For backprop on Variable b: + # Because f = a * b * b * c, df / da = 2 * a * b * c. + # 2.0 - learning_rate * 2 * a * b * c + # = 2.0 - 0.01 * 2 * 1.0 * 2.0 * 4.0 = 1.84 + self.assertAllClose(1.0, self.sess.run(self.a)) + self.assertAllClose(1.84, self.sess.run(self.b)) + self.assertAllClose(4.0, self.sess.run(self.c)) def testContAfterUpdateWithoutRestoringVariableValue(self): - stepper = NodeStepper(self.sess, "optim") - - # First, update Variable a from 1.0 to 0.84. - result = stepper.cont( - "optim/update_a/ApplyGradientDescent", restore_variable_values=True) - self.assertIsNone(result) - self.assertEqual(set(["a:0"]), stepper.dirty_variables()) - self.assertAllClose(0.84, self.sess.run(self.a)) - self.assertAllClose(2.0, self.sess.run(self.b)) - self.assertAllClose(4.0, self.sess.run(self.c)) - - # Second, update Variable b without the default restore_variable_values. - result = stepper.cont( - "optim/update_b/ApplyGradientDescent", restore_variable_values=False) - self.assertIsNone(result) - # For the backprop on Variable b under the updated value of a: - # 2.0 - learning_rate * 2 * a' * b * c - # = 2.0 - 0.01 * 2 * 0.84 * 2.0 * 4.0 = 1.8656 - self.assertAllClose(0.84, self.sess.run(self.a)) - self.assertAllClose(1.8656, self.sess.run(self.b)) - self.assertAllClose(4.0, self.sess.run(self.c)) + with NodeStepper(self.sess, "optim") as stepper: + # First, update Variable a from 1.0 to 0.84. + result = stepper.cont( + "optim/update_a/ApplyGradientDescent", + invalidate_from_updated_variables=True, + restore_variable_values=True) + self.assertIsNone(result) + self.assertEqual(set(["a:0"]), stepper.dirty_variables()) + self.assertAllClose(0.84, self.sess.run(self.a)) + self.assertAllClose(2.0, self.sess.run(self.b)) + self.assertAllClose(4.0, self.sess.run(self.c)) + # Tracking of the updated variables should have invalidated all + # intermediate tensors downstream to a:0. + self.assertNotIn("a/read:0", stepper.intermediate_tensor_names()) + self.assertNotIn("d:0", stepper.intermediate_tensor_names()) + + # Second, update Variable b without the default restore_variable_values. + result = stepper.cont( + "optim/update_b/ApplyGradientDescent", restore_variable_values=False) + self.assertIsNone(result) + # For the backprop on Variable b under the updated value of a: + # 2.0 - learning_rate * 2 * a' * b * c + # = 2.0 - 0.01 * 2 * 0.84 * 2.0 * 4.0 = 1.8656 + self.assertAllClose(0.84, self.sess.run(self.a)) + self.assertAllClose(1.8656, self.sess.run(self.b)) + self.assertAllClose(4.0, self.sess.run(self.c)) + + def testContNotInvalidatingFromVariableUpdatesWorksForNextUpdate(self): + with NodeStepper(self.sess, "optim") as stepper: + self.assertIsNone(stepper.cont( + "optim/update_a/ApplyGradientDescent", + invalidate_from_updated_variables=False)) + # Even though invalidate_from_updated_variables is set to False, dirty + # variables should still have been tracked. + self.assertEqual({"a:0"}, stepper.dirty_variables()) + self.assertIn("a/read:0", stepper.intermediate_tensor_names()) + self.assertIn("b/read:0", stepper.intermediate_tensor_names()) + self.assertIn("c/read:0", stepper.intermediate_tensor_names()) + self.assertIn("d:0", stepper.intermediate_tensor_names()) + self.assertIn("e:0", stepper.intermediate_tensor_names()) + self.assertIn("optim/learning_rate:0", + stepper.intermediate_tensor_names()) + self.assertNotIn("a:0", stepper.intermediate_tensor_names()) + self.assertNotIn("b:0", stepper.intermediate_tensor_names()) + self.assertNotIn("c:0", stepper.intermediate_tensor_names()) + + self.assertAllClose(0.84, self.sess.run(self.a)) + self.assertAllClose(2.0, self.sess.run(self.b)) + self.assertAllClose(4.0, self.sess.run(self.c)) + + # For the backprop on Variable b, the result should reflect the original + # value of Variable a, even though Variable a has actually been updated. + # 2.0 - learning_rate * 2 * a * b * c + # = 2.0 - 0.01 * 2 * 1.0 * 2.0 * 4.0 = 1.84 + self.assertIsNone(stepper.cont( + "optim/update_b/ApplyGradientDescent", + invalidate_from_updated_variables=False, + restore_variable_values=False)) + self.assertAllClose(0.84, self.sess.run(self.a)) + self.assertAllClose(1.84, self.sess.run(self.b)) + self.assertAllClose(4.0, self.sess.run(self.c)) def testUpdateTwiceRestoreVariable(self): - stepper = NodeStepper(self.sess, "optim") - - result = stepper.cont( - "optim/update_a/ApplyGradientDescent", restore_variable_values=True) - self.assertIsNone(result) - self.assertEqual({"a:0"}, stepper.dirty_variables()) - - result = stepper.cont( - "optim/update_b/ApplyGradientDescent", restore_variable_values=True) - self.assertIsNone(result) - # Variables a and c should have been restored and hence no longer dirty. - # Variable b should have been marked as dirty. - self.assertEqual({"b:0"}, stepper.dirty_variables()) + with NodeStepper(self.sess, "optim") as stepper: + result = stepper.cont( + "optim/update_a/ApplyGradientDescent", + invalidate_from_updated_variables=True, + restore_variable_values=True) + self.assertIsNone(result) + self.assertEqual({"a:0"}, stepper.dirty_variables()) + + result = stepper.cont( + "optim/update_b/ApplyGradientDescent", + invalidate_from_updated_variables=True, + restore_variable_values=True) + self.assertIsNone(result) + # Variables a and c should have been restored and hence no longer dirty. + # Variable b should have been marked as dirty. + self.assertEqual({"b:0"}, stepper.dirty_variables()) # The result of the update should be identitcal to as if only update_b is # run. @@ -680,179 +848,203 @@ class StepperBackwardRunTest(test_util.TensorFlowTestCase): "clean" means no Variables have been updated by preceding cont() calls. """ - stepper = NodeStepper(self.sess, "optim") - - # First, call cont() on the two tensors on the intermediate level: e and f. - result = stepper.cont("d:0") - self.assertAllClose(2.0, result) - self.assertEqual({}, stepper.last_feed_types()) - self.assertEqual(set(), stepper.dirty_variables()) - - # The cont call above should have restored Variable "b". - result = stepper.cont("e:0") - self.assertAllClose(8.0, result) - self.assertEqual({}, stepper.last_feed_types()) - self.assertEqual(set(), stepper.dirty_variables()) - - # Now run update_a, so as to let Variable a be diry. - result = stepper.cont( - "optim/update_a/ApplyGradientDescent", restore_variable_values=True) - self.assertIsNone(result) - self.assertEqual({"a:0"}, stepper.dirty_variables()) - - # Now, run update_b. - result = stepper.cont( - "optim/update_b/ApplyGradientDescent", restore_variable_values=True) - self.assertIsNone(result) - - # The last cont() run should have use the handle of tensor e, but not the - # handle of tensor d, because the transitive closure of e is clean, whereas - # that of d is dirty due to the update to a in the previous cont() call. - self.assertEqual({ - "e:0": NodeStepper.FEED_TYPE_HANDLE - }, stepper.last_feed_types()) - - # The result of the update_b should be identical to as if no other - # update_* cont() calls have occurred before. - self.assertAllClose(1.0, self.sess.run(self.a)) - self.assertAllClose(1.84, self.sess.run(self.b)) - self.assertAllClose(4.0, self.sess.run(self.c)) + with NodeStepper(self.sess, "optim") as stepper: + # First, call cont() on the two tensors on the intermediate level: e and + # f. + result = stepper.cont("d:0") + self.assertAllClose(2.0, result) + self.assertEqual({}, stepper.last_feed_types()) + self.assertItemsEqual(["a/read:0", "b/read:0"], + stepper.intermediate_tensor_names()) + self.assertItemsEqual(["d:0"], stepper.handle_names()) + self.assertEqual(set(), stepper.dirty_variables()) + + result = stepper.cont("e:0") + self.assertAllClose(8.0, result) + self.assertEqual({ + "b/read:0": NodeStepper.FEED_TYPE_DUMPED_INTERMEDIATE + }, stepper.last_feed_types()) + self.assertItemsEqual(["d:0", "e:0"], stepper.handle_names()) + self.assertItemsEqual(["a/read:0", "b/read:0", "c/read:0"], + stepper.intermediate_tensor_names()) + self.assertEqual(set(), stepper.dirty_variables()) + + # Now run update_a, so as to let Variable a be dirty. + result = stepper.cont( + "optim/update_a/ApplyGradientDescent", + invalidate_from_updated_variables=True, + restore_variable_values=True) + self.assertIsNone(result) + # Due to the update to the value of a:0, the dumped intermediate a/read:0 + # should have been invalidated. + self.assertNotIn("a/read:0", stepper.intermediate_tensor_names()) + # ["b/read:0", "c/read:0"], + self.assertEqual({"a:0"}, stepper.dirty_variables()) + + # Now, run update_b. + result = stepper.cont( + "optim/update_b/ApplyGradientDescent", restore_variable_values=True) + self.assertIsNone(result) + + # The last cont() run should have use the handle of tensor e, but not the + # handle of tensor d, because the transitive closure of e is clean, + # whereas that of d is dirty due to the update to a in the previous cont() + # call. + last_feed_types = stepper.last_feed_types() + self.assertNotIn("d:0", last_feed_types) + self.assertEqual(NodeStepper.FEED_TYPE_DUMPED_INTERMEDIATE, + last_feed_types["b/read:0"]) + self.assertEqual(NodeStepper.FEED_TYPE_DUMPED_INTERMEDIATE, + last_feed_types["c/read:0"]) + + # The result of the update_b should be identical to as if no other + # update_* cont() calls have occurred before. + self.assertAllClose(1.0, self.sess.run(self.a)) + self.assertAllClose(1.84, self.sess.run(self.b)) + self.assertAllClose(4.0, self.sess.run(self.c)) def testRestoreVariableValues(self): """Test restore_variable_values() restores the old values of variables.""" - stepper = NodeStepper(self.sess, "optim") - - stepper.cont( - "optim/update_b/ApplyGradientDescent", restore_variable_values=True) - self.assertAllClose(1.84, self.sess.run(self.b)) + with NodeStepper(self.sess, "optim") as stepper: + stepper.cont( + "optim/update_b/ApplyGradientDescent", + invalidate_from_updated_variables=True, + restore_variable_values=True) + self.assertAllClose(1.84, self.sess.run(self.b)) - stepper.restore_variable_values() - self.assertAllClose(2.0, self.sess.run(self.b)) + stepper.restore_variable_values() + self.assertAllClose(2.0, self.sess.run(self.b)) def testFinalize(self): """Test finalize() to restore variables and run the original fetch.""" - stepper = NodeStepper(self.sess, "optim") - - # Invoke update_b before calling finalize. - stepper.cont( - "optim/update_b/ApplyGradientDescent", restore_variable_values=True) + with NodeStepper(self.sess, "optim") as stepper: + # Invoke update_b before calling finalize. + stepper.cont( + "optim/update_b/ApplyGradientDescent", + invalidate_from_updated_variables=True, + restore_variable_values=True) - result = stepper.finalize() - self.assertIsNone(result) + result = stepper.finalize() + self.assertIsNone(result) - # The results of the Variable updates should be the same as if no cont() - # call has occurred on update_b. - self.assertAllClose(0.84, self.sess.run(self.a)) - self.assertAllClose(1.84, self.sess.run(self.b)) - self.assertAllClose(3.96, self.sess.run(self.c)) + # The results of the Variable updates should be the same as if no cont() + # call has occurred on update_b. + self.assertAllClose(0.84, self.sess.run(self.a)) + self.assertAllClose(1.84, self.sess.run(self.b)) + self.assertAllClose(3.96, self.sess.run(self.c)) - def testOverrideThenContToUpdate(self): + def testOverrideThenContToUpdateThenRemoveOverrideThenUpdateAgain(self): """Test cont() to update nodes after overriding tensor values.""" - stepper = NodeStepper(self.sess, "optim") - - result = stepper.cont("d:0") - self.assertAllClose(2.0, result) - self.assertEqual({}, stepper.last_feed_types()) - self.assertEqual(set(), stepper.dirty_variables()) - self.assertEqual(["d:0"], stepper.handle_names()) - self.assertSetEqual({"d"}, stepper.handle_node_names()) - - # Override the value from 1.0 to 10.0. - stepper.override_tensor("a/read:0", 10.0) - - self.assertEqual(["a/read:0"], stepper.override_names()) - - result = stepper.cont( - "optim/update_c/ApplyGradientDescent", restore_variable_values=True) - self.assertIsNone(result) - - # The last cont() call should have not used the tensor handle to d:0, - # because the transitive closure of d:0 contains an override tensor. - self.assertEqual({ - "a/read:0": NodeStepper.FEED_TYPE_OVERRIDE - }, stepper.last_feed_types()) - - # The tensor handle to d:0 should have been removed due to the dirty - # transitive closure. - self.assertEqual([], stepper.handle_names()) - self.assertSetEqual(set(), stepper.handle_node_names()) - - # For this backprop on c, the overriding value of a/read:0 should have been - # used: - # 4.0 - learning_rate * a * b * b - # = 4.0 - 0.01 * 10.0 * 2.0 * 2.0 = 3.6. - self.assertAllClose(3.6, self.sess.run(self.c)) - - # Now remove the overriding value of a/read:0. - stepper.remove_override("a/read:0") - self.assertEqual([], stepper.override_names()) - - # Obtain the tensor handle to d:0 again. - result = stepper.cont("d:0") - self.assertAllClose(2.0, result) - self.assertEqual(["d:0"], stepper.handle_names()) - self.assertSetEqual({"d"}, stepper.handle_node_names()) - - # Then call update_c again, without restoring c. - result = stepper.cont( - "optim/update_c/ApplyGradientDescent", restore_variable_values=False) - self.assertIsNone(result) - - # This time, the d:0 tensor handle should have been used, because its - # transitive closure is clean. - self.assertEqual({ - "d:0": NodeStepper.FEED_TYPE_HANDLE - }, stepper.last_feed_types()) - - # For this backprop on c, the overriding value of a/read:0 should have been - # used: - # 3.6 - learning_rate * a * b * b - # = 3.6 - 0.01 * 1.0 * 2.0 * 2.0 = 3.56. - self.assertAllClose(3.56, self.sess.run(self.c)) + with NodeStepper(self.sess, "optim") as stepper: + result = stepper.cont("d:0") + self.assertAllClose(2.0, result) + self.assertEqual({}, stepper.last_feed_types()) + self.assertEqual(set(), stepper.dirty_variables()) + self.assertEqual(["d:0"], stepper.handle_names()) + self.assertSetEqual({"d"}, stepper.handle_node_names()) + + # Override the value from 1.0 to 10.0. + stepper.override_tensor("a/read:0", 10.0) + + self.assertEqual(["a/read:0"], stepper.override_names()) + + result = stepper.cont( + "optim/update_c/ApplyGradientDescent", + invalidate_from_updated_variables=True, + restore_variable_values=True) + self.assertIsNone(result) + + # The last cont() call should have not used the tensor handle to d:0, + # because the transitive closure of d:0 contains an override tensor. + self.assertEqual({ + "a/read:0": NodeStepper.FEED_TYPE_OVERRIDE, + "b/read:0": NodeStepper.FEED_TYPE_DUMPED_INTERMEDIATE, + }, stepper.last_feed_types()) + + # The tensor handle to d:0 should have been removed due to the dirty + # transitive closure. + self.assertEqual([], stepper.handle_names()) + self.assertSetEqual(set(), stepper.handle_node_names()) + + # For this backprop on c, the overriding value of a/read:0 should have + # been used: + # 4.0 - learning_rate * a * b * b + # = 4.0 - 0.01 * 10.0 * 2.0 * 2.0 = 3.6. + self.assertAllClose(3.6, self.sess.run(self.c)) + + # Now remove the overriding value of a/read:0. + stepper.remove_override("a/read:0") + self.assertEqual([], stepper.override_names()) + + # Obtain the tensor handle to d:0 again. + result = stepper.cont("d:0") + self.assertAllClose(2.0, result) + self.assertEqual(["d:0"], stepper.handle_names()) + self.assertSetEqual({"d"}, stepper.handle_node_names()) + self.assertNotIn("a/read:0", stepper.last_feed_types()) + + # Then call update_c again, without restoring c. + result = stepper.cont("optim/update_c/ApplyGradientDescent", + restore_variable_values=False) + self.assertIsNone(result) + self.assertNotIn("a/read:0", stepper.last_feed_types()) + + # This time, the d:0 tensor handle should have been used, because its + # transitive closure is clean. + self.assertEqual({ + "b/read:0": NodeStepper.FEED_TYPE_DUMPED_INTERMEDIATE, + "d:0": NodeStepper.FEED_TYPE_HANDLE, + "optim/learning_rate:0": NodeStepper.FEED_TYPE_DUMPED_INTERMEDIATE, + }, stepper.last_feed_types()) + + # For this backprop on c, the overriding value of a/read:0 should have + # been used: + # 3.6 - learning_rate * a * b * b + # = 3.6 - 0.01 * 1.0 * 2.0 * 2.0 = 3.56. + self.assertAllClose(3.56, self.sess.run(self.c)) def testContToNodeWithOutputTensors(self): """cont() to an op should cache its output tensors if appropriate.""" - stepper = NodeStepper(self.sess, "optim") - - # In the transitive closure of the stepper, look for an op of which the - # output tensor also is in the transitive closure. - # Do not assume a specific op, e.g., ""gradients/e_grad/Reshape_1", - # because it may vary between builds. - closure_elements = stepper.closure_elements() - op_with_output_in_closure = None - for element_name in closure_elements: - if element_name + ":0" in closure_elements: - op_with_output_in_closure = str(element_name) - break - - self.assertEqual([0], - stepper.output_slots_in_closure(op_with_output_in_closure)) - - self.assertIsNotNone(op_with_output_in_closure) - output_tensor = op_with_output_in_closure + ":0" - - # The op "gradients/?_grad/Reshape_1" is in the transitive closure of the - # stepper, because it is the control input to another o. However, its - # output tensor "gradients/?_grad/Reshape_1:0" is also in the transitive - # closure, because it is the (non-control) input of certain ops. Calling - # cont() on the op should lead to the caching of the tensor handle for - # the output tensor. - stepper.cont(op_with_output_in_closure) - - self.assertEqual([output_tensor], stepper.handle_names()) - self.assertSetEqual({op_with_output_in_closure}, - stepper.handle_node_names()) - - # Do a cont() call that uses the cached tensor of - # "gradients/?_grad/Reshape_1:0". - stepper.cont(output_tensor) - self.assertEqual({ - output_tensor: NodeStepper.FEED_TYPE_HANDLE - }, stepper.last_feed_types()) + with NodeStepper(self.sess, "optim") as stepper: + # In the transitive closure of the stepper, look for an op of which the + # output tensor also is in the transitive closure. + # Do not assume a specific op, e.g., ""gradients/e_grad/Reshape_1", + # because it may vary between builds. + closure_elements = stepper.closure_elements() + op_with_output_in_closure = None + for element_name in closure_elements: + if element_name + ":0" in closure_elements: + op_with_output_in_closure = str(element_name) + break + + self.assertEqual( + [0], stepper.output_slots_in_closure(op_with_output_in_closure)) + + self.assertIsNotNone(op_with_output_in_closure) + output_tensor = op_with_output_in_closure + ":0" + + # The op "gradients/?_grad/Reshape_1" is in the transitive closure of the + # stepper, because it is the control input to another o. However, its + # output tensor "gradients/?_grad/Reshape_1:0" is also in the transitive + # closure, because it is the (non-control) input of certain ops. Calling + # cont() on the op should lead to the caching of the tensor handle for + # the output tensor. + stepper.cont(op_with_output_in_closure) + + self.assertEqual([output_tensor], stepper.handle_names()) + self.assertSetEqual({op_with_output_in_closure}, + stepper.handle_node_names()) + + # Do a cont() call that uses the cached tensor of + # "gradients/?_grad/Reshape_1:0". + stepper.cont(output_tensor) + self.assertEqual({ + output_tensor: NodeStepper.FEED_TYPE_HANDLE + }, stepper.last_feed_types()) if __name__ == "__main__": diff --git a/tensorflow/python/debug/wrappers/framework.py b/tensorflow/python/debug/wrappers/framework.py index 145008e902..00ab7277cb 100644 --- a/tensorflow/python/debug/wrappers/framework.py +++ b/tensorflow/python/debug/wrappers/framework.py @@ -427,9 +427,10 @@ class BaseDebugWrapperSession(session.SessionInterface): elif (run_start_resp.action == OnRunStartAction.NON_DEBUG_RUN or run_start_resp.action == OnRunStartAction.INVOKE_STEPPER): if run_start_resp.action == OnRunStartAction.INVOKE_STEPPER: - retvals = self.invoke_node_stepper( - stepper.NodeStepper(self._sess, fetches, feed_dict), - restore_variable_values_on_exit=True) + with stepper.NodeStepper( + self._sess, fetches, feed_dict) as node_stepper: + retvals = self.invoke_node_stepper( + node_stepper, restore_variable_values_on_exit=True) # Invoke run() method of the wrapped session. retvals = self._sess.run( diff --git a/tensorflow/python/debug/wrappers/hooks.py b/tensorflow/python/debug/wrappers/hooks.py index 30f0e117e6..58a8f7f9c9 100644 --- a/tensorflow/python/debug/wrappers/hooks.py +++ b/tensorflow/python/debug/wrappers/hooks.py @@ -107,10 +107,13 @@ class LocalCLIDebugHook(session_run_hook.SessionRunHook, run_context.session.graph._finalized = False # pylint: enable=protected-access - self.invoke_node_stepper( - stepper.NodeStepper(run_context.session, run_context.original_args. - fetches, run_context.original_args.feed_dict), - restore_variable_values_on_exit=True) + with stepper.NodeStepper( + run_context.session, + run_context.original_args. + fetches, + run_context.original_args.feed_dict) as node_stepper: + self.invoke_node_stepper( + node_stepper, restore_variable_values_on_exit=True) return run_args |