aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/python/debug/BUILD1
-rw-r--r--tensorflow/python/debug/cli/debugger_cli_common.py1
-rw-r--r--tensorflow/python/debug/cli/stepper_cli.py90
-rw-r--r--tensorflow/python/debug/cli/stepper_cli_test.py485
-rw-r--r--tensorflow/python/debug/stepper.py228
-rw-r--r--tensorflow/python/debug/stepper_test.py1394
-rw-r--r--tensorflow/python/debug/wrappers/framework.py7
-rw-r--r--tensorflow/python/debug/wrappers/hooks.py11
8 files changed, 1307 insertions, 910 deletions
diff --git a/tensorflow/python/debug/BUILD b/tensorflow/python/debug/BUILD
index 22eb840791..339a6a72e0 100644
--- a/tensorflow/python/debug/BUILD
+++ b/tensorflow/python/debug/BUILD
@@ -52,6 +52,7 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":debug_data",
+ ":debug_utils",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:session_ops",
"@six_archive//:six",
diff --git a/tensorflow/python/debug/cli/debugger_cli_common.py b/tensorflow/python/debug/cli/debugger_cli_common.py
index f456a2387c..f4274d958d 100644
--- a/tensorflow/python/debug/cli/debugger_cli_common.py
+++ b/tensorflow/python/debug/cli/debugger_cli_common.py
@@ -90,6 +90,7 @@ class RichLine(object):
ret = RichLine()
if isinstance(other, str):
ret.text = self.text + other
+ ret.font_attr_segs = self.font_attr_segs[:]
return ret
elif isinstance(other, RichLine):
ret.text = self.text + other.text
diff --git a/tensorflow/python/debug/cli/stepper_cli.py b/tensorflow/python/debug/cli/stepper_cli.py
index 377fcc0387..bb76c440bc 100644
--- a/tensorflow/python/debug/cli/stepper_cli.py
+++ b/tensorflow/python/debug/cli/stepper_cli.py
@@ -41,7 +41,7 @@ class NodeStepperCLI(object):
STATE_CONT = "H"
# State where an intermediate dump of the tensor is available.
- STATE_INTERMEDIATE = "I"
+ STATE_DUMPED_INTERMEDIATE = "I"
# State where the element is already overridden.
STATE_OVERRIDDEN = "O"
@@ -53,6 +53,8 @@ class NodeStepperCLI(object):
# this NodeStepperCLI instance.
STATE_DIRTY_VARIABLE = "D"
+ STATE_UNFEEDABLE = "U"
+
NEXT_NODE_POINTER_STR = "-->"
_MESSAGE_TEMPLATES = {
@@ -63,6 +65,15 @@ class NodeStepperCLI(object):
"Please use full tensor name.",
}
+ _STATE_COLORS = {
+ STATE_CONT: "green",
+ STATE_DIRTY_VARIABLE: "magenta",
+ STATE_DUMPED_INTERMEDIATE: "blue",
+ STATE_OVERRIDDEN: "yellow",
+ STATE_IS_PLACEHOLDER: "cyan",
+ STATE_UNFEEDABLE: "red",
+ }
+
def __init__(self, node_stepper):
self._node_stepper = node_stepper
@@ -98,6 +109,14 @@ class NodeStepperCLI(object):
type=str,
help="Name of the Tensor or Op to continue to.")
ap.add_argument(
+ "-i",
+ "--invalidate_from_updated_variables",
+ dest="invalidate_from_updated_variables",
+ action="store_true",
+ help="Whether to invalidate the cached "
+ "tensor handles and intermediate tensor handles affected "
+ "by Variable updates in this continue call.")
+ ap.add_argument(
"-r",
"--restore_variable_values",
dest="restore_variable_values",
@@ -211,6 +230,7 @@ class NodeStepperCLI(object):
verbose = True
handle_node_names = self._node_stepper.handle_node_names()
+ intermediate_tensor_names = self._node_stepper.intermediate_tensor_names()
override_names = self._node_stepper.override_names()
dirty_variable_names = [
dirty_variable.split(":")[0]
@@ -242,6 +262,7 @@ class NodeStepperCLI(object):
labels, label_font_attr_segs = self._get_status_labels(
element_name,
handle_node_names,
+ intermediate_tensor_names,
override_names,
dirty_variable_names,
len(node_prefix))
@@ -262,6 +283,7 @@ class NodeStepperCLI(object):
def _get_status_labels(self,
element_name,
handle_node_names,
+ intermediate_tensor_names,
override_names,
dirty_variable_names,
offset):
@@ -278,6 +300,7 @@ class NodeStepperCLI(object):
element_name: (str) name of the graph element.
handle_node_names: (list of str) Names of the nodes of which the output
tensors' handles are available.
+ intermediate_tensor_names: (list of str) TOOD(cais): document.
override_names: (list of str) Names of the tensors of which the values
are overridden.
dirty_variable_names: (list of str) Names of the dirty variables.
@@ -292,19 +315,31 @@ class NodeStepperCLI(object):
status = RL(" " * offset)
node_name = element_name.split(":")[0]
- status += RL("P", "cyan") if node_name in self._placeholders else " "
- status += (RL("U", "red")
+ status += (RL(self.STATE_IS_PLACEHOLDER,
+ self._STATE_COLORS[self.STATE_IS_PLACEHOLDER])
+ if node_name in self._placeholders else " ")
+ status += (RL(self.STATE_UNFEEDABLE,
+ self._STATE_COLORS[self.STATE_UNFEEDABLE])
if not self._node_stepper.is_feedable(str(element_name))
else " ")
- status += (RL("H", "green") if element_name in handle_node_names else " ")
+ status += (RL(self.STATE_CONT, self._STATE_COLORS[self.STATE_CONT])
+ if element_name in handle_node_names else " ")
+
+ intermediate_node_names = [
+ tensor_name.split(":")[0] for tensor_name in intermediate_tensor_names]
+ status += (RL(self.STATE_DUMPED_INTERMEDIATE,
+ self._STATE_COLORS[self.STATE_DUMPED_INTERMEDIATE])
+ if element_name in intermediate_node_names else " ")
slots = self._node_stepper.output_slots_in_closure(element_name)
has_override = any(element_name + ":%d" % slot in override_names
for slot in slots)
- status += RL("O", "yellow") if has_override else " "
- status += (RL(self.STATE_DIRTY_VARIABLE, "magenta")
- if element_name in dirty_variable_names
- else " ")
+ status += (RL(self.STATE_OVERRIDDEN,
+ self._STATE_COLORS[self.STATE_OVERRIDDEN])
+ if has_override else " ")
+ status += (RL(self.STATE_DIRTY_VARIABLE,
+ self._STATE_COLORS[self.STATE_DIRTY_VARIABLE])
+ if element_name in dirty_variable_names else " ")
# TODO(ebreck) Return status here, once the caller is updated with the
# RichLine API.
@@ -320,14 +355,30 @@ class NodeStepperCLI(object):
return debugger_cli_common.rich_text_lines_from_rich_line_list([
RL(""),
RL("Legend:"),
- RL(" ") + RL("P", "cyan") + " - Placeholder",
- RL(" ") + RL("U", "red") + " - Unfeedable",
- (RL(" ") + RL("H", "green") +
+ (RL(" ") +
+ RL(self.STATE_IS_PLACEHOLDER,
+ self._STATE_COLORS[self.STATE_IS_PLACEHOLDER]) +
+ " - Placeholder"),
+ (RL(" ") +
+ RL(self.STATE_UNFEEDABLE,
+ self._STATE_COLORS[self.STATE_UNFEEDABLE]) +
+ " - Unfeedable"),
+ (RL(" ") +
+ RL(self.STATE_CONT,
+ self._STATE_COLORS[self.STATE_CONT]) +
" - Already continued-to; Tensor handle available from output "
"slot(s)"),
- (RL(" ") + RL("O", "yellow") +
+ (RL(" ") +
+ RL(self.STATE_DUMPED_INTERMEDIATE,
+ self._STATE_COLORS[self.STATE_DUMPED_INTERMEDIATE]) +
+ " - Unfeedable"),
+ (RL(" ") +
+ RL(self.STATE_OVERRIDDEN,
+ self._STATE_COLORS[self.STATE_OVERRIDDEN]) +
" - Has overriding (injected) tensor value"),
- (RL(" ") + RL("D", "magenta") +
+ (RL(" ") +
+ RL(self.STATE_DIRTY_VARIABLE,
+ self._STATE_COLORS[self.STATE_DIRTY_VARIABLE]) +
" - Dirty variable: Variable already updated this node stepper.")])
def cont(self, args, screen_info=None):
@@ -347,6 +398,8 @@ class NodeStepperCLI(object):
cont_result = self._node_stepper.cont(
parsed.target_name,
+ invalidate_from_updated_variables=(
+ parsed.invalidate_from_updated_variables),
restore_variable_values=parsed.restore_variable_values)
self._completed_nodes.add(parsed.target_name.split(":")[0])
@@ -363,8 +416,16 @@ class NodeStepperCLI(object):
lines.append(feed_info_line)
if feed_types[feed_name] == stepper.NodeStepper.FEED_TYPE_HANDLE:
font_attr_segs[line_counter] = [
- (len(feed_name) + 2, len(feed_info_line), "green")
+ (len(feed_name) + 2,
+ len(feed_info_line),
+ self._STATE_COLORS[self.STATE_UNFEEDABLE])
]
+ elif (feed_types[feed_name] ==
+ stepper.NodeStepper.FEED_TYPE_DUMPED_INTERMEDIATE):
+ font_attr_segs[line_counter] = [(
+ len(feed_name) + 2,
+ len(feed_info_line),
+ self._STATE_COLORS[self.STATE_DUMPED_INTERMEDIATE])]
elif feed_types[feed_name] == stepper.NodeStepper.FEED_TYPE_OVERRIDE:
font_attr_segs[line_counter] = [
(len(feed_name) + 2, len(feed_info_line), "yellow")
@@ -542,4 +603,3 @@ class NodeStepperCLI(object):
return [(element_name + ":%d" % slot) for slot in slots]
else:
return []
-
diff --git a/tensorflow/python/debug/cli/stepper_cli_test.py b/tensorflow/python/debug/cli/stepper_cli_test.py
index 78c56697ae..0dd4493c95 100644
--- a/tensorflow/python/debug/cli/stepper_cli_test.py
+++ b/tensorflow/python/debug/cli/stepper_cli_test.py
@@ -140,310 +140,339 @@ class NodeStepperSimpleGraphTest(test_util.TensorFlowTestCase):
self.assertLess(node_names.index("e"), node_names.index("f"))
def testListingSortedNodesPresentsTransitveClosure(self):
- cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.e))
+ with stepper.NodeStepper(self.sess, self.e) as node_stepper:
+ cli = stepper_cli.NodeStepperCLI(node_stepper)
- output = cli.list_sorted_nodes([])
- node_names, stat_labels, node_pointer = _parse_sorted_nodes_list(
- output.lines)
+ output = cli.list_sorted_nodes([])
+ node_names, stat_labels, node_pointer = _parse_sorted_nodes_list(
+ output.lines)
- self._assert_nodes_topologically_sorted_with_target_e(node_names)
- self.assertEqual(len(node_names), len(stat_labels))
- for stat_label in stat_labels:
- self.assertEqual(" ", stat_label)
- self.assertEqual(0, node_pointer)
+ self._assert_nodes_topologically_sorted_with_target_e(node_names)
+ self.assertEqual(len(node_names), len(stat_labels))
+ for stat_label in stat_labels:
+ self.assertEqual(" ", stat_label)
+ self.assertEqual(0, node_pointer)
def testListingSortedNodesLabelsPlaceholders(self):
- cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.f))
+ with stepper.NodeStepper(self.sess, self.f) as node_stepper:
+ cli = stepper_cli.NodeStepperCLI(node_stepper)
- output = cli.list_sorted_nodes([])
- node_names, stat_labels, node_pointer = _parse_sorted_nodes_list(
- output.lines)
+ output = cli.list_sorted_nodes([])
+ node_names, stat_labels, node_pointer = _parse_sorted_nodes_list(
+ output.lines)
- self._assert_nodes_topologically_sorted_with_target_f(node_names)
+ self._assert_nodes_topologically_sorted_with_target_f(node_names)
- index_ph = node_names.index("ph")
- self.assertEqual(len(node_names), len(stat_labels))
- for i in xrange(len(stat_labels)):
- if index_ph == i:
- self.assertIn(stepper_cli.NodeStepperCLI.STATE_IS_PLACEHOLDER,
- stat_labels[i])
- else:
- self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_IS_PLACEHOLDER,
- stat_labels[i])
+ index_ph = node_names.index("ph")
+ self.assertEqual(len(node_names), len(stat_labels))
+ for i in xrange(len(stat_labels)):
+ if index_ph == i:
+ self.assertIn(stepper_cli.NodeStepperCLI.STATE_IS_PLACEHOLDER,
+ stat_labels[i])
+ else:
+ self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_IS_PLACEHOLDER,
+ stat_labels[i])
- self.assertEqual(0, node_pointer)
+ self.assertEqual(0, node_pointer)
def testContToNonexistentNodeShouldError(self):
- cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.f))
+ with stepper.NodeStepper(self.sess, self.f) as node_stepper:
+ cli = stepper_cli.NodeStepperCLI(node_stepper)
- output = cli.cont(["foobar"])
- self.assertEqual([
- "ERROR: foobar is not in the transitive closure of this stepper "
- "instance."
- ], output.lines)
+ output = cli.cont(["foobar"])
+ self.assertEqual([
+ "ERROR: foobar is not in the transitive closure of this stepper "
+ "instance."
+ ], output.lines)
def testContToNodeOutsideTransitiveClosureShouldError(self):
- cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.e))
+ with stepper.NodeStepper(self.sess, self.e) as node_stepper:
+ cli = stepper_cli.NodeStepperCLI(node_stepper)
- output = cli.cont(["f"])
- self.assertEqual([
- "ERROR: f is not in the transitive closure of this stepper "
- "instance."
- ], output.lines)
+ output = cli.cont(["f"])
+ self.assertEqual([
+ "ERROR: f is not in the transitive closure of this stepper "
+ "instance."
+ ], output.lines)
def testContToValidNodeShouldUpdateStatus(self):
- cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.e))
+ with stepper.NodeStepper(self.sess, self.e) as node_stepper:
+ cli = stepper_cli.NodeStepperCLI(node_stepper)
- output = cli.list_sorted_nodes([])
- node_names, stat_labels, node_pointer = _parse_sorted_nodes_list(
- output.lines)
+ output = cli.list_sorted_nodes([])
+ node_names, stat_labels, node_pointer = _parse_sorted_nodes_list(
+ output.lines)
- index_c = node_names.index("c")
- self.assertEqual(" ", stat_labels[index_c])
- self.assertEqual(0, node_pointer)
+ index_c = node_names.index("c")
+ self.assertEqual(" ", stat_labels[index_c])
+ self.assertEqual(0, node_pointer)
- output = cli.cont("c")
- node_names, stat_labels, node_pointer = _parse_sorted_nodes_list(
- output.lines)
+ output = cli.cont("c")
+ node_names, stat_labels, node_pointer = _parse_sorted_nodes_list(
+ output.lines)
- self.assertGreaterEqual(len(node_names), 3)
- self.assertIn("c", node_names)
- index_c = node_names.index("c")
- self.assertEqual(index_c, node_pointer)
- self.assertIn(stepper_cli.NodeStepperCLI.STATE_CONT, stat_labels[index_c])
+ self.assertGreaterEqual(len(node_names), 3)
+ self.assertIn("c", node_names)
+ index_c = node_names.index("c")
+ self.assertEqual(index_c, node_pointer)
+ self.assertIn(stepper_cli.NodeStepperCLI.STATE_CONT, stat_labels[index_c])
- output = cli.cont("d")
- node_names, stat_labels, node_pointer = _parse_sorted_nodes_list(
- output.lines)
+ output = cli.cont("d")
+ node_names, stat_labels, node_pointer = _parse_sorted_nodes_list(
+ output.lines)
- used_feed_types = _parsed_used_feeds(output.lines)
- self.assertEqual({"c:0": "handle"}, used_feed_types)
+ used_feed_types = _parsed_used_feeds(output.lines)
+ self.assertEqual({
+ "c:0": stepper.NodeStepper.FEED_TYPE_HANDLE,
+ "a/read:0": stepper.NodeStepper.FEED_TYPE_DUMPED_INTERMEDIATE,
+ }, used_feed_types)
- self.assertGreaterEqual(len(node_names), 3)
- self.assertIn("d", node_names)
- index_d = node_names.index("d")
- self.assertEqual(index_d, node_pointer)
- self.assertIn(stepper_cli.NodeStepperCLI.STATE_CONT, stat_labels[index_d])
+ self.assertGreaterEqual(len(node_names), 3)
+ self.assertIn("d", node_names)
+ index_d = node_names.index("d")
+ self.assertEqual(index_d, node_pointer)
+ self.assertIn(stepper_cli.NodeStepperCLI.STATE_CONT, stat_labels[index_d])
def testSteppingOneStepAtATimeShouldUpdateStatus(self):
- cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.e))
+ with stepper.NodeStepper(self.sess, self.e) as node_stepper:
+ cli = stepper_cli.NodeStepperCLI(node_stepper)
- output = cli.list_sorted_nodes([])
- orig_node_names, _, node_pointer = _parse_sorted_nodes_list(output.lines)
- self.assertEqual(0, node_pointer)
-
- for i in xrange(len(orig_node_names)):
+ output = cli.list_sorted_nodes([])
+ orig_node_names, _, node_pointer = _parse_sorted_nodes_list(output.lines)
+ self.assertEqual(0, node_pointer)
+
+ for i in xrange(len(orig_node_names)):
+ output = cli.step([])
+ node_names, stat_labels, node_pointer = _parse_sorted_nodes_list(
+ output.lines)
+
+ next_node_name = node_names[node_pointer]
+ self.assertEqual(orig_node_names[i], next_node_name)
+
+ self.assertIn(stepper_cli.NodeStepperCLI.STATE_CONT,
+ stat_labels[node_pointer])
+
+ # The order in which the nodes are listed should not change as the
+ # stepping happens.
+ output = cli.list_sorted_nodes([])
+ node_names, _, node_pointer = _parse_sorted_nodes_list(output.lines)
+ self.assertEqual(orig_node_names, node_names)
+
+ if i < len(orig_node_names) - 1:
+ self.assertEqual(i + 1, node_pointer)
+ else:
+ # Stepped over the limit. Pointer should be at -1.
+ self.assertEqual(-1, node_pointer)
+
+ # Attempt to step once more after the end has been reached should error
+ # out.
output = cli.step([])
+ self.assertEqual([
+ "ERROR: Cannot step any further because the end of the sorted "
+ "transitive closure has been reached."
+ ], output.lines)
+
+ def testSteppingMultipleStepsUpdatesStatus(self):
+ with stepper.NodeStepper(self.sess, self.e) as node_stepper:
+ cli = stepper_cli.NodeStepperCLI(node_stepper)
+
+ output = cli.list_sorted_nodes([])
+ orig_node_names, _, _ = _parse_sorted_nodes_list(output.lines)
+
+ output = cli.step(["-t", "3"])
node_names, stat_labels, node_pointer = _parse_sorted_nodes_list(
output.lines)
- next_node_name = node_names[node_pointer]
- self.assertEqual(orig_node_names[i], next_node_name)
+ self.assertEqual(orig_node_names[2], node_names[node_pointer])
- self.assertIn(stepper_cli.NodeStepperCLI.STATE_CONT,
- stat_labels[node_pointer])
+ for i in xrange(node_pointer):
+ self.assertIn(stepper_cli.NodeStepperCLI.STATE_CONT, stat_labels[i])
- # The order in which the nodes are listed should not change as the
- # stepping happens.
- output = cli.list_sorted_nodes([])
- node_names, _, node_pointer = _parse_sorted_nodes_list(output.lines)
- self.assertEqual(orig_node_names, node_names)
+ for i in xrange(node_pointer + 1, len(stat_labels)):
+ self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_CONT, stat_labels[i])
- if i < len(orig_node_names) - 1:
- self.assertEqual(i + 1, node_pointer)
- else:
- # Stepped over the limit. Pointer should be at -1.
- self.assertEqual(-1, node_pointer)
+ def testContToNodeWithoutOutputTensorInClosureShowsNoHandleCached(self):
+ with stepper.NodeStepper(self.sess, self.opt) as node_stepper:
+ sorted_nodes = node_stepper.sorted_nodes()
+ closure_elements = node_stepper.closure_elements()
+
+ # Find a node which is in the list of sorted nodes, but whose output
+ # Tensor is not in the transitive closure.
+ no_output_node = None
+ for node in sorted_nodes:
+ if (node + ":0" not in closure_elements and
+ node + ":1" not in closure_elements):
+ no_output_node = node
+ break
+
+ self.assertIsNotNone(no_output_node)
+
+ cli = stepper_cli.NodeStepperCLI(node_stepper)
+ output = cli.cont([no_output_node])
+ node_names, stat_labels, node_pointer = _parse_sorted_nodes_list(
+ output.lines)
- # Attempt to step once more after the end has been reached should error out.
- output = cli.step([])
- self.assertEqual([
- "ERROR: Cannot step any further because the end of the sorted "
- "transitive closure has been reached."
- ], output.lines)
+ self.assertEqual(no_output_node, node_names[node_pointer])
+ self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_CONT,
+ stat_labels[node_pointer])
- def testSteppingMultipleStepsUpdatesStatus(self):
- cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.e))
+ def testContToUpdateNodeWithTrackingLeadsToDirtyVariableLabel(self):
+ with stepper.NodeStepper(self.sess, self.opt) as node_stepper:
+ cli = stepper_cli.NodeStepperCLI(node_stepper)
+ output = cli.cont(["opt/update_b/ApplyGradientDescent", "-i"])
- output = cli.list_sorted_nodes([])
- orig_node_names, _, _ = _parse_sorted_nodes_list(output.lines)
+ output = cli.list_sorted_nodes([])
+ node_names, stat_labels, _ = _parse_sorted_nodes_list(output.lines)
+ self.assertIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE,
+ stat_labels[node_names.index("b")])
+ self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE,
+ stat_labels[node_names.index("a")])
- output = cli.step(["-t", "3"])
- node_names, stat_labels, node_pointer = _parse_sorted_nodes_list(
- output.lines)
+ def testContToUpdateNodeWithoutTrackingLeadsToNoDirtyVariableLabel(self):
+ with stepper.NodeStepper(self.sess, self.opt) as node_stepper:
+ cli = stepper_cli.NodeStepperCLI(node_stepper)
+ output = cli.cont(["opt/update_b/ApplyGradientDescent"])
- self.assertEqual(orig_node_names[2], node_names[node_pointer])
+ output = cli.list_sorted_nodes([])
+ node_names, stat_labels, _ = _parse_sorted_nodes_list(output.lines)
+ self.assertIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE,
+ stat_labels[node_names.index("b")])
+ self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE,
+ stat_labels[node_names.index("a")])
- for i in xrange(node_pointer):
- self.assertIn(stepper_cli.NodeStepperCLI.STATE_CONT, stat_labels[i])
+ def testContWithRestoreVariablesOptionShouldRestoreVariableValue(self):
+ with stepper.NodeStepper(self.sess, self.opt) as node_stepper:
+ cli = stepper_cli.NodeStepperCLI(node_stepper)
+ output = cli.cont(["opt/update_a/ApplyGradientDescent",
+ "--invalidate_from_updated_variables"])
- for i in xrange(node_pointer + 1, len(stat_labels)):
- self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_CONT, stat_labels[i])
+ # After cont() call on .../update_a/..., Variable a should have been
+ # marked as dirty, whereas b should not have.
+ output = cli.list_sorted_nodes([])
+ node_names, stat_labels, _ = _parse_sorted_nodes_list(output.lines)
+ self.assertIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE,
+ stat_labels[node_names.index("a")])
+ self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE,
+ stat_labels[node_names.index("b")])
- def testContToNodeWithoutOutputTensorInClosureShowsNoHandleCached(self):
- node_stepper = stepper.NodeStepper(self.sess, self.opt)
-
- sorted_nodes = node_stepper.sorted_nodes()
- closure_elements = node_stepper.closure_elements()
-
- # Find a node which is in the list of sorted nodes, but whose output tensor
- # is not in the transitive closure.
- no_output_node = None
- for node in sorted_nodes:
- if (node + ":0" not in closure_elements and
- node + ":1" not in closure_elements):
- no_output_node = node
- break
-
- self.assertIsNotNone(no_output_node)
-
- cli = stepper_cli.NodeStepperCLI(node_stepper)
- output = cli.cont([no_output_node])
- node_names, stat_labels, node_pointer = _parse_sorted_nodes_list(
- output.lines)
-
- self.assertEqual(no_output_node, node_names[node_pointer])
- self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_CONT,
- stat_labels[node_pointer])
-
- def testContToUpdateNodeLeadsToDirtyVariableLabel(self):
- cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.opt))
- output = cli.cont(["opt/update_b/ApplyGradientDescent"])
-
- output = cli.list_sorted_nodes([])
- node_names, stat_labels, _ = _parse_sorted_nodes_list(output.lines)
- self.assertIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE,
- stat_labels[node_names.index("b")])
- self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE,
- stat_labels[node_names.index("a")])
+ output = cli.cont(["opt/update_b/ApplyGradientDescent", "-r", "-i"])
- def testContWithRestoreVariablesOptionShouldRestoreVariableValue(self):
- cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.opt))
- output = cli.cont(["opt/update_a/ApplyGradientDescent"])
-
- # After cont() call on .../update_a/..., Variable a should have been marked
- # as dirty, whereas b should not have.
- output = cli.list_sorted_nodes([])
- node_names, stat_labels, _ = _parse_sorted_nodes_list(output.lines)
- self.assertIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE,
- stat_labels[node_names.index("a")])
- self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE,
- stat_labels[node_names.index("b")])
-
- output = cli.cont(["opt/update_b/ApplyGradientDescent", "-r"])
-
- # After cont() call on .../update_b/... with the -r flag, Variable b should
- # have been marked as dirty, whereas Variable a should not be because it
- # should have been restored.
- output = cli.list_sorted_nodes([])
- node_names, stat_labels, _ = _parse_sorted_nodes_list(output.lines)
- self.assertIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE,
- stat_labels[node_names.index("b")])
- self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE,
- stat_labels[node_names.index("a")])
+ # After cont() call on .../update_b/... with the -r flag, Variable b
+ # should have been marked as dirty, whereas Variable a should not be
+ # because it should have been restored.
+ output = cli.list_sorted_nodes([])
+ node_names, stat_labels, _ = _parse_sorted_nodes_list(output.lines)
+ self.assertIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE,
+ stat_labels[node_names.index("b")])
+ self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE,
+ stat_labels[node_names.index("a")])
def testPrintTensorShouldWorkWithTensorName(self):
- cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.e))
+ with stepper.NodeStepper(self.sess, self.e) as node_stepper:
+ cli = stepper_cli.NodeStepperCLI(node_stepper)
- cli.cont("d")
- output = cli.print_tensor(["d:0"])
+ cli.cont("d")
+ output = cli.print_tensor(["d:0"])
- self.assertEqual("Tensor \"d:0\":", output.lines[0])
- self.assertEqual("-20.0", output.lines[-1])
+ self.assertEqual("Tensor \"d:0\":", output.lines[0])
+ self.assertEqual("-20.0", output.lines[-1])
def testPrintTensorShouldWorkWithNodeNameWithOutputTensor(self):
- cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.e))
+ with stepper.NodeStepper(self.sess, self.e) as node_stepper:
+ cli = stepper_cli.NodeStepperCLI(node_stepper)
- cli.cont("d")
- output = cli.print_tensor(["d"])
+ cli.cont("d")
+ output = cli.print_tensor(["d"])
- self.assertEqual("Tensor \"d:0\":", output.lines[0])
- self.assertEqual("-20.0", output.lines[-1])
+ self.assertEqual("Tensor \"d:0\":", output.lines[0])
+ self.assertEqual("-20.0", output.lines[-1])
def testPrintTensorShouldWorkSlicingString(self):
ph_value = np.array([[1.0, 0.0], [0.0, 2.0]])
- cli = stepper_cli.NodeStepperCLI(
- stepper.NodeStepper(
- self.sess, self.f, feed_dict={self.ph: ph_value}))
+ with stepper.NodeStepper(
+ self.sess, self.f, feed_dict={self.ph: ph_value}) as node_stepper:
+ cli = stepper_cli.NodeStepperCLI(node_stepper)
- output = cli.print_tensor(["ph:0[:, 1]"])
- self.assertEqual("Tensor \"ph:0[:, 1]\":", output.lines[0])
- self.assertEqual(repr(ph_value[:, 1]), output.lines[-1])
+ output = cli.print_tensor(["ph:0[:, 1]"])
+ self.assertEqual("Tensor \"ph:0[:, 1]\":", output.lines[0])
+ self.assertEqual(repr(ph_value[:, 1]), output.lines[-1])
- output = cli.print_tensor(["ph[:, 1]"])
- self.assertEqual("Tensor \"ph:0[:, 1]\":", output.lines[0])
- self.assertEqual(repr(ph_value[:, 1]), output.lines[-1])
+ output = cli.print_tensor(["ph[:, 1]"])
+ self.assertEqual("Tensor \"ph:0[:, 1]\":", output.lines[0])
+ self.assertEqual(repr(ph_value[:, 1]), output.lines[-1])
def testPrintTensorWithNonexistentTensorShouldError(self):
- cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.e))
+ with stepper.NodeStepper(self.sess, self.e) as node_stepper:
+ cli = stepper_cli.NodeStepperCLI(node_stepper)
- output = cli.print_tensor(["foobar"])
- self.assertEqual([
- "ERROR: foobar is not in the transitive closure of this stepper "
- "instance."
- ], output.lines)
+ output = cli.print_tensor(["foobar"])
+ self.assertEqual([
+ "ERROR: foobar is not in the transitive closure of this stepper "
+ "instance."
+ ], output.lines)
def testPrintTensorWithNoHandleShouldError(self):
- cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.e))
+ with stepper.NodeStepper(self.sess, self.e) as node_stepper:
+ cli = stepper_cli.NodeStepperCLI(node_stepper)
- output = cli.print_tensor("e")
- self.assertEqual([
- "This stepper instance does not have access to the value of tensor "
- "\"e:0\""
- ], output.lines)
+ output = cli.print_tensor("e")
+ self.assertEqual([
+ "This stepper instance does not have access to the value of tensor "
+ "\"e:0\""
+ ], output.lines)
def testInjectTensorValueByTensorNameShouldBeReflected(self):
- node_stepper = stepper.NodeStepper(self.sess, self.e)
- cli = stepper_cli.NodeStepperCLI(node_stepper)
+ with stepper.NodeStepper(self.sess, self.e) as node_stepper:
+ cli = stepper_cli.NodeStepperCLI(node_stepper)
- output = cli.cont(["d"])
- node_names, _, node_pointer = _parse_sorted_nodes_list(output.lines)
- self.assertEqual("d", node_names[node_pointer])
+ output = cli.cont(["d"])
+ node_names, _, node_pointer = _parse_sorted_nodes_list(output.lines)
+ self.assertEqual("d", node_names[node_pointer])
- output = cli.list_sorted_nodes([])
- node_names, stat_labels, node_pointer = _parse_sorted_nodes_list(
- output.lines)
+ output = cli.list_sorted_nodes([])
+ node_names, stat_labels, node_pointer = _parse_sorted_nodes_list(
+ output.lines)
- index_d = node_names.index("d")
- self.assertIn(stepper_cli.NodeStepperCLI.STATE_CONT, stat_labels[index_d])
- self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_OVERRIDDEN,
- stat_labels[index_d])
+ index_d = node_names.index("d")
+ self.assertIn(stepper_cli.NodeStepperCLI.STATE_CONT, stat_labels[index_d])
+ self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_OVERRIDDEN,
+ stat_labels[index_d])
- self.assertAllClose(-20.0, node_stepper.get_tensor_value("d:0"))
+ self.assertAllClose(-20.0, node_stepper.get_tensor_value("d:0"))
- output = cli.inject_value(["d:0", "20.0"])
+ output = cli.inject_value(["d:0", "20.0"])
- # Verify that the override is available.
- self.assertEqual(["d:0"], node_stepper.override_names())
+ # Verify that the override is available.
+ self.assertEqual(["d:0"], node_stepper.override_names())
- # Verify that the list of sorted nodes reflects the existence of the value
- # override (i.e., injection).
- output = cli.list_sorted_nodes([])
- node_names, stat_labels, node_pointer = _parse_sorted_nodes_list(
- output.lines)
+ # Verify that the list of sorted nodes reflects the existence of the value
+ # override (i.e., injection).
+ output = cli.list_sorted_nodes([])
+ node_names, stat_labels, node_pointer = _parse_sorted_nodes_list(
+ output.lines)
- index_d = node_names.index("d")
- self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_CONT,
- stat_labels[index_d])
- self.assertIn(stepper_cli.NodeStepperCLI.STATE_OVERRIDDEN,
- stat_labels[index_d])
+ index_d = node_names.index("d")
+ self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_CONT,
+ stat_labels[index_d])
+ self.assertIn(stepper_cli.NodeStepperCLI.STATE_OVERRIDDEN,
+ stat_labels[index_d])
def testInjectTensorValueByNodeNameShouldBeReflected(self):
- node_stepper = stepper.NodeStepper(self.sess, self.e)
- cli = stepper_cli.NodeStepperCLI(node_stepper)
+ with stepper.NodeStepper(self.sess, self.e) as node_stepper:
+ cli = stepper_cli.NodeStepperCLI(node_stepper)
- cli.inject_value(["d", "20.0"])
- self.assertEqual(["d:0"], node_stepper.override_names())
+ cli.inject_value(["d", "20.0"])
+ self.assertEqual(["d:0"], node_stepper.override_names())
def testInjectToNonexistentTensorShouldError(self):
- node_stepper = stepper.NodeStepper(self.sess, self.e)
- cli = stepper_cli.NodeStepperCLI(node_stepper)
-
- output = cli.inject_value(["foobar:0", "20.0"])
- self.assertEqual([
- "ERROR: foobar:0 is not in the transitive closure of this stepper "
- "instance."
- ], output.lines)
+ with stepper.NodeStepper(self.sess, self.e) as node_stepper:
+ cli = stepper_cli.NodeStepperCLI(node_stepper)
+
+ output = cli.inject_value(["foobar:0", "20.0"])
+ self.assertEqual([
+ "ERROR: foobar:0 is not in the transitive closure of this stepper "
+ "instance."
+ ], output.lines)
if __name__ == "__main__":
diff --git a/tensorflow/python/debug/stepper.py b/tensorflow/python/debug/stepper.py
index fb5e972627..3cbe83d072 100644
--- a/tensorflow/python/debug/stepper.py
+++ b/tensorflow/python/debug/stepper.py
@@ -18,11 +18,16 @@ from __future__ import division
from __future__ import print_function
import copy
+import os
+import shutil
+import tempfile
+import time
import six
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.debug import debug_data
+from tensorflow.python.debug import debug_utils
from tensorflow.python.framework import ops
from tensorflow.python.ops import session_ops
@@ -64,9 +69,17 @@ class NodeStepper(object):
tree of the target. When it reaches an input where one of the following is
available, it will supply the available value to the feed_dict of the cont()
call:
- (1) TensorHandles from previous cont() calls.
- (2) Overriding (injected) values from the client.
- (3) Feeds supplied during the construction of the stepper instance.
+ (1) Overriding (injected) values from the client.
+ (2) TensorHandles from previous cont() calls.
+ (3) Dumped intermediate Tensors from previous cont() calls.
+ (4) Feeds supplied during the construction of the stepper instance.
+
+ During the cont() call, intermediate Tensors are dumped to temporary
+ directories. The dumped Tensor values will be used in subsequent cont() calls
+ when they are required as data dependencies.
+
+ The temporary directories are automatically clean when the NodeStepper
+ instance exits as a context mananger.
Once the tracing is complete, it will issue a run() call on the
underlying session, using the aforementioned feed_dict prepared by the input
@@ -95,10 +108,7 @@ class NodeStepper(object):
FEED_TYPE_CLIENT = "client"
FEED_TYPE_HANDLE = "handle"
FEED_TYPE_OVERRIDE = "override"
-
- # TODO(cais): The following member constant is currently unused. Use it when
- # the stepper is capable of using dumped intermediate tensors.
- FEED_TYPE_INTERMEDIATE = "intermediate"
+ FEED_TYPE_DUMPED_INTERMEDIATE = "dumped_intermediate"
def __init__(self, sess, fetches, feed_dict=None):
"""Constructor for Debugger.
@@ -125,11 +135,15 @@ class NodeStepper(object):
self._variable_initial_values = {}
# Initialize the map for output recipients (targets).
- self._non_control_output_targets = {}
+ self._output_targets = {}
# Sorted transitive closure of the fetched node.
- self._sorted_nodes, self._closure_elements = self._dfs_visit(
- self._sess.graph, self._fetch_list)
+ # We also collect the list of the names of the reference-type Tensors,
+ # because we later need to avoid using intermediate dumps for such Tensors.
+ (self._sorted_nodes,
+ self._closure_elements,
+ self._ref_tensor_names) = self._dfs_visit(self._sess.graph,
+ self._fetch_list)
self._transitive_closure_set = set(self._sorted_nodes)
@@ -146,6 +160,11 @@ class NodeStepper(object):
# tensor handles.
self._tensor_handles = {}
+ # Cached intermediate tensor values: a dict mapping tensor names to
+ # DebugTensorDatum.
+ self._dumped_intermediate_tensors = {}
+ self._dump_session_root = tempfile.mkdtemp(prefix="tfdbg_stepper_")
+
# Feed dict from the client.
self._client_feed_dict = {}
if feed_dict:
@@ -161,6 +180,13 @@ class NodeStepper(object):
# What the feed types were used by the last cont() call.
self._last_feed_types = {}
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_value, exc_traceback):
+ if os.path.isdir(self._dump_session_root):
+ shutil.rmtree(self._dump_session_root)
+
def _get_fetch_and_name_lists(self, flattened_fetches):
"""Get the lists of fetches and their names.
@@ -220,6 +246,11 @@ class NodeStepper(object):
# Graph elements in the transitive closure, including the nodes and tensors.
closure_elements = [elem.name for elem in elem_list]
+ ref_tensor_names = set()
+ for element in elem_list:
+ if isinstance(element, ops.Tensor) and element.dtype._is_ref_dtype: # pylint: disable=protected-access
+ ref_tensor_names.add(element.name)
+
while elem_stack:
curr_elem = elem_stack.pop()
curr_node = self._get_node(curr_elem)
@@ -238,22 +269,21 @@ class NodeStepper(object):
# Iterate through the (non-control) inputs.
for inp in all_inputs:
- is_non_control_input = inp in non_control_inputs
-
# Set up the non-control output map.
- if is_non_control_input:
- if inp.name not in self._non_control_output_targets:
- self._non_control_output_targets[inp.name] = set([curr_elem.name])
- else:
- self._non_control_output_targets[inp.name].add(curr_elem.name)
-
- if (inp.op.type in ["Variable", "VariableV2"] and
- inp.name not in self._variable_initializers):
- # Obtain the initializer op of the variable, in case the Variable's
- # value needs to be restored later.
- initializer = graph.as_graph_element(inp.op.name + "/Assign")
- self._variable_initializers[inp.name] = initializer
- self._variable_initial_values[inp.name] = initializer.inputs[1]
+ # if is_non_control_input:
+ if inp.name not in self._output_targets:
+ self._output_targets[inp.name] = set([curr_elem.name])
+ else:
+ self._output_targets[inp.name].add(curr_elem.name)
+
+ if (isinstance(inp, ops.Tensor) and
+ inp.op.type in ["Variable", "VariableV2"] and
+ inp.name not in self._variable_initializers):
+ # Obtain the initializer op of the variable, in case the Variable's
+ # value needs to be restored later.
+ initializer = graph.as_graph_element(inp.op.name + "/Assign")
+ self._variable_initializers[inp.name] = initializer
+ self._variable_initial_values[inp.name] = initializer.inputs[1]
inp_node = self._get_node(inp)
if inp_node.name in done:
@@ -262,6 +292,8 @@ class NodeStepper(object):
elem_stack.append(inp)
closure_elements.append(inp.name)
+ if isinstance(inp, ops.Tensor) and inp.dtype._is_ref_dtype: # pylint: disable=protected-access
+ ref_tensor_names.add(inp.name)
# Now that we have traversed the transitive closure and obtained the
# node-input map, we can topologically sort them.
@@ -291,7 +323,7 @@ class NodeStepper(object):
stack.extend(pushes)
- return sorted_nodes, closure_elements
+ return sorted_nodes, closure_elements, ref_tensor_names
def sorted_nodes(self):
"""Get a topologically-sorted list of node names of the stepper.
@@ -409,10 +441,15 @@ class NodeStepper(object):
return self._last_feed_types
+ # TODO(cais): Add method last_updated_variables() to allow client to look
+ # up the Variables updated in the last cont() call.
+
def cont(self,
target,
use_tensor_handles=True,
+ use_dumped_intermediates=True,
use_overrides=True,
+ invalidate_from_updated_variables=False,
restore_variable_values=False):
"""Continue till the completion of the specified target tensor.
@@ -423,8 +460,13 @@ class NodeStepper(object):
# TODO(cais): Support multiple fetches as in Session.run() interface.
use_tensor_handles: (bool) Whether this cont() run will use cached tensor
handles to avoid recomputation. Default: True.
+ use_dumped_intermediates: (bool) Whether this cont() call will use dumped
+ intermediate tensors to avoid recomputation.
use_overrides: (bool) Whether the overriding tensor values supplied by
the client are to be used in this cont() call. Default: True.
+ invalidate_from_updated_variables: (bool) Whether to invalidate the
+ tensor handles and intermediate tensor handles affected by the
+ Variable updates that happen in this cont() call.
restore_variable_values: (bool) Whether the old values of the variables
(before any cont() calls in this object) are to be restored.
@@ -441,9 +483,6 @@ class NodeStepper(object):
self._last_feed_types = {}
- # The feeds to be used in the Session.run() call.
- feeds = {}
-
if isinstance(target, six.string_types):
# Fetch target is a string. Assume it is the name of the Tensor or Op and
# will attempt to find it in the Session's graph.
@@ -494,6 +533,12 @@ class NodeStepper(object):
self._last_feed_types[target_name] = self.FEED_TYPE_HANDLE
return self._tensor_handles[target_name].eval()
+ # Check if a dumped intermediate tensor can be used on the fetch directly.
+ if (use_dumped_intermediates and
+ target_name in self._dumped_intermediate_tensors):
+ self._last_feed_types[target_name] = self.FEED_TYPE_DUMPED_INTERMEDIATE
+ return self._dumped_intermediate_tensors[target_name].get_tensor()
+
# Check if an overriding tensor value can be used directly.
if use_overrides and target_name in self._override_tensors:
# Override is available. Return the value right away.
@@ -510,6 +555,7 @@ class NodeStepper(object):
# =========================================================================
# Use a non-recursive method to trace the inputs from the node and set up
# the feeds.
+ feeds = {} # The feeds to be used in the Session.run() call.
fetched = self._sess.graph.as_graph_element(target_name)
elem_stack = [fetched]
done = set()
@@ -529,7 +575,7 @@ class NodeStepper(object):
# Determine whether the input is feedable. Reference-type tensors,
# e.g., Variables, should not be fed, because they can change.
if isinstance(inp, ops.Tensor):
- is_inp_ref = inp.dtype._is_ref_dtype # pylint: disable=protected-access
+ is_inp_ref = inp.dtype._is_ref_dtype # pylint: disable=protected-access
can_feed = self._sess.graph.is_feedable(inp) and not is_inp_ref
else:
is_inp_ref = False
@@ -574,11 +620,17 @@ class NodeStepper(object):
# Use client-supplied overriding tensor value.
feeds[inp] = self._override_tensors[inp.name]
self._last_feed_types[inp.name] = self.FEED_TYPE_OVERRIDE
- elif (use_tensor_handles and can_feed and
- inp.name in self._tensor_handles and inp not in feeds):
+ elif (can_feed and inp not in feeds and
+ use_tensor_handles and inp.name in self._tensor_handles):
# Tensor handle found in cache.
feeds[inp] = self._tensor_handles[inp.name].eval()
self._last_feed_types[inp.name] = self.FEED_TYPE_HANDLE
+ elif (can_feed and inp not in feeds and
+ use_dumped_intermediates and
+ inp.name in self._dumped_intermediate_tensors):
+ # Dumped intermediate Tensor found.
+ feeds[inp] = self._dumped_intermediate_tensors[inp.name].get_tensor()
+ self._last_feed_types[inp.name] = self.FEED_TYPE_DUMPED_INTERMEDIATE
elif inp.name in self._client_feed_dict:
# This input is available in the client feed_dict.
feeds[inp] = self._client_feed_dict[inp.name]
@@ -602,14 +654,12 @@ class NodeStepper(object):
for variable in restored_variables:
self._dirty_variables.remove(variable)
- # Prepare RunOptions for DebugTensorWatches
- run_options = config_pb2.RunOptions()
- # TODO(cais): Add fields for watching intermediate tensors.
-
+ (dump_path,
+ run_options) = self._prepare_cont_call_dump_path_and_run_options()
if isinstance(fetched, ops.Operation):
# The fetched is an Operation: Will not get tensor handle.
self._sess.run(fetched, feed_dict=feeds, options=run_options)
- # No return value for a run of an Operation
+ return_value = None
else:
# This is a Tensor: Will get tensor handle and cache it.
# Will also get the additional requested tensor handles (if any).
@@ -622,17 +672,58 @@ class NodeStepper(object):
])
handle_names.extend(additional_handle_requests)
- for handle_name, tensor in zip(handle_names, tensors_to_get_handles_for):
- handle = self._sess.run(session_ops.get_session_handle(tensor),
- feed_dict=feeds,
- options=run_options)
+ handles = self._sess.run(
+ [session_ops.get_session_handle(tensor) for tensor in
+ tensors_to_get_handles_for],
+ feed_dict=feeds,
+ options=run_options)
+ for handle_name, handle in zip(handle_names, handles):
self._tensor_handles[handle_name] = handle
- return self._tensor_handles[target_name].eval()
+ return_value = self._tensor_handles[target_name].eval()
+
+ self._load_dumped_intermediate_tensors(dump_path, target_name)
- # Invalidate caches at the end.
- for touched_variable in touched_variables:
- self._invalidate_transitively_outgoing_cache(touched_variable)
+ if invalidate_from_updated_variables:
+ # Invalidate caches at the end.
+ for touched_variable in touched_variables:
+ self._invalidate_transitively_outgoing_cache(touched_variable)
+
+ return return_value
+
+ def _prepare_cont_call_dump_path_and_run_options(self):
+ """Prepare the dump path and RunOptions for next cont() call.
+
+ Returns:
+ dump_path: (str) Directory path to which the intermediate tensor will be
+ dumped.
+ run_options: (config_pb2.RunOptions) The RunOptions containing the tensor
+ watch options for this graph.
+ """
+ run_options = config_pb2.RunOptions()
+ dump_path = self._cont_call_dump_path()
+ for element_name in self._closure_elements:
+ if ":" in element_name:
+ debug_utils.add_debug_tensor_watch(
+ run_options,
+ debug_data.get_node_name(element_name),
+ output_slot=debug_data.get_output_slot(element_name),
+ debug_urls=["file://" + dump_path])
+
+ return dump_path, run_options
+
+ def _cont_call_dump_path(self):
+ return os.path.join(self._dump_session_root,
+ "cont_%d" % int(time.time() * 1e6))
+
+ def _load_dumped_intermediate_tensors(self, dump_path, target_name):
+ dump_dir = debug_data.DebugDumpDir(dump_path, validate=False)
+ for dump in dump_dir.dumped_tensor_data:
+ if (dump.tensor_name not in self._ref_tensor_names and
+ dump.tensor_name not in self._tensor_handles and
+ dump.tensor_name not in self._override_tensors and
+ dump.tensor_name != target_name):
+ self._dumped_intermediate_tensors[dump.tensor_name] = dump
def _get_node_name(self, graph_element_name):
return graph_element_name.split(":")[0]
@@ -646,29 +737,33 @@ class NodeStepper(object):
Uses non-recursive implementation to avoid stack overflow on deep networks.
- TODO(cais): Currently, only TensorHandle caches are invalidated. Invalidate
- cached intermediate tensor values from dumps when dumps are added.
-
Args:
source_element: The source graph element (e.g., a Variable output slot)
to trace the output from.
"""
- if not self._tensor_handles:
+ if not self._tensor_handles and not self._dumped_intermediate_tensors:
return
# First, use cached invalidation paths to eliminate some cached tensor
- # handles.
- to_delete = []
+ # handles and intermediate tensors.
+ to_delete_handles = []
for handle_name in self._tensor_handles:
if (handle_name in self._cached_invalidation_path and
source_element in self._cached_invalidation_path[handle_name]):
- to_delete.append(handle_name)
-
- for handle_name in to_delete:
+ to_delete_handles.append(handle_name)
+ for handle_name in to_delete_handles:
del self._tensor_handles[handle_name]
- if not self._tensor_handles:
+ to_delete_intermediates = []
+ for intm_tensor_name in self._dumped_intermediate_tensors:
+ if (intm_tensor_name in self._cached_invalidation_path and
+ source_element in self._cached_invalidation_path[intm_tensor_name]):
+ to_delete_intermediates.append(intm_tensor_name)
+ for intermediate in to_delete_intermediates:
+ del self._dumped_intermediate_tensors[intermediate]
+
+ if not self._tensor_handles and not self._dumped_intermediate_tensors:
return
stack = [source_element]
@@ -676,19 +771,22 @@ class NodeStepper(object):
while stack:
curr_element = stack.pop()
-
done.add(curr_element)
- if curr_element in self._tensor_handles:
+ if (curr_element in self._tensor_handles or
+ curr_element in self._dumped_intermediate_tensors):
# Cache the invalidation path for potential future use.
if curr_element not in self._cached_invalidation_path:
self._cached_invalidation_path[curr_element] = set([source_element])
else:
self._cached_invalidation_path[curr_element].add(source_element)
- del self._tensor_handles[curr_element]
+ if curr_element in self._tensor_handles:
+ del self._tensor_handles[curr_element]
+ else:
+ del self._dumped_intermediate_tensors[curr_element]
- targets = self._non_control_output_targets.get(curr_element, [])
+ targets = self._output_targets.get(curr_element, [])
for target in targets:
if target in done:
continue
@@ -740,6 +838,16 @@ class NodeStepper(object):
return set([self._get_node_name(name) for name in self._tensor_handles])
+ def intermediate_tensor_names(self):
+ """Get list of the names of the Tensors for which dumps are available.
+
+ Returns:
+ (list of str) List of the names of the Tensors for which intermediate
+ dumps are available.
+ """
+
+ return self._dumped_intermediate_tensors.keys()
+
def dirty_variables(self):
"""Get the set of variables that are currently "dirty".
@@ -817,6 +925,8 @@ class NodeStepper(object):
return self._override_tensors[tensor_name]
elif tensor_name in self._tensor_handles:
return self._tensor_handles[tensor_name].eval()
+ elif tensor_name in self._dumped_intermediate_tensors:
+ return self._dumped_intermediate_tensors[tensor_name].get_tensor()
else:
raise ValueError(
"This stepper instance does not have access to the value of "
diff --git a/tensorflow/python/debug/stepper_test.py b/tensorflow/python/debug/stepper_test.py
index 20d79f3ade..63501b4fe6 100644
--- a/tensorflow/python/debug/stepper_test.py
+++ b/tensorflow/python/debug/stepper_test.py
@@ -25,6 +25,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import googletest
from tensorflow.python.training import gradient_descent
@@ -54,277 +55,310 @@ class StepperTest(test_util.TensorFlowTestCase):
self.sess = session.Session()
self.sess.run(variables.global_variables_initializer())
- self.sess = session.Session()
- self.sess.run(variables.global_variables_initializer())
-
def tearDown(self):
ops.reset_default_graph()
def testContToFetchNotInTransitiveClosureShouldError(self):
- stepper = NodeStepper(self.sess, "e:0")
-
- sorted_nodes = stepper.sorted_nodes()
- self.assertEqual(7, len(sorted_nodes))
- self.assertLess(sorted_nodes.index("a"), sorted_nodes.index("a/read"))
- self.assertLess(sorted_nodes.index("b"), sorted_nodes.index("b/read"))
- self.assertLess(sorted_nodes.index("a"), sorted_nodes.index("c"))
- self.assertLess(sorted_nodes.index("b"), sorted_nodes.index("c"))
- self.assertLess(sorted_nodes.index("a"), sorted_nodes.index("d"))
- self.assertLess(sorted_nodes.index("d"), sorted_nodes.index("e"))
- self.assertLess(sorted_nodes.index("c"), sorted_nodes.index("e"))
-
- self.assertSetEqual(
- {"e:0", "d:0", "c:0", "a/read:0", "b/read:0", "b:0", "a:0"},
- set(stepper.closure_elements()))
-
- with self.assertRaisesRegexp(
- ValueError,
- "Target \"f:0\" is not in the transitive closure for the fetch of the "
- "stepper"):
- stepper.cont("f:0")
-
- def testContToNodeNameShouldReturnTensorvalue(self):
- stepper = NodeStepper(self.sess, "e:0")
-
- cont_result = stepper.cont("c")
- self.assertAllClose(6.0, cont_result)
+ with NodeStepper(self.sess, "e:0") as stepper:
+ sorted_nodes = stepper.sorted_nodes()
+ self.assertEqual(7, len(sorted_nodes))
+ self.assertLess(sorted_nodes.index("a"), sorted_nodes.index("a/read"))
+ self.assertLess(sorted_nodes.index("b"), sorted_nodes.index("b/read"))
+ self.assertLess(sorted_nodes.index("a"), sorted_nodes.index("c"))
+ self.assertLess(sorted_nodes.index("b"), sorted_nodes.index("c"))
+ self.assertLess(sorted_nodes.index("a"), sorted_nodes.index("d"))
+ self.assertLess(sorted_nodes.index("d"), sorted_nodes.index("e"))
+ self.assertLess(sorted_nodes.index("c"), sorted_nodes.index("e"))
- def testUsingNamesNotUsingIntermediateTensors(self):
- stepper = NodeStepper(self.sess, "e:0")
+ self.assertSetEqual(
+ {"e:0", "d:0", "c:0", "a/read:0", "b/read:0", "b:0", "a:0"},
+ set(stepper.closure_elements()))
- # The first cont() call should have used no feeds.
- result = stepper.cont("c:0")
- self.assertAllClose(6.0, result)
- self.assertEqual({}, stepper.last_feed_types())
+ with self.assertRaisesRegexp(
+ ValueError,
+ "Target \"f:0\" is not in the transitive closure for the fetch of "
+ "the stepper"):
+ stepper.cont("f:0")
- # The second cont() call should have used the tensor handle from the
- # previous cont() call.
- result = stepper.cont("e:0")
- self.assertAllClose(24.0, result)
- self.assertEqual({
- "c:0": NodeStepper.FEED_TYPE_HANDLE
- }, stepper.last_feed_types())
+ def testContToNodeNameShouldReturnTensorValue(self):
+ with NodeStepper(self.sess, "e:0") as stepper:
+ self.assertAllClose(6.0, stepper.cont("c"))
- def testUsingNodesNotUsingIntermediateTensors(self):
- stepper = NodeStepper(self.sess, self.e)
+ def testUsingNamesNotUsingIntermediateTensors(self):
+ with NodeStepper(self.sess, "e:0") as stepper:
+ # The first cont() call should have used no feeds.
+ result = stepper.cont("c:0")
+ self.assertAllClose(6.0, result)
+ self.assertItemsEqual(["a/read:0", "b/read:0"],
+ stepper.intermediate_tensor_names())
+ self.assertAllClose(2.0, stepper.get_tensor_value("a/read:0"))
+ self.assertAllClose(3.0, stepper.get_tensor_value("b/read:0"))
+ self.assertEqual({}, stepper.last_feed_types())
+
+ # The second cont() call should have used the tensor handle from the
+ # previous cont() call.
+ result = stepper.cont("e:0")
+ self.assertAllClose(24.0, result)
+ self.assertItemsEqual(["a/read:0", "b/read:0", "d:0"],
+ stepper.intermediate_tensor_names())
+ self.assertAllClose(2.0, stepper.get_tensor_value("a/read:0"))
+ self.assertAllClose(3.0, stepper.get_tensor_value("b/read:0"))
+ self.assertAllClose(4.0, stepper.get_tensor_value("d:0"))
+ self.assertEqual({
+ "c:0": NodeStepper.FEED_TYPE_HANDLE,
+ "a/read:0": NodeStepper.FEED_TYPE_DUMPED_INTERMEDIATE,
+ }, stepper.last_feed_types())
- # There should be no handles before any cont() calls.
- self.assertEqual([], stepper.handle_names())
- self.assertSetEqual(set(), stepper.handle_node_names())
-
- # Before the cont() call, the stepper should not have access to the value
- # of c:0.
- with self.assertRaisesRegexp(
- ValueError,
- "This stepper instance does not have access to the value of tensor "
- "\"c:0\""):
- stepper.get_tensor_value("c:0")
-
- # Using the node/tensor itself, instead of the name str, should work on
- # cont().
- result = stepper.cont(self.c)
- self.assertAllClose(6.0, result)
- self.assertEqual({}, stepper.last_feed_types())
-
- self.assertEqual(["c:0"], stepper.handle_names())
- self.assertEqual({"c"}, stepper.handle_node_names())
-
- # After the cont() call, the stepper should have access to the value of c:0
- # via a tensor handle.
- self.assertAllClose(6.0, stepper.get_tensor_value("c:0"))
-
- result = stepper.cont(self.e)
- self.assertAllClose(24.0, result)
- self.assertEqual({
- "c:0": NodeStepper.FEED_TYPE_HANDLE
- }, stepper.last_feed_types())
+ def testUsingNodesNotUsingIntermediateTensors(self):
+ with NodeStepper(self.sess, self.e) as stepper:
+ # There should be no handles before any cont() calls.
+ self.assertEqual([], stepper.handle_names())
+ self.assertSetEqual(set(), stepper.handle_node_names())
+
+ # Before the cont() call, the stepper should not have access to the value
+ # of c:0.
+ with self.assertRaisesRegexp(
+ ValueError,
+ "This stepper instance does not have access to the value of tensor "
+ "\"c:0\""):
+ stepper.get_tensor_value("c:0")
+
+ # Using the node/tensor itself, instead of the name str, should work on
+ # cont().
+ result = stepper.cont(self.c)
+ self.assertItemsEqual(["a/read:0", "b/read:0"],
+ stepper.intermediate_tensor_names())
+ self.assertAllClose(6.0, result)
+ self.assertEqual({}, stepper.last_feed_types())
+
+ self.assertEqual(["c:0"], stepper.handle_names())
+ self.assertEqual({"c"}, stepper.handle_node_names())
+
+ # After the cont() call, the stepper should have access to the value of
+ # c:0 via a tensor handle.
+ self.assertAllClose(6.0, stepper.get_tensor_value("c:0"))
+
+ result = stepper.cont(self.e)
+ self.assertAllClose(24.0, result)
+ self.assertItemsEqual(["a/read:0", "b/read:0", "d:0"],
+ stepper.intermediate_tensor_names())
+ self.assertEqual({
+ "c:0": NodeStepper.FEED_TYPE_HANDLE,
+ "a/read:0": NodeStepper.FEED_TYPE_DUMPED_INTERMEDIATE,
+ }, stepper.last_feed_types())
+
+ def testContToTensorWithIntermediateDumpShouldUseDump(self):
+ with NodeStepper(self.sess, ["e:0", "f:0"]) as stepper:
+ stepper.cont("c:0")
+ self.assertItemsEqual(["a/read:0", "b/read:0"],
+ stepper.intermediate_tensor_names())
+ self.assertAllClose(2.0, stepper.get_tensor_value("a/read:0"))
+ self.assertAllClose(3.0, stepper.get_tensor_value("b/read:0"))
+
+ self.assertAllClose(2.0, stepper.cont("a/read:0"))
+ self.assertEqual({
+ "a/read:0": NodeStepper.FEED_TYPE_DUMPED_INTERMEDIATE
+ }, stepper.last_feed_types())
+
+ self.assertAllClose(10.0, stepper.cont("f:0"))
+ self.assertEqual({
+ "b/read:0": NodeStepper.FEED_TYPE_DUMPED_INTERMEDIATE
+ }, stepper.last_feed_types())
+
+ def testDisablingUseDumpedIntermediatesWorks(self):
+ with NodeStepper(self.sess, ["e:0", "f:0"]) as stepper:
+ stepper.cont("c:0")
+ self.assertItemsEqual(["a/read:0", "b/read:0"],
+ stepper.intermediate_tensor_names())
+ self.assertAllClose(2.0, stepper.get_tensor_value("a/read:0"))
+ self.assertAllClose(3.0, stepper.get_tensor_value("b/read:0"))
+
+ self.assertAllClose(10.0,
+ stepper.cont("f:0", use_dumped_intermediates=False))
+ self.assertEqual({}, stepper.last_feed_types())
def testIsFeedableShouldGiveCorrectAnswers(self):
- stepper = NodeStepper(self.sess, self.e)
-
- self.assertTrue(stepper.is_feedable("a/read:0"))
- self.assertTrue(stepper.is_feedable("b/read:0"))
- self.assertTrue(stepper.is_feedable("c:0"))
- self.assertTrue(stepper.is_feedable("d:0"))
+ with NodeStepper(self.sess, self.e) as stepper:
+ self.assertTrue(stepper.is_feedable("a/read:0"))
+ self.assertTrue(stepper.is_feedable("b/read:0"))
+ self.assertTrue(stepper.is_feedable("c:0"))
+ self.assertTrue(stepper.is_feedable("d:0"))
def testOverrideValue(self):
- stepper = NodeStepper(self.sess, self.e)
-
- result = stepper.cont(self.c)
- self.assertAllClose(6.0, result)
- self.assertEqual({}, stepper.last_feed_types())
-
- # There should be no overrides before any cont() calls.
- self.assertEqual([], stepper.override_names())
-
- # Calling cont() on c again should lead to use of the handle.
- result = stepper.cont(self.c)
- self.assertAllClose(6.0, result)
- self.assertEqual({
- "c:0": NodeStepper.FEED_TYPE_HANDLE
- }, stepper.last_feed_types())
-
- # Override c:0.
- stepper.override_tensor("c:0", 7.0)
-
- # After the overriding, calling get_tensor_value() on c:0 should yield the
- # overriding value.
- self.assertEqual(7.0, stepper.get_tensor_value("c:0"))
-
- # Now c:0 should have only an override value, but no cached handle, because
- # the handle should have been invalidated.
- self.assertEqual([], stepper.handle_names())
- self.assertSetEqual(set(), stepper.handle_node_names())
- self.assertEqual(["c:0"], stepper.override_names())
-
- # Run a downstream tensor after the value override.
- result = stepper.cont(self.e)
- self.assertAllClose(28.0, result) # Should reflect the overriding value.
-
- # Should use override, instead of the handle.
- self.assertEqual({
- "c:0": NodeStepper.FEED_TYPE_OVERRIDE
- }, stepper.last_feed_types())
+ with NodeStepper(self.sess, self.e) as stepper:
+ result = stepper.cont(self.c)
+ self.assertAllClose(6.0, result)
+ self.assertEqual({}, stepper.last_feed_types())
+
+ # There should be no overrides before any cont() calls.
+ self.assertEqual([], stepper.override_names())
+
+ # Calling cont() on c again should lead to use of the handle.
+ result = stepper.cont(self.c)
+ self.assertAllClose(6.0, result)
+ self.assertEqual({
+ "c:0": NodeStepper.FEED_TYPE_HANDLE
+ }, stepper.last_feed_types())
+
+ # Override c:0.
+ stepper.override_tensor("c:0", 7.0)
+
+ # After the overriding, calling get_tensor_value() on c:0 should yield the
+ # overriding value.
+ self.assertEqual(7.0, stepper.get_tensor_value("c:0"))
+
+ # Now c:0 should have only an override value, but no cached handle,
+ # because the handle should have been invalidated.
+ self.assertEqual([], stepper.handle_names())
+ self.assertSetEqual(set(), stepper.handle_node_names())
+ self.assertEqual(["c:0"], stepper.override_names())
+
+ # Run a downstream tensor after the value override.
+ result = stepper.cont(self.e)
+ self.assertAllClose(28.0, result) # Should reflect the overriding value.
+
+ # Should use override, instead of the handle.
+ self.assertEqual({
+ "c:0": NodeStepper.FEED_TYPE_OVERRIDE,
+ "a/read:0": NodeStepper.FEED_TYPE_DUMPED_INTERMEDIATE,
+ }, stepper.last_feed_types())
def testOverrideValueTwice(self):
- stepper = NodeStepper(self.sess, self.e)
-
- # Override once.
- stepper.override_tensor("c:0", 7.0)
- self.assertAllClose(28.0, stepper.cont(self.e))
- self.assertEqual({
- "c:0": NodeStepper.FEED_TYPE_OVERRIDE
- }, stepper.last_feed_types())
-
- self.assertEqual(["e:0"], stepper.handle_names())
- self.assertSetEqual({"e"}, stepper.handle_node_names())
- self.assertEqual(["c:0"], stepper.override_names())
-
- # Calling cont(self.e) again. This time the cached tensor handle of e
- # should be used.
- self.assertEqual(28.0, stepper.cont(self.e))
- self.assertEqual({
- "e:0": NodeStepper.FEED_TYPE_HANDLE
- }, stepper.last_feed_types())
-
- # Override c again. This should have invalidated the cache for e.
- stepper.override_tensor("c:0", 8.0)
-
- self.assertEqual([], stepper.handle_names())
- self.assertEqual(set(), stepper.handle_node_names())
- self.assertEqual(["c:0"], stepper.override_names())
-
- self.assertAllClose(32.0, stepper.cont(self.e))
- self.assertEqual({
- "c:0": NodeStepper.FEED_TYPE_OVERRIDE
- }, stepper.last_feed_types())
+ with NodeStepper(self.sess, self.e) as stepper:
+ # Override once.
+ stepper.override_tensor("c:0", 7.0)
+ self.assertAllClose(28.0, stepper.cont(self.e))
+ self.assertEqual({
+ "c:0": NodeStepper.FEED_TYPE_OVERRIDE
+ }, stepper.last_feed_types())
+
+ self.assertEqual(["e:0"], stepper.handle_names())
+ self.assertSetEqual({"e"}, stepper.handle_node_names())
+ self.assertEqual(["c:0"], stepper.override_names())
+
+ # Calling cont(self.e) again. This time the cached tensor handle of e
+ # should be used.
+ self.assertEqual(28.0, stepper.cont(self.e))
+ self.assertEqual({
+ "e:0": NodeStepper.FEED_TYPE_HANDLE
+ }, stepper.last_feed_types())
+
+ # Override c again. This should have invalidated the cache for e.
+ stepper.override_tensor("c:0", 8.0)
+
+ self.assertEqual([], stepper.handle_names())
+ self.assertEqual(set(), stepper.handle_node_names())
+ self.assertEqual(["c:0"], stepper.override_names())
+
+ self.assertAllClose(32.0, stepper.cont(self.e))
+ self.assertEqual({
+ "c:0": NodeStepper.FEED_TYPE_OVERRIDE,
+ "d:0": NodeStepper.FEED_TYPE_DUMPED_INTERMEDIATE,
+ }, stepper.last_feed_types())
def testRemoveOverrideValue(self):
- stepper = NodeStepper(self.sess, self.e)
-
- result = stepper.cont(self.c)
- self.assertAllClose(6.0, result)
- self.assertEqual({}, stepper.last_feed_types())
-
- # The previous cont() step should have generated a cached tensor handle.
- self.assertEqual(["c:0"], stepper.handle_names())
- self.assertSetEqual({"c"}, stepper.handle_node_names())
-
- # Override c:0.
- stepper.override_tensor("c:0", 7.0)
-
- # The overriding should have invalidated the tensor handle.
- self.assertEqual([], stepper.handle_names())
- self.assertSetEqual(set(), stepper.handle_node_names())
- self.assertEqual(["c:0"], stepper.override_names())
-
- result = stepper.cont(self.e)
- self.assertAllClose(28.0, result) # Should reflect the overriding value.
- self.assertEqual({
- "c:0": NodeStepper.FEED_TYPE_OVERRIDE
- }, stepper.last_feed_types())
-
- # The handle to tensor e:0 should have been cached, even though its
- # transitive closure contains an override.
- self.assertIn("e:0", stepper.handle_names())
- self.assertSetEqual({"e"}, stepper.handle_node_names())
-
- # Remove the override.
- stepper.remove_override("c:0")
- # c:0 should not be in the overrides anymore.
- self.assertEqual([], stepper.override_names())
+ with NodeStepper(self.sess, self.e) as stepper:
+ result = stepper.cont(self.c)
+ self.assertAllClose(6.0, result)
+ self.assertEqual({}, stepper.last_feed_types())
+
+ # The previous cont() step should have generated a cached tensor handle.
+ self.assertEqual(["c:0"], stepper.handle_names())
+ self.assertSetEqual({"c"}, stepper.handle_node_names())
+
+ # Override c:0.
+ stepper.override_tensor("c:0", 7.0)
+
+ # The overriding should have invalidated the tensor handle.
+ self.assertEqual([], stepper.handle_names())
+ self.assertSetEqual(set(), stepper.handle_node_names())
+ self.assertEqual(["c:0"], stepper.override_names())
+
+ result = stepper.cont(self.e)
+ self.assertAllClose(28.0, result) # Should reflect the overriding value.
+ self.assertEqual({
+ "c:0": NodeStepper.FEED_TYPE_OVERRIDE,
+ "a/read:0": NodeStepper.FEED_TYPE_DUMPED_INTERMEDIATE,
+ }, stepper.last_feed_types())
+
+ # The handle to tensor e:0 should have been cached, even though its
+ # transitive closure contains an override.
+ self.assertIn("e:0", stepper.handle_names())
+ self.assertSetEqual({"e"}, stepper.handle_node_names())
+
+ # Remove the override.
+ stepper.remove_override("c:0")
+ # c:0 should not be in the overrides anymore.
+ self.assertEqual([], stepper.override_names())
- # Removing the override should have invalidated the tensor handle for c.
- self.assertNotIn("e:0", stepper.handle_names())
- self.assertNotIn("e", stepper.handle_node_names())
+ # Removing the override should have invalidated the tensor handle for c.
+ self.assertNotIn("e:0", stepper.handle_names())
+ self.assertNotIn("e", stepper.handle_node_names())
- # Should reflect the non-overriding value.
- self.assertAllClose(24.0, stepper.cont(self.e))
+ # Should reflect the non-overriding value.
+ self.assertAllClose(24.0, stepper.cont(self.e))
- # This time, the handle to tensor e:0 should have been cached again, even
- # thought its transitive closure contains an override.
- self.assertIn("e:0", stepper.handle_names())
- self.assertIn("e", stepper.handle_node_names())
+ # This time, the handle to tensor e:0 should have been cached again, even
+ # thought its transitive closure contains an override.
+ self.assertIn("e:0", stepper.handle_names())
+ self.assertIn("e", stepper.handle_node_names())
- # Calling cont(self.e) again should have used the tensor handle to e:0.
- self.assertAllClose(24.0, stepper.cont(self.e))
- self.assertEqual({
- "e:0": NodeStepper.FEED_TYPE_HANDLE
- }, stepper.last_feed_types())
+ # Calling cont(self.e) again should have used the tensor handle to e:0.
+ self.assertAllClose(24.0, stepper.cont(self.e))
+ self.assertEqual({
+ "e:0": NodeStepper.FEED_TYPE_HANDLE,
+ }, stepper.last_feed_types())
def testOverrideAndContToSameTensor(self):
- stepper = NodeStepper(self.sess, self.e)
+ with NodeStepper(self.sess, self.e) as stepper:
+ result = stepper.cont(self.c)
+ self.assertAllClose(6.0, result)
+ self.assertEqual({}, stepper.last_feed_types())
+ self.assertEqual(["c:0"], stepper.handle_names())
+ self.assertSetEqual({"c"}, stepper.handle_node_names())
- result = stepper.cont(self.c)
- self.assertAllClose(6.0, result)
- self.assertEqual({}, stepper.last_feed_types())
- self.assertEqual(["c:0"], stepper.handle_names())
- self.assertSetEqual({"c"}, stepper.handle_node_names())
+ self.assertAllClose(6.0, stepper.cont(self.c))
- self.assertAllClose(6.0, stepper.cont(self.c))
+ # The last cont() call should use the tensor handle directly.
+ self.assertEqual({
+ "c:0": NodeStepper.FEED_TYPE_HANDLE
+ }, stepper.last_feed_types())
- # The last cont() call should use the tensor handle directly.
- self.assertEqual({
- "c:0": NodeStepper.FEED_TYPE_HANDLE
- }, stepper.last_feed_types())
+ # Override c:0.
+ stepper.override_tensor("c:0", 7.0)
- # Override c:0.
- stepper.override_tensor("c:0", 7.0)
+ # As a result of the override, the tensor handle should have been
+ # invalidated.
+ self.assertEqual([], stepper.handle_names())
+ self.assertSetEqual(set(), stepper.handle_node_names())
- # As a result of the override, the tensor handle should have been
- # invalidated.
- self.assertEqual([], stepper.handle_names())
- self.assertSetEqual(set(), stepper.handle_node_names())
+ result = stepper.cont(self.c)
+ self.assertAllClose(7.0, result)
- result = stepper.cont(self.c)
- self.assertAllClose(7.0, result)
-
- self.assertEqual({
- "c:0": NodeStepper.FEED_TYPE_OVERRIDE
- }, stepper.last_feed_types())
+ self.assertEqual({
+ "c:0": NodeStepper.FEED_TYPE_OVERRIDE
+ }, stepper.last_feed_types())
def testFinalizeWithPreviousOverrides(self):
- stepper = NodeStepper(self.sess, self.e)
+ with NodeStepper(self.sess, self.e) as stepper:
+ stepper.override_tensor("a/read:0", 20.0)
+ self.assertEqual(["a/read:0"], stepper.override_names())
- stepper.override_tensor("a/read:0", 20.0)
- self.assertEqual(["a/read:0"], stepper.override_names())
+ # Should reflect the overriding value.
+ self.assertAllClose(24000.0, stepper.cont("e:0"))
+ self.assertEqual({
+ "a/read:0": NodeStepper.FEED_TYPE_OVERRIDE
+ }, stepper.last_feed_types())
- # Should reflect the overriding value.
- self.assertAllClose(24000.0, stepper.cont("e:0"))
- self.assertEqual({
- "a/read:0": NodeStepper.FEED_TYPE_OVERRIDE
- }, stepper.last_feed_types())
-
- # Finalize call should have ignored the overriding value.
- self.assertAllClose(24.0, stepper.finalize())
+ # Finalize call should have ignored the overriding value.
+ self.assertAllClose(24.0, stepper.finalize())
def testRemoveNonexistentOverrideValue(self):
- stepper = NodeStepper(self.sess, self.e)
- self.assertEqual([], stepper.override_names())
-
- with self.assertRaisesRegexp(
- ValueError, "No overriding value exists for tensor \"c:0\""):
- stepper.remove_override("c:0")
+ with NodeStepper(self.sess, self.e) as stepper:
+ self.assertEqual([], stepper.override_names())
+ with self.assertRaisesRegexp(
+ ValueError, "No overriding value exists for tensor \"c:0\""):
+ stepper.remove_override("c:0")
def testAttemptToOverrideInvalidTensor(self):
stepper = NodeStepper(self.sess, self.e)
@@ -333,20 +367,18 @@ class StepperTest(test_util.TensorFlowTestCase):
stepper.override_tensor("f:0", 42.0)
def testInvalidOverrideArgumentType(self):
- stepper = NodeStepper(self.sess, self.e)
-
- with self.assertRaisesRegexp(TypeError, "Expected type str; got type"):
- stepper.override_tensor(self.a, 42.0)
+ with NodeStepper(self.sess, self.e) as stepper:
+ with self.assertRaisesRegexp(TypeError, "Expected type str; got type"):
+ stepper.override_tensor(self.a, 42.0)
def testTransitiveClosureWithCrossLinksShouldHaveCorrectOrder(self):
- stepper = NodeStepper(self.sess, "z:0")
-
- sorted_nodes = stepper.sorted_nodes()
- self.assertEqual(4, len(sorted_nodes))
- self.assertLess(sorted_nodes.index("x"), sorted_nodes.index("x/read"))
- self.assertLess(sorted_nodes.index("x"), sorted_nodes.index("y"))
- self.assertLess(sorted_nodes.index("x"), sorted_nodes.index("z"))
- self.assertLess(sorted_nodes.index("y"), sorted_nodes.index("z"))
+ with NodeStepper(self.sess, "z:0") as stepper:
+ sorted_nodes = stepper.sorted_nodes()
+ self.assertEqual(4, len(sorted_nodes))
+ self.assertLess(sorted_nodes.index("x"), sorted_nodes.index("x/read"))
+ self.assertLess(sorted_nodes.index("x"), sorted_nodes.index("y"))
+ self.assertLess(sorted_nodes.index("x"), sorted_nodes.index("z"))
+ self.assertLess(sorted_nodes.index("y"), sorted_nodes.index("z"))
def testNodeStepperConstructorShouldAllowListOrTupleOrDictOfFetches(self):
for i in range(6):
@@ -363,45 +395,44 @@ class StepperTest(test_util.TensorFlowTestCase):
elif i == 5:
fetches = {"e": "e:0", "fz": {"f": "f:0", "z": "z:0"}}
- stepper = NodeStepper(self.sess, fetches)
-
- sorted_nodes = stepper.sorted_nodes()
- self.assertEqual(13, len(sorted_nodes))
-
- # Check the topological order of the sorted nodes.
- self.assertLess(sorted_nodes.index("x"), sorted_nodes.index("x/read"))
- self.assertLess(sorted_nodes.index("x"), sorted_nodes.index("y"))
- self.assertLess(sorted_nodes.index("x"), sorted_nodes.index("z"))
- self.assertLess(sorted_nodes.index("y"), sorted_nodes.index("z"))
-
- self.assertLess(sorted_nodes.index("a"), sorted_nodes.index("a/read"))
- self.assertLess(sorted_nodes.index("b"), sorted_nodes.index("b/read"))
- self.assertLess(sorted_nodes.index("a"), sorted_nodes.index("c"))
- self.assertLess(sorted_nodes.index("b"), sorted_nodes.index("c"))
- self.assertLess(sorted_nodes.index("a"), sorted_nodes.index("d"))
- self.assertLess(sorted_nodes.index("d"), sorted_nodes.index("e"))
- self.assertLess(sorted_nodes.index("c"), sorted_nodes.index("e"))
- self.assertLess(sorted_nodes.index("b"), sorted_nodes.index("f"))
- self.assertLess(sorted_nodes.index("f_y"), sorted_nodes.index("f"))
-
- closure_elements = stepper.closure_elements()
- self.assertIn("x/read:0", closure_elements)
- self.assertIn("e:0", closure_elements)
- self.assertIn("f:0", closure_elements)
-
- self.assertEqual([0], stepper.output_slots_in_closure("x/read"))
- self.assertEqual([0], stepper.output_slots_in_closure("e"))
- self.assertEqual([0], stepper.output_slots_in_closure("f"))
-
- result = stepper.finalize()
- if i == 0 or i == 1 or i == 3 or i == 4:
- self.assertAllClose(24.0, result[0])
- self.assertAllClose(10.0, result[1][0])
- self.assertAllClose(-4.0, result[1][1])
- elif i == 2 or i == 5:
- self.assertAllClose(24.0, result["e"])
- self.assertAllClose(10.0, result["fz"]["f"])
- self.assertAllClose(-4.0, result["fz"]["z"])
+ with NodeStepper(self.sess, fetches) as stepper:
+ sorted_nodes = stepper.sorted_nodes()
+ self.assertEqual(13, len(sorted_nodes))
+
+ # Check the topological order of the sorted nodes.
+ self.assertLess(sorted_nodes.index("x"), sorted_nodes.index("x/read"))
+ self.assertLess(sorted_nodes.index("x"), sorted_nodes.index("y"))
+ self.assertLess(sorted_nodes.index("x"), sorted_nodes.index("z"))
+ self.assertLess(sorted_nodes.index("y"), sorted_nodes.index("z"))
+
+ self.assertLess(sorted_nodes.index("a"), sorted_nodes.index("a/read"))
+ self.assertLess(sorted_nodes.index("b"), sorted_nodes.index("b/read"))
+ self.assertLess(sorted_nodes.index("a"), sorted_nodes.index("c"))
+ self.assertLess(sorted_nodes.index("b"), sorted_nodes.index("c"))
+ self.assertLess(sorted_nodes.index("a"), sorted_nodes.index("d"))
+ self.assertLess(sorted_nodes.index("d"), sorted_nodes.index("e"))
+ self.assertLess(sorted_nodes.index("c"), sorted_nodes.index("e"))
+ self.assertLess(sorted_nodes.index("b"), sorted_nodes.index("f"))
+ self.assertLess(sorted_nodes.index("f_y"), sorted_nodes.index("f"))
+
+ closure_elements = stepper.closure_elements()
+ self.assertIn("x/read:0", closure_elements)
+ self.assertIn("e:0", closure_elements)
+ self.assertIn("f:0", closure_elements)
+
+ self.assertEqual([0], stepper.output_slots_in_closure("x/read"))
+ self.assertEqual([0], stepper.output_slots_in_closure("e"))
+ self.assertEqual([0], stepper.output_slots_in_closure("f"))
+
+ result = stepper.finalize()
+ if i == 0 or i == 1 or i == 3 or i == 4:
+ self.assertAllClose(24.0, result[0])
+ self.assertAllClose(10.0, result[1][0])
+ self.assertAllClose(-4.0, result[1][1])
+ elif i == 2 or i == 5:
+ self.assertAllClose(24.0, result["e"])
+ self.assertAllClose(10.0, result["fz"]["f"])
+ self.assertAllClose(-4.0, result["fz"]["z"])
class StepperTestWithPlaceHolders(test_util.TensorFlowTestCase):
@@ -419,128 +450,222 @@ class StepperTestWithPlaceHolders(test_util.TensorFlowTestCase):
ops.reset_default_graph()
def testGetTensorValueWorksOnPlaceholder(self):
- stepper = NodeStepper(
+ with NodeStepper(
self.sess,
self.y,
feed_dict={
self.ph0: [[1.0, 2.0], [-3.0, 5.0]],
self.ph1: [[-1.0], [0.5]]
- })
-
- self.assertAllClose([[1.0, 2.0], [-3.0, 5.0]],
- stepper.get_tensor_value("ph0"))
- self.assertAllClose([[1.0, 2.0], [-3.0, 5.0]],
- stepper.get_tensor_value("ph0:0"))
- with self.assertRaisesRegexp(
- KeyError, r"The name 'ph0:1' refers to a Tensor which does not exist"):
- stepper.get_tensor_value("ph0:1")
+ }) as stepper:
+ self.assertAllClose([[1.0, 2.0], [-3.0, 5.0]],
+ stepper.get_tensor_value("ph0"))
+ self.assertAllClose([[1.0, 2.0], [-3.0, 5.0]],
+ stepper.get_tensor_value("ph0:0"))
+ with self.assertRaisesRegexp(
+ KeyError,
+ r"The name 'ph0:1' refers to a Tensor which does not exist"):
+ stepper.get_tensor_value("ph0:1")
def testIsPlaceholdersShouldGiveCorrectAnswers(self):
- stepper = NodeStepper(self.sess, self.y)
-
- self.assertTrue(stepper.is_placeholder(self.ph0.name))
- self.assertTrue(stepper.is_placeholder(self.ph1.name))
+ with NodeStepper(self.sess, self.y) as stepper:
+ self.assertTrue(stepper.is_placeholder(self.ph0.name))
+ self.assertTrue(stepper.is_placeholder(self.ph1.name))
- self.assertFalse(stepper.is_placeholder(self.x.name))
- self.assertFalse(stepper.is_placeholder(self.y.name))
+ self.assertFalse(stepper.is_placeholder(self.x.name))
+ self.assertFalse(stepper.is_placeholder(self.y.name))
- with self.assertRaisesRegexp(ValueError,
- "A is not in the transitive closure"):
- self.assertFalse(stepper.is_placeholder("A"))
+ with self.assertRaisesRegexp(ValueError,
+ "A is not in the transitive closure"):
+ self.assertFalse(stepper.is_placeholder("A"))
def testPlaceholdersShouldGiveCorrectAnswers(self):
- stepper = NodeStepper(self.sess, self.y)
-
- self.assertSetEqual({"ph0", "ph1"}, set(stepper.placeholders()))
+ with NodeStepper(self.sess, self.y) as stepper:
+ self.assertSetEqual({"ph0", "ph1"}, set(stepper.placeholders()))
def testContWithPlaceholders(self):
- stepper = NodeStepper(
+ with NodeStepper(
self.sess,
self.y,
feed_dict={
self.ph0: [[1.0, 2.0], [-3.0, 5.0]],
self.ph1: [[-1.0], [0.5]]
- })
-
- self.assertEqual(4, len(stepper.sorted_nodes()))
- self.assertSetEqual({"ph0:0", "ph1:0", "x:0", "y:0"},
- set(stepper.closure_elements()))
-
- result = stepper.cont(self.x)
- self.assertAllClose([[0.0], [5.5]], result)
- self.assertEqual({
- "ph0:0": NodeStepper.FEED_TYPE_CLIENT,
- "ph1:0": NodeStepper.FEED_TYPE_CLIENT,
- }, stepper.last_feed_types())
-
- self.assertEqual(["x:0"], stepper.handle_names())
- self.assertSetEqual({"x"}, stepper.handle_node_names())
-
- result = stepper.cont(self.y)
- self.assertAllClose([[-1.0], [6.0]], result)
- self.assertEqual({
- "x:0": NodeStepper.FEED_TYPE_HANDLE,
- "ph1:0": NodeStepper.FEED_TYPE_CLIENT,
- }, stepper.last_feed_types())
+ }) as stepper:
+ self.assertEqual(4, len(stepper.sorted_nodes()))
+ self.assertSetEqual({"ph0:0", "ph1:0", "x:0", "y:0"},
+ set(stepper.closure_elements()))
+
+ result = stepper.cont(self.x)
+ self.assertAllClose([[0.0], [5.5]], result)
+ self.assertEqual({
+ "ph0:0": NodeStepper.FEED_TYPE_CLIENT,
+ "ph1:0": NodeStepper.FEED_TYPE_CLIENT,
+ }, stepper.last_feed_types())
+
+ self.assertEqual(["x:0"], stepper.handle_names())
+ self.assertSetEqual({"x"}, stepper.handle_node_names())
+
+ result = stepper.cont(self.y)
+ self.assertAllClose([[-1.0], [6.0]], result)
+ self.assertEqual({
+ "x:0": NodeStepper.FEED_TYPE_HANDLE,
+ "ph1:0": NodeStepper.FEED_TYPE_CLIENT,
+ }, stepper.last_feed_types())
def testAttemptToContToPlaceholderWithTensorFeedKeysShouldWork(self):
"""Continuing to a placeholder should be allowed, using client feed."""
ph0_feed = [[1.0, 2.0], [-3.0, 5.0]]
ph1_feed = [[-1.0], [0.5]]
- stepper = NodeStepper(
+ with NodeStepper(
self.sess, self.y, feed_dict={
self.ph0: ph0_feed,
self.ph1: ph1_feed,
- })
-
- self.assertAllClose(ph0_feed, stepper.cont(self.ph0))
- self.assertEqual({
- self.ph0.name: NodeStepper.FEED_TYPE_CLIENT
- }, stepper.last_feed_types())
+ }) as stepper:
+ self.assertAllClose(ph0_feed, stepper.cont(self.ph0))
+ self.assertEqual({
+ self.ph0.name: NodeStepper.FEED_TYPE_CLIENT
+ }, stepper.last_feed_types())
- self.assertAllClose(ph1_feed, stepper.cont(self.ph1))
- self.assertEqual({
- self.ph1.name: NodeStepper.FEED_TYPE_CLIENT
- }, stepper.last_feed_types())
+ self.assertAllClose(ph1_feed, stepper.cont(self.ph1))
+ self.assertEqual({
+ self.ph1.name: NodeStepper.FEED_TYPE_CLIENT
+ }, stepper.last_feed_types())
- ph0_node = self.sess.graph.as_graph_element("ph0")
- self.assertAllClose(ph0_feed, stepper.cont(ph0_node))
- self.assertEqual({
- self.ph0.name: NodeStepper.FEED_TYPE_CLIENT
- }, stepper.last_feed_types())
+ ph0_node = self.sess.graph.as_graph_element("ph0")
+ self.assertAllClose(ph0_feed, stepper.cont(ph0_node))
+ self.assertEqual({
+ self.ph0.name: NodeStepper.FEED_TYPE_CLIENT
+ }, stepper.last_feed_types())
- self.assertAllClose([[-1.0], [6.0]], stepper.finalize())
+ self.assertAllClose([[-1.0], [6.0]], stepper.finalize())
def testAttemptToContToPlaceholderWithTensorNameFeedKeysShouldWork(self):
ph0_feed = [[1.0, 2.0], [-3.0, 5.0]]
ph1_feed = [[-1.0], [0.5]]
- stepper = NodeStepper(
+ with NodeStepper(
self.sess,
self.y,
feed_dict={
self.ph0.name: ph0_feed,
self.ph1.name: ph1_feed,
- })
+ }) as stepper:
+ self.assertAllClose(ph0_feed, stepper.cont(self.ph0))
+ self.assertEqual({
+ self.ph0.name: NodeStepper.FEED_TYPE_CLIENT
+ }, stepper.last_feed_types())
- self.assertAllClose(ph0_feed, stepper.cont(self.ph0))
- self.assertEqual({
- self.ph0.name: NodeStepper.FEED_TYPE_CLIENT
- }, stepper.last_feed_types())
+ self.assertAllClose(ph1_feed, stepper.cont(self.ph1))
+ self.assertEqual({
+ self.ph1.name: NodeStepper.FEED_TYPE_CLIENT
+ }, stepper.last_feed_types())
- self.assertAllClose(ph1_feed, stepper.cont(self.ph1))
- self.assertEqual({
- self.ph1.name: NodeStepper.FEED_TYPE_CLIENT
- }, stepper.last_feed_types())
+ ph0_node = self.sess.graph.as_graph_element("ph0")
+ self.assertAllClose(ph0_feed, stepper.cont(ph0_node))
+ self.assertEqual({
+ self.ph0.name: NodeStepper.FEED_TYPE_CLIENT
+ }, stepper.last_feed_types())
- ph0_node = self.sess.graph.as_graph_element("ph0")
- self.assertAllClose(ph0_feed, stepper.cont(ph0_node))
- self.assertEqual({
- self.ph0.name: NodeStepper.FEED_TYPE_CLIENT
- }, stepper.last_feed_types())
+ self.assertAllClose([[-1.0], [6.0]], stepper.finalize())
- self.assertAllClose([[-1.0], [6.0]], stepper.finalize())
+
+class StepperAssignAddTest(test_util.TensorFlowTestCase):
+
+ def setUp(self):
+ self.v = variables.Variable(10.0, name="v")
+ self.p = math_ops.add(self.v, self.v, name="p")
+ self.q = math_ops.multiply(self.p, self.p, name="q")
+ self.delta = constant_op.constant(2.0, name="delta")
+ self.v_add = state_ops.assign_add(self.v, self.delta, name="v_add")
+ self.v_add_plus_one = math_ops.add(self.v_add,
+ 1.0,
+ name="v_add_plus_one")
+
+ self.sess = session.Session()
+ self.sess.run(self.v.initializer)
+
+ def tearDown(self):
+ ops.reset_default_graph()
+
+ def testContToUpdateInvalidatesDumpedIntermedates(self):
+ with NodeStepper(self.sess, [self.q, self.v_add]) as stepper:
+ self.assertAllClose(400.0, stepper.cont("q:0"))
+ self.assertItemsEqual(["v/read:0", "p:0"],
+ stepper.intermediate_tensor_names())
+ self.assertAllClose(10.0, stepper.get_tensor_value("v/read:0"))
+ self.assertAllClose(20.0, stepper.get_tensor_value("p:0"))
+
+ self.assertAllClose(
+ 12.0, stepper.cont(
+ self.v_add, invalidate_from_updated_variables=True))
+ self.assertAllClose(12.0, self.sess.run(self.v))
+ self.assertItemsEqual(["v:0"], stepper.dirty_variables())
+ # Updating the value of v by calling v_add should have invalidated the
+ # dumped intermediate tensors for v/read:0 and p:0.
+ self.assertItemsEqual(["delta:0"], stepper.intermediate_tensor_names())
+ with self.assertRaisesRegexp(
+ ValueError,
+ r"This stepper instance does not have access to the value of tensor "
+ r"\"p:0\""):
+ stepper.get_tensor_value("p:0")
+
+ # The next cont to q should not have used any dumped intermediate tensors
+ # and its result should reflect the updated value.
+ self.assertAllClose(576.0, stepper.cont("q:0"))
+ self.assertEqual({}, stepper.last_feed_types())
+
+ def testOverridingUpstreamTensorInvalidatesDumpedIntermediates(self):
+ with NodeStepper(self.sess, self.q) as stepper:
+ self.assertAllClose(400.0, stepper.cont("q:0"))
+ self.assertItemsEqual(["v/read:0", "p:0"],
+ stepper.intermediate_tensor_names())
+ self.assertAllClose(10.0, stepper.get_tensor_value("v/read:0"))
+ self.assertAllClose(20.0, stepper.get_tensor_value("p:0"))
+
+ stepper.override_tensor("v/read:0", 11.0)
+ self.assertItemsEqual(["v/read:0"], stepper.override_names())
+ # Overriding the upstream v/read:0 should have invalidated the dumped
+ # intermediate tensor for the downstream p:0.
+ self.assertItemsEqual([], stepper.intermediate_tensor_names())
+
+ # The next cont to q should not have used any dumped intermediate tensors
+ # and its result should reflect the overriding value.
+ self.assertAllClose(484.0, stepper.cont("q:0"))
+ self.assertEqual({
+ "v/read:0": NodeStepper.FEED_TYPE_OVERRIDE
+ }, stepper.last_feed_types())
+
+ def testRemovingOverrideToUpstreamTensorInvalidatesDumpedIntermediates(self):
+ with NodeStepper(self.sess, self.q) as stepper:
+ stepper.override_tensor("v/read:0", 9.0)
+ self.assertItemsEqual(["v/read:0"], stepper.override_names())
+
+ self.assertAllClose(324.0, stepper.cont(self.q))
+ self.assertItemsEqual(["p:0"], stepper.intermediate_tensor_names())
+
+ stepper.remove_override("v/read:0")
+ self.assertItemsEqual([], stepper.override_names())
+ # Removing the pre-existing override to v/read:0 should have invalidated
+ # the dumped intermediate tensor.
+ self.assertItemsEqual([], stepper.intermediate_tensor_names())
+
+ def testRepeatedCallsToAssignAddDoesNotUpdateVariableAgain(self):
+ with NodeStepper(self.sess, self.v_add) as stepper:
+ stepper.cont(self.v_add)
+ self.assertAllClose(12.0, stepper.cont(self.v))
+ stepper.cont(self.v_add)
+ self.assertEqual({"v_add:0": NodeStepper.FEED_TYPE_HANDLE},
+ stepper.last_feed_types())
+ self.assertAllClose(12.0, stepper.cont(self.v))
+
+ def testRepeatedCallsToAssignAddDownStreamDoesNotUpdateVariableAgain(self):
+ with NodeStepper(self.sess, self.v_add_plus_one) as stepper:
+ stepper.cont(self.v_add_plus_one)
+ self.assertAllClose(12.0, stepper.cont(self.v))
+ stepper.cont(self.v_add_plus_one)
+ self.assertEqual({"v_add_plus_one:0": NodeStepper.FEED_TYPE_HANDLE},
+ stepper.last_feed_types())
+ self.assertAllClose(12.0, stepper.cont(self.v))
class StepperBackwardRunTest(test_util.TensorFlowTestCase):
@@ -580,93 +705,136 @@ class StepperBackwardRunTest(test_util.TensorFlowTestCase):
ops.reset_default_graph()
def testContToUpdateA(self):
- stepper = NodeStepper(self.sess, "optim")
-
- result = stepper.cont("a:0")
- self.assertAllClose(1.0, result)
- self.assertEqual({}, stepper.last_feed_types())
-
- result = stepper.cont("optim/learning_rate:0")
- self.assertAllClose(0.01, result)
- self.assertEqual({}, stepper.last_feed_types())
-
- # Before any cont calls on ApplyGradientDescent, there should be no "dirty"
- # variables.
- self.assertEqual(set(), stepper.dirty_variables())
-
- # First, all the two control inputs to optim.
- result = stepper.cont("optim/update_a/ApplyGradientDescent")
-
- # Now variable a should have been marked as dirty due to the update
- # by optim/update_a/ApplyGradientDescent.
- self.assertEqual({"a:0"}, stepper.dirty_variables())
- self.assertIsNone(result)
- self.assertEqual({
- "optim/learning_rate:0": NodeStepper.FEED_TYPE_HANDLE
- }, stepper.last_feed_types())
-
- # Check that Variable "a" has been updated properly, but "b", "c" and "d"
- # remain the same.
- # For backprop on Variable a:
- # Because f = a * b * b * c, df / da = b * b * c.
- # 1.0 - learning_rate * b * b * c
- # = 1.0 - 0.01 * 2.0 * 2.0 * 4.0 = 0.84.
- self.assertAllClose(0.84, self.sess.run(self.a))
- self.assertAllClose(2.0, self.sess.run(self.b))
- self.assertAllClose(4.0, self.sess.run(self.c))
+ with NodeStepper(self.sess, "optim") as stepper:
+ result = stepper.cont("a:0")
+ self.assertAllClose(1.0, result)
+ self.assertEqual({}, stepper.last_feed_types())
+
+ result = stepper.cont("optim/learning_rate:0")
+ self.assertAllClose(0.01, result)
+ self.assertEqual({}, stepper.last_feed_types())
+
+ # Before any cont calls on ApplyGradientDescent, there should be no
+ # "dirty" variables.
+ self.assertEqual(set(), stepper.dirty_variables())
+
+ # First, all the two control inputs to optim.
+ result = stepper.cont("optim/update_a/ApplyGradientDescent",
+ invalidate_from_updated_variables=True)
+
+ # Now variable a should have been marked as dirty due to the update
+ # by optim/update_a/ApplyGradientDescent.
+ self.assertEqual({"a:0"}, stepper.dirty_variables())
+ self.assertIsNone(result)
+ self.assertEqual({
+ "optim/learning_rate:0": NodeStepper.FEED_TYPE_HANDLE
+ }, stepper.last_feed_types())
+
+ # Check that Variable "a" has been updated properly, but "b", "c" and "d"
+ # remain the same.
+ # For backprop on Variable a:
+ # Because f = a * b * b * c, df / da = b * b * c.
+ # 1.0 - learning_rate * b * b * c
+ # = 1.0 - 0.01 * 2.0 * 2.0 * 4.0 = 0.84.
+ self.assertAllClose(0.84, self.sess.run(self.a))
+ self.assertAllClose(2.0, self.sess.run(self.b))
+ self.assertAllClose(4.0, self.sess.run(self.c))
def testContToUpdateB(self):
- stepper = NodeStepper(self.sess, "optim")
-
- result = stepper.cont("optim/update_b/ApplyGradientDescent")
- self.assertIsNone(result)
- self.assertEqual(set(["b:0"]), stepper.dirty_variables())
-
- # For backprop on Variable b:
- # Because f = a * b * b * c, df / da = 2 * a * b * c.
- # 2.0 - learning_rate * 2 * a * b * c
- # = 2.0 - 0.01 * 2 * 1.0 * 2.0 * 4.0 = 1.84
- self.assertAllClose(1.0, self.sess.run(self.a))
- self.assertAllClose(1.84, self.sess.run(self.b))
- self.assertAllClose(4.0, self.sess.run(self.c))
+ with NodeStepper(self.sess, "optim") as stepper:
+ result = stepper.cont("optim/update_b/ApplyGradientDescent",
+ invalidate_from_updated_variables=True)
+ self.assertIsNone(result)
+ self.assertEqual(set(["b:0"]), stepper.dirty_variables())
+
+ # For backprop on Variable b:
+ # Because f = a * b * b * c, df / da = 2 * a * b * c.
+ # 2.0 - learning_rate * 2 * a * b * c
+ # = 2.0 - 0.01 * 2 * 1.0 * 2.0 * 4.0 = 1.84
+ self.assertAllClose(1.0, self.sess.run(self.a))
+ self.assertAllClose(1.84, self.sess.run(self.b))
+ self.assertAllClose(4.0, self.sess.run(self.c))
def testContAfterUpdateWithoutRestoringVariableValue(self):
- stepper = NodeStepper(self.sess, "optim")
-
- # First, update Variable a from 1.0 to 0.84.
- result = stepper.cont(
- "optim/update_a/ApplyGradientDescent", restore_variable_values=True)
- self.assertIsNone(result)
- self.assertEqual(set(["a:0"]), stepper.dirty_variables())
- self.assertAllClose(0.84, self.sess.run(self.a))
- self.assertAllClose(2.0, self.sess.run(self.b))
- self.assertAllClose(4.0, self.sess.run(self.c))
-
- # Second, update Variable b without the default restore_variable_values.
- result = stepper.cont(
- "optim/update_b/ApplyGradientDescent", restore_variable_values=False)
- self.assertIsNone(result)
- # For the backprop on Variable b under the updated value of a:
- # 2.0 - learning_rate * 2 * a' * b * c
- # = 2.0 - 0.01 * 2 * 0.84 * 2.0 * 4.0 = 1.8656
- self.assertAllClose(0.84, self.sess.run(self.a))
- self.assertAllClose(1.8656, self.sess.run(self.b))
- self.assertAllClose(4.0, self.sess.run(self.c))
+ with NodeStepper(self.sess, "optim") as stepper:
+ # First, update Variable a from 1.0 to 0.84.
+ result = stepper.cont(
+ "optim/update_a/ApplyGradientDescent",
+ invalidate_from_updated_variables=True,
+ restore_variable_values=True)
+ self.assertIsNone(result)
+ self.assertEqual(set(["a:0"]), stepper.dirty_variables())
+ self.assertAllClose(0.84, self.sess.run(self.a))
+ self.assertAllClose(2.0, self.sess.run(self.b))
+ self.assertAllClose(4.0, self.sess.run(self.c))
+ # Tracking of the updated variables should have invalidated all
+ # intermediate tensors downstream to a:0.
+ self.assertNotIn("a/read:0", stepper.intermediate_tensor_names())
+ self.assertNotIn("d:0", stepper.intermediate_tensor_names())
+
+ # Second, update Variable b without the default restore_variable_values.
+ result = stepper.cont(
+ "optim/update_b/ApplyGradientDescent", restore_variable_values=False)
+ self.assertIsNone(result)
+ # For the backprop on Variable b under the updated value of a:
+ # 2.0 - learning_rate * 2 * a' * b * c
+ # = 2.0 - 0.01 * 2 * 0.84 * 2.0 * 4.0 = 1.8656
+ self.assertAllClose(0.84, self.sess.run(self.a))
+ self.assertAllClose(1.8656, self.sess.run(self.b))
+ self.assertAllClose(4.0, self.sess.run(self.c))
+
+ def testContNotInvalidatingFromVariableUpdatesWorksForNextUpdate(self):
+ with NodeStepper(self.sess, "optim") as stepper:
+ self.assertIsNone(stepper.cont(
+ "optim/update_a/ApplyGradientDescent",
+ invalidate_from_updated_variables=False))
+ # Even though invalidate_from_updated_variables is set to False, dirty
+ # variables should still have been tracked.
+ self.assertEqual({"a:0"}, stepper.dirty_variables())
+ self.assertIn("a/read:0", stepper.intermediate_tensor_names())
+ self.assertIn("b/read:0", stepper.intermediate_tensor_names())
+ self.assertIn("c/read:0", stepper.intermediate_tensor_names())
+ self.assertIn("d:0", stepper.intermediate_tensor_names())
+ self.assertIn("e:0", stepper.intermediate_tensor_names())
+ self.assertIn("optim/learning_rate:0",
+ stepper.intermediate_tensor_names())
+ self.assertNotIn("a:0", stepper.intermediate_tensor_names())
+ self.assertNotIn("b:0", stepper.intermediate_tensor_names())
+ self.assertNotIn("c:0", stepper.intermediate_tensor_names())
+
+ self.assertAllClose(0.84, self.sess.run(self.a))
+ self.assertAllClose(2.0, self.sess.run(self.b))
+ self.assertAllClose(4.0, self.sess.run(self.c))
+
+ # For the backprop on Variable b, the result should reflect the original
+ # value of Variable a, even though Variable a has actually been updated.
+ # 2.0 - learning_rate * 2 * a * b * c
+ # = 2.0 - 0.01 * 2 * 1.0 * 2.0 * 4.0 = 1.84
+ self.assertIsNone(stepper.cont(
+ "optim/update_b/ApplyGradientDescent",
+ invalidate_from_updated_variables=False,
+ restore_variable_values=False))
+ self.assertAllClose(0.84, self.sess.run(self.a))
+ self.assertAllClose(1.84, self.sess.run(self.b))
+ self.assertAllClose(4.0, self.sess.run(self.c))
def testUpdateTwiceRestoreVariable(self):
- stepper = NodeStepper(self.sess, "optim")
-
- result = stepper.cont(
- "optim/update_a/ApplyGradientDescent", restore_variable_values=True)
- self.assertIsNone(result)
- self.assertEqual({"a:0"}, stepper.dirty_variables())
-
- result = stepper.cont(
- "optim/update_b/ApplyGradientDescent", restore_variable_values=True)
- self.assertIsNone(result)
- # Variables a and c should have been restored and hence no longer dirty.
- # Variable b should have been marked as dirty.
- self.assertEqual({"b:0"}, stepper.dirty_variables())
+ with NodeStepper(self.sess, "optim") as stepper:
+ result = stepper.cont(
+ "optim/update_a/ApplyGradientDescent",
+ invalidate_from_updated_variables=True,
+ restore_variable_values=True)
+ self.assertIsNone(result)
+ self.assertEqual({"a:0"}, stepper.dirty_variables())
+
+ result = stepper.cont(
+ "optim/update_b/ApplyGradientDescent",
+ invalidate_from_updated_variables=True,
+ restore_variable_values=True)
+ self.assertIsNone(result)
+ # Variables a and c should have been restored and hence no longer dirty.
+ # Variable b should have been marked as dirty.
+ self.assertEqual({"b:0"}, stepper.dirty_variables())
# The result of the update should be identitcal to as if only update_b is
# run.
@@ -680,179 +848,203 @@ class StepperBackwardRunTest(test_util.TensorFlowTestCase):
"clean" means no Variables have been updated by preceding cont() calls.
"""
- stepper = NodeStepper(self.sess, "optim")
-
- # First, call cont() on the two tensors on the intermediate level: e and f.
- result = stepper.cont("d:0")
- self.assertAllClose(2.0, result)
- self.assertEqual({}, stepper.last_feed_types())
- self.assertEqual(set(), stepper.dirty_variables())
-
- # The cont call above should have restored Variable "b".
- result = stepper.cont("e:0")
- self.assertAllClose(8.0, result)
- self.assertEqual({}, stepper.last_feed_types())
- self.assertEqual(set(), stepper.dirty_variables())
-
- # Now run update_a, so as to let Variable a be diry.
- result = stepper.cont(
- "optim/update_a/ApplyGradientDescent", restore_variable_values=True)
- self.assertIsNone(result)
- self.assertEqual({"a:0"}, stepper.dirty_variables())
-
- # Now, run update_b.
- result = stepper.cont(
- "optim/update_b/ApplyGradientDescent", restore_variable_values=True)
- self.assertIsNone(result)
-
- # The last cont() run should have use the handle of tensor e, but not the
- # handle of tensor d, because the transitive closure of e is clean, whereas
- # that of d is dirty due to the update to a in the previous cont() call.
- self.assertEqual({
- "e:0": NodeStepper.FEED_TYPE_HANDLE
- }, stepper.last_feed_types())
-
- # The result of the update_b should be identical to as if no other
- # update_* cont() calls have occurred before.
- self.assertAllClose(1.0, self.sess.run(self.a))
- self.assertAllClose(1.84, self.sess.run(self.b))
- self.assertAllClose(4.0, self.sess.run(self.c))
+ with NodeStepper(self.sess, "optim") as stepper:
+ # First, call cont() on the two tensors on the intermediate level: e and
+ # f.
+ result = stepper.cont("d:0")
+ self.assertAllClose(2.0, result)
+ self.assertEqual({}, stepper.last_feed_types())
+ self.assertItemsEqual(["a/read:0", "b/read:0"],
+ stepper.intermediate_tensor_names())
+ self.assertItemsEqual(["d:0"], stepper.handle_names())
+ self.assertEqual(set(), stepper.dirty_variables())
+
+ result = stepper.cont("e:0")
+ self.assertAllClose(8.0, result)
+ self.assertEqual({
+ "b/read:0": NodeStepper.FEED_TYPE_DUMPED_INTERMEDIATE
+ }, stepper.last_feed_types())
+ self.assertItemsEqual(["d:0", "e:0"], stepper.handle_names())
+ self.assertItemsEqual(["a/read:0", "b/read:0", "c/read:0"],
+ stepper.intermediate_tensor_names())
+ self.assertEqual(set(), stepper.dirty_variables())
+
+ # Now run update_a, so as to let Variable a be dirty.
+ result = stepper.cont(
+ "optim/update_a/ApplyGradientDescent",
+ invalidate_from_updated_variables=True,
+ restore_variable_values=True)
+ self.assertIsNone(result)
+ # Due to the update to the value of a:0, the dumped intermediate a/read:0
+ # should have been invalidated.
+ self.assertNotIn("a/read:0", stepper.intermediate_tensor_names())
+ # ["b/read:0", "c/read:0"],
+ self.assertEqual({"a:0"}, stepper.dirty_variables())
+
+ # Now, run update_b.
+ result = stepper.cont(
+ "optim/update_b/ApplyGradientDescent", restore_variable_values=True)
+ self.assertIsNone(result)
+
+ # The last cont() run should have use the handle of tensor e, but not the
+ # handle of tensor d, because the transitive closure of e is clean,
+ # whereas that of d is dirty due to the update to a in the previous cont()
+ # call.
+ last_feed_types = stepper.last_feed_types()
+ self.assertNotIn("d:0", last_feed_types)
+ self.assertEqual(NodeStepper.FEED_TYPE_DUMPED_INTERMEDIATE,
+ last_feed_types["b/read:0"])
+ self.assertEqual(NodeStepper.FEED_TYPE_DUMPED_INTERMEDIATE,
+ last_feed_types["c/read:0"])
+
+ # The result of the update_b should be identical to as if no other
+ # update_* cont() calls have occurred before.
+ self.assertAllClose(1.0, self.sess.run(self.a))
+ self.assertAllClose(1.84, self.sess.run(self.b))
+ self.assertAllClose(4.0, self.sess.run(self.c))
def testRestoreVariableValues(self):
"""Test restore_variable_values() restores the old values of variables."""
- stepper = NodeStepper(self.sess, "optim")
-
- stepper.cont(
- "optim/update_b/ApplyGradientDescent", restore_variable_values=True)
- self.assertAllClose(1.84, self.sess.run(self.b))
+ with NodeStepper(self.sess, "optim") as stepper:
+ stepper.cont(
+ "optim/update_b/ApplyGradientDescent",
+ invalidate_from_updated_variables=True,
+ restore_variable_values=True)
+ self.assertAllClose(1.84, self.sess.run(self.b))
- stepper.restore_variable_values()
- self.assertAllClose(2.0, self.sess.run(self.b))
+ stepper.restore_variable_values()
+ self.assertAllClose(2.0, self.sess.run(self.b))
def testFinalize(self):
"""Test finalize() to restore variables and run the original fetch."""
- stepper = NodeStepper(self.sess, "optim")
-
- # Invoke update_b before calling finalize.
- stepper.cont(
- "optim/update_b/ApplyGradientDescent", restore_variable_values=True)
+ with NodeStepper(self.sess, "optim") as stepper:
+ # Invoke update_b before calling finalize.
+ stepper.cont(
+ "optim/update_b/ApplyGradientDescent",
+ invalidate_from_updated_variables=True,
+ restore_variable_values=True)
- result = stepper.finalize()
- self.assertIsNone(result)
+ result = stepper.finalize()
+ self.assertIsNone(result)
- # The results of the Variable updates should be the same as if no cont()
- # call has occurred on update_b.
- self.assertAllClose(0.84, self.sess.run(self.a))
- self.assertAllClose(1.84, self.sess.run(self.b))
- self.assertAllClose(3.96, self.sess.run(self.c))
+ # The results of the Variable updates should be the same as if no cont()
+ # call has occurred on update_b.
+ self.assertAllClose(0.84, self.sess.run(self.a))
+ self.assertAllClose(1.84, self.sess.run(self.b))
+ self.assertAllClose(3.96, self.sess.run(self.c))
- def testOverrideThenContToUpdate(self):
+ def testOverrideThenContToUpdateThenRemoveOverrideThenUpdateAgain(self):
"""Test cont() to update nodes after overriding tensor values."""
- stepper = NodeStepper(self.sess, "optim")
-
- result = stepper.cont("d:0")
- self.assertAllClose(2.0, result)
- self.assertEqual({}, stepper.last_feed_types())
- self.assertEqual(set(), stepper.dirty_variables())
- self.assertEqual(["d:0"], stepper.handle_names())
- self.assertSetEqual({"d"}, stepper.handle_node_names())
-
- # Override the value from 1.0 to 10.0.
- stepper.override_tensor("a/read:0", 10.0)
-
- self.assertEqual(["a/read:0"], stepper.override_names())
-
- result = stepper.cont(
- "optim/update_c/ApplyGradientDescent", restore_variable_values=True)
- self.assertIsNone(result)
-
- # The last cont() call should have not used the tensor handle to d:0,
- # because the transitive closure of d:0 contains an override tensor.
- self.assertEqual({
- "a/read:0": NodeStepper.FEED_TYPE_OVERRIDE
- }, stepper.last_feed_types())
-
- # The tensor handle to d:0 should have been removed due to the dirty
- # transitive closure.
- self.assertEqual([], stepper.handle_names())
- self.assertSetEqual(set(), stepper.handle_node_names())
-
- # For this backprop on c, the overriding value of a/read:0 should have been
- # used:
- # 4.0 - learning_rate * a * b * b
- # = 4.0 - 0.01 * 10.0 * 2.0 * 2.0 = 3.6.
- self.assertAllClose(3.6, self.sess.run(self.c))
-
- # Now remove the overriding value of a/read:0.
- stepper.remove_override("a/read:0")
- self.assertEqual([], stepper.override_names())
-
- # Obtain the tensor handle to d:0 again.
- result = stepper.cont("d:0")
- self.assertAllClose(2.0, result)
- self.assertEqual(["d:0"], stepper.handle_names())
- self.assertSetEqual({"d"}, stepper.handle_node_names())
-
- # Then call update_c again, without restoring c.
- result = stepper.cont(
- "optim/update_c/ApplyGradientDescent", restore_variable_values=False)
- self.assertIsNone(result)
-
- # This time, the d:0 tensor handle should have been used, because its
- # transitive closure is clean.
- self.assertEqual({
- "d:0": NodeStepper.FEED_TYPE_HANDLE
- }, stepper.last_feed_types())
-
- # For this backprop on c, the overriding value of a/read:0 should have been
- # used:
- # 3.6 - learning_rate * a * b * b
- # = 3.6 - 0.01 * 1.0 * 2.0 * 2.0 = 3.56.
- self.assertAllClose(3.56, self.sess.run(self.c))
+ with NodeStepper(self.sess, "optim") as stepper:
+ result = stepper.cont("d:0")
+ self.assertAllClose(2.0, result)
+ self.assertEqual({}, stepper.last_feed_types())
+ self.assertEqual(set(), stepper.dirty_variables())
+ self.assertEqual(["d:0"], stepper.handle_names())
+ self.assertSetEqual({"d"}, stepper.handle_node_names())
+
+ # Override the value from 1.0 to 10.0.
+ stepper.override_tensor("a/read:0", 10.0)
+
+ self.assertEqual(["a/read:0"], stepper.override_names())
+
+ result = stepper.cont(
+ "optim/update_c/ApplyGradientDescent",
+ invalidate_from_updated_variables=True,
+ restore_variable_values=True)
+ self.assertIsNone(result)
+
+ # The last cont() call should have not used the tensor handle to d:0,
+ # because the transitive closure of d:0 contains an override tensor.
+ self.assertEqual({
+ "a/read:0": NodeStepper.FEED_TYPE_OVERRIDE,
+ "b/read:0": NodeStepper.FEED_TYPE_DUMPED_INTERMEDIATE,
+ }, stepper.last_feed_types())
+
+ # The tensor handle to d:0 should have been removed due to the dirty
+ # transitive closure.
+ self.assertEqual([], stepper.handle_names())
+ self.assertSetEqual(set(), stepper.handle_node_names())
+
+ # For this backprop on c, the overriding value of a/read:0 should have
+ # been used:
+ # 4.0 - learning_rate * a * b * b
+ # = 4.0 - 0.01 * 10.0 * 2.0 * 2.0 = 3.6.
+ self.assertAllClose(3.6, self.sess.run(self.c))
+
+ # Now remove the overriding value of a/read:0.
+ stepper.remove_override("a/read:0")
+ self.assertEqual([], stepper.override_names())
+
+ # Obtain the tensor handle to d:0 again.
+ result = stepper.cont("d:0")
+ self.assertAllClose(2.0, result)
+ self.assertEqual(["d:0"], stepper.handle_names())
+ self.assertSetEqual({"d"}, stepper.handle_node_names())
+ self.assertNotIn("a/read:0", stepper.last_feed_types())
+
+ # Then call update_c again, without restoring c.
+ result = stepper.cont("optim/update_c/ApplyGradientDescent",
+ restore_variable_values=False)
+ self.assertIsNone(result)
+ self.assertNotIn("a/read:0", stepper.last_feed_types())
+
+ # This time, the d:0 tensor handle should have been used, because its
+ # transitive closure is clean.
+ self.assertEqual({
+ "b/read:0": NodeStepper.FEED_TYPE_DUMPED_INTERMEDIATE,
+ "d:0": NodeStepper.FEED_TYPE_HANDLE,
+ "optim/learning_rate:0": NodeStepper.FEED_TYPE_DUMPED_INTERMEDIATE,
+ }, stepper.last_feed_types())
+
+ # For this backprop on c, the overriding value of a/read:0 should have
+ # been used:
+ # 3.6 - learning_rate * a * b * b
+ # = 3.6 - 0.01 * 1.0 * 2.0 * 2.0 = 3.56.
+ self.assertAllClose(3.56, self.sess.run(self.c))
def testContToNodeWithOutputTensors(self):
"""cont() to an op should cache its output tensors if appropriate."""
- stepper = NodeStepper(self.sess, "optim")
-
- # In the transitive closure of the stepper, look for an op of which the
- # output tensor also is in the transitive closure.
- # Do not assume a specific op, e.g., ""gradients/e_grad/Reshape_1",
- # because it may vary between builds.
- closure_elements = stepper.closure_elements()
- op_with_output_in_closure = None
- for element_name in closure_elements:
- if element_name + ":0" in closure_elements:
- op_with_output_in_closure = str(element_name)
- break
-
- self.assertEqual([0],
- stepper.output_slots_in_closure(op_with_output_in_closure))
-
- self.assertIsNotNone(op_with_output_in_closure)
- output_tensor = op_with_output_in_closure + ":0"
-
- # The op "gradients/?_grad/Reshape_1" is in the transitive closure of the
- # stepper, because it is the control input to another o. However, its
- # output tensor "gradients/?_grad/Reshape_1:0" is also in the transitive
- # closure, because it is the (non-control) input of certain ops. Calling
- # cont() on the op should lead to the caching of the tensor handle for
- # the output tensor.
- stepper.cont(op_with_output_in_closure)
-
- self.assertEqual([output_tensor], stepper.handle_names())
- self.assertSetEqual({op_with_output_in_closure},
- stepper.handle_node_names())
-
- # Do a cont() call that uses the cached tensor of
- # "gradients/?_grad/Reshape_1:0".
- stepper.cont(output_tensor)
- self.assertEqual({
- output_tensor: NodeStepper.FEED_TYPE_HANDLE
- }, stepper.last_feed_types())
+ with NodeStepper(self.sess, "optim") as stepper:
+ # In the transitive closure of the stepper, look for an op of which the
+ # output tensor also is in the transitive closure.
+ # Do not assume a specific op, e.g., ""gradients/e_grad/Reshape_1",
+ # because it may vary between builds.
+ closure_elements = stepper.closure_elements()
+ op_with_output_in_closure = None
+ for element_name in closure_elements:
+ if element_name + ":0" in closure_elements:
+ op_with_output_in_closure = str(element_name)
+ break
+
+ self.assertEqual(
+ [0], stepper.output_slots_in_closure(op_with_output_in_closure))
+
+ self.assertIsNotNone(op_with_output_in_closure)
+ output_tensor = op_with_output_in_closure + ":0"
+
+ # The op "gradients/?_grad/Reshape_1" is in the transitive closure of the
+ # stepper, because it is the control input to another o. However, its
+ # output tensor "gradients/?_grad/Reshape_1:0" is also in the transitive
+ # closure, because it is the (non-control) input of certain ops. Calling
+ # cont() on the op should lead to the caching of the tensor handle for
+ # the output tensor.
+ stepper.cont(op_with_output_in_closure)
+
+ self.assertEqual([output_tensor], stepper.handle_names())
+ self.assertSetEqual({op_with_output_in_closure},
+ stepper.handle_node_names())
+
+ # Do a cont() call that uses the cached tensor of
+ # "gradients/?_grad/Reshape_1:0".
+ stepper.cont(output_tensor)
+ self.assertEqual({
+ output_tensor: NodeStepper.FEED_TYPE_HANDLE
+ }, stepper.last_feed_types())
if __name__ == "__main__":
diff --git a/tensorflow/python/debug/wrappers/framework.py b/tensorflow/python/debug/wrappers/framework.py
index 145008e902..00ab7277cb 100644
--- a/tensorflow/python/debug/wrappers/framework.py
+++ b/tensorflow/python/debug/wrappers/framework.py
@@ -427,9 +427,10 @@ class BaseDebugWrapperSession(session.SessionInterface):
elif (run_start_resp.action == OnRunStartAction.NON_DEBUG_RUN or
run_start_resp.action == OnRunStartAction.INVOKE_STEPPER):
if run_start_resp.action == OnRunStartAction.INVOKE_STEPPER:
- retvals = self.invoke_node_stepper(
- stepper.NodeStepper(self._sess, fetches, feed_dict),
- restore_variable_values_on_exit=True)
+ with stepper.NodeStepper(
+ self._sess, fetches, feed_dict) as node_stepper:
+ retvals = self.invoke_node_stepper(
+ node_stepper, restore_variable_values_on_exit=True)
# Invoke run() method of the wrapped session.
retvals = self._sess.run(
diff --git a/tensorflow/python/debug/wrappers/hooks.py b/tensorflow/python/debug/wrappers/hooks.py
index 30f0e117e6..58a8f7f9c9 100644
--- a/tensorflow/python/debug/wrappers/hooks.py
+++ b/tensorflow/python/debug/wrappers/hooks.py
@@ -107,10 +107,13 @@ class LocalCLIDebugHook(session_run_hook.SessionRunHook,
run_context.session.graph._finalized = False
# pylint: enable=protected-access
- self.invoke_node_stepper(
- stepper.NodeStepper(run_context.session, run_context.original_args.
- fetches, run_context.original_args.feed_dict),
- restore_variable_values_on_exit=True)
+ with stepper.NodeStepper(
+ run_context.session,
+ run_context.original_args.
+ fetches,
+ run_context.original_args.feed_dict) as node_stepper:
+ self.invoke_node_stepper(
+ node_stepper, restore_variable_values_on_exit=True)
return run_args