aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Shanqing Cai <cais@google.com>2017-02-04 13:04:55 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-02-04 13:28:19 -0800
commit59e6f82a1e4411925e6a32e1488ae6a5381b69e7 (patch)
treeac934361d95c21e495896509ccec7f7aa008bf76
parent79c4c0153026effad515b0f8d286640a55d0010d (diff)
tfdbg: stepper: use dumped intermediate tensors during cont() calls
Each cont() call will use TFDBG's own RunOptions.debug_options to generate intermediate tensor dumps and load the dumps and cached the relevant DebugTensorDatum objects. If future cont() calls require the intermediate tensor values, they will be obtained from the cached DebugTensorDatum objects, saving unnecessary recomputation. The use of such cached intermediate tensor dumps can be disabled in individual cont() calls by using "use_dumped_intermediates=False". Change: 146572672
-rw-r--r--tensorflow/python/debug/BUILD1
-rw-r--r--tensorflow/python/debug/cli/debugger_cli_common.py1
-rw-r--r--tensorflow/python/debug/cli/stepper_cli.py90
-rw-r--r--tensorflow/python/debug/cli/stepper_cli_test.py485
-rw-r--r--tensorflow/python/debug/stepper.py228
-rw-r--r--tensorflow/python/debug/stepper_test.py1394
-rw-r--r--tensorflow/python/debug/wrappers/framework.py7
-rw-r--r--tensorflow/python/debug/wrappers/hooks.py11
8 files changed, 1307 insertions, 910 deletions
diff --git a/tensorflow/python/debug/BUILD b/tensorflow/python/debug/BUILD
index 22eb840791..339a6a72e0 100644
--- a/tensorflow/python/debug/BUILD
+++ b/tensorflow/python/debug/BUILD
@@ -52,6 +52,7 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":debug_data",
+ ":debug_utils",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:session_ops",
"@six_archive//:six",
diff --git a/tensorflow/python/debug/cli/debugger_cli_common.py b/tensorflow/python/debug/cli/debugger_cli_common.py
index f456a2387c..f4274d958d 100644
--- a/tensorflow/python/debug/cli/debugger_cli_common.py
+++ b/tensorflow/python/debug/cli/debugger_cli_common.py
@@ -90,6 +90,7 @@ class RichLine(object):
ret = RichLine()
if isinstance(other, str):
ret.text = self.text + other
+ ret.font_attr_segs = self.font_attr_segs[:]
return ret
elif isinstance(other, RichLine):
ret.text = self.text + other.text
diff --git a/tensorflow/python/debug/cli/stepper_cli.py b/tensorflow/python/debug/cli/stepper_cli.py
index 377fcc0387..bb76c440bc 100644
--- a/tensorflow/python/debug/cli/stepper_cli.py
+++ b/tensorflow/python/debug/cli/stepper_cli.py
@@ -41,7 +41,7 @@ class NodeStepperCLI(object):
STATE_CONT = "H"
# State where an intermediate dump of the tensor is available.
- STATE_INTERMEDIATE = "I"
+ STATE_DUMPED_INTERMEDIATE = "I"
# State where the element is already overridden.
STATE_OVERRIDDEN = "O"
@@ -53,6 +53,8 @@ class NodeStepperCLI(object):
# this NodeStepperCLI instance.
STATE_DIRTY_VARIABLE = "D"
+ STATE_UNFEEDABLE = "U"
+
NEXT_NODE_POINTER_STR = "-->"
_MESSAGE_TEMPLATES = {
@@ -63,6 +65,15 @@ class NodeStepperCLI(object):
"Please use full tensor name.",
}
+ _STATE_COLORS = {
+ STATE_CONT: "green",
+ STATE_DIRTY_VARIABLE: "magenta",
+ STATE_DUMPED_INTERMEDIATE: "blue",
+ STATE_OVERRIDDEN: "yellow",
+ STATE_IS_PLACEHOLDER: "cyan",
+ STATE_UNFEEDABLE: "red",
+ }
+
def __init__(self, node_stepper):
self._node_stepper = node_stepper
@@ -98,6 +109,14 @@ class NodeStepperCLI(object):
type=str,
help="Name of the Tensor or Op to continue to.")
ap.add_argument(
+ "-i",
+ "--invalidate_from_updated_variables",
+ dest="invalidate_from_updated_variables",
+ action="store_true",
+ help="Whether to invalidate the cached "
+ "tensor handles and intermediate tensor handles affected "
+ "by Variable updates in this continue call.")
+ ap.add_argument(
"-r",
"--restore_variable_values",
dest="restore_variable_values",
@@ -211,6 +230,7 @@ class NodeStepperCLI(object):
verbose = True
handle_node_names = self._node_stepper.handle_node_names()
+ intermediate_tensor_names = self._node_stepper.intermediate_tensor_names()
override_names = self._node_stepper.override_names()
dirty_variable_names = [
dirty_variable.split(":")[0]
@@ -242,6 +262,7 @@ class NodeStepperCLI(object):
labels, label_font_attr_segs = self._get_status_labels(
element_name,
handle_node_names,
+ intermediate_tensor_names,
override_names,
dirty_variable_names,
len(node_prefix))
@@ -262,6 +283,7 @@ class NodeStepperCLI(object):
def _get_status_labels(self,
element_name,
handle_node_names,
+ intermediate_tensor_names,
override_names,
dirty_variable_names,
offset):
@@ -278,6 +300,7 @@ class NodeStepperCLI(object):
element_name: (str) name of the graph element.
handle_node_names: (list of str) Names of the nodes of which the output
tensors' handles are available.
+ intermediate_tensor_names: (list of str) TOOD(cais): document.
override_names: (list of str) Names of the tensors of which the values
are overridden.
dirty_variable_names: (list of str) Names of the dirty variables.
@@ -292,19 +315,31 @@ class NodeStepperCLI(object):
status = RL(" " * offset)
node_name = element_name.split(":")[0]
- status += RL("P", "cyan") if node_name in self._placeholders else " "
- status += (RL("U", "red")
+ status += (RL(self.STATE_IS_PLACEHOLDER,
+ self._STATE_COLORS[self.STATE_IS_PLACEHOLDER])
+ if node_name in self._placeholders else " ")
+ status += (RL(self.STATE_UNFEEDABLE,
+ self._STATE_COLORS[self.STATE_UNFEEDABLE])
if not self._node_stepper.is_feedable(str(element_name))
else " ")
- status += (RL("H", "green") if element_name in handle_node_names else " ")
+ status += (RL(self.STATE_CONT, self._STATE_COLORS[self.STATE_CONT])
+ if element_name in handle_node_names else " ")
+
+ intermediate_node_names = [
+ tensor_name.split(":")[0] for tensor_name in intermediate_tensor_names]
+ status += (RL(self.STATE_DUMPED_INTERMEDIATE,
+ self._STATE_COLORS[self.STATE_DUMPED_INTERMEDIATE])
+ if element_name in intermediate_node_names else " ")
slots = self._node_stepper.output_slots_in_closure(element_name)
has_override = any(element_name + ":%d" % slot in override_names
for slot in slots)
- status += RL("O", "yellow") if has_override else " "
- status += (RL(self.STATE_DIRTY_VARIABLE, "magenta")
- if element_name in dirty_variable_names
- else " ")
+ status += (RL(self.STATE_OVERRIDDEN,
+ self._STATE_COLORS[self.STATE_OVERRIDDEN])
+ if has_override else " ")
+ status += (RL(self.STATE_DIRTY_VARIABLE,
+ self._STATE_COLORS[self.STATE_DIRTY_VARIABLE])
+ if element_name in dirty_variable_names else " ")
# TODO(ebreck) Return status here, once the caller is updated with the
# RichLine API.
@@ -320,14 +355,30 @@ class NodeStepperCLI(object):
return debugger_cli_common.rich_text_lines_from_rich_line_list([
RL(""),
RL("Legend:"),
- RL(" ") + RL("P", "cyan") + " - Placeholder",
- RL(" ") + RL("U", "red") + " - Unfeedable",
- (RL(" ") + RL("H", "green") +
+ (RL(" ") +
+ RL(self.STATE_IS_PLACEHOLDER,
+ self._STATE_COLORS[self.STATE_IS_PLACEHOLDER]) +
+ " - Placeholder"),
+ (RL(" ") +
+ RL(self.STATE_UNFEEDABLE,
+ self._STATE_COLORS[self.STATE_UNFEEDABLE]) +
+ " - Unfeedable"),
+ (RL(" ") +
+ RL(self.STATE_CONT,
+ self._STATE_COLORS[self.STATE_CONT]) +
" - Already continued-to; Tensor handle available from output "
"slot(s)"),
- (RL(" ") + RL("O", "yellow") +
+ (RL(" ") +
+ RL(self.STATE_DUMPED_INTERMEDIATE,
+ self._STATE_COLORS[self.STATE_DUMPED_INTERMEDIATE]) +
+ " - Unfeedable"),
+ (RL(" ") +
+ RL(self.STATE_OVERRIDDEN,
+ self._STATE_COLORS[self.STATE_OVERRIDDEN]) +
" - Has overriding (injected) tensor value"),
- (RL(" ") + RL("D", "magenta") +
+ (RL(" ") +
+ RL(self.STATE_DIRTY_VARIABLE,
+ self._STATE_COLORS[self.STATE_DIRTY_VARIABLE]) +
" - Dirty variable: Variable already updated this node stepper.")])
def cont(self, args, screen_info=None):
@@ -347,6 +398,8 @@ class NodeStepperCLI(object):
cont_result = self._node_stepper.cont(
parsed.target_name,
+ invalidate_from_updated_variables=(
+ parsed.invalidate_from_updated_variables),
restore_variable_values=parsed.restore_variable_values)
self._completed_nodes.add(parsed.target_name.split(":")[0])
@@ -363,8 +416,16 @@ class NodeStepperCLI(object):
lines.append(feed_info_line)
if feed_types[feed_name] == stepper.NodeStepper.FEED_TYPE_HANDLE:
font_attr_segs[line_counter] = [
- (len(feed_name) + 2, len(feed_info_line), "green")
+ (len(feed_name) + 2,
+ len(feed_info_line),
+ self._STATE_COLORS[self.STATE_UNFEEDABLE])
]
+ elif (feed_types[feed_name] ==
+ stepper.NodeStepper.FEED_TYPE_DUMPED_INTERMEDIATE):
+ font_attr_segs[line_counter] = [(
+ len(feed_name) + 2,
+ len(feed_info_line),
+ self._STATE_COLORS[self.STATE_DUMPED_INTERMEDIATE])]
elif feed_types[feed_name] == stepper.NodeStepper.FEED_TYPE_OVERRIDE:
font_attr_segs[line_counter] = [
(len(feed_name) + 2, len(feed_info_line), "yellow")
@@ -542,4 +603,3 @@ class NodeStepperCLI(object):
return [(element_name + ":%d" % slot) for slot in slots]
else:
return []
-
diff --git a/tensorflow/python/debug/cli/stepper_cli_test.py b/tensorflow/python/debug/cli/stepper_cli_test.py
index 78c56697ae..0dd4493c95 100644
--- a/tensorflow/python/debug/cli/stepper_cli_test.py
+++ b/tensorflow/python/debug/cli/stepper_cli_test.py
@@ -140,310 +140,339 @@ class NodeStepperSimpleGraphTest(test_util.TensorFlowTestCase):
self.assertLess(node_names.index("e"), node_names.index("f"))
def testListingSortedNodesPresentsTransitveClosure(self):
- cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.e))
+ with stepper.NodeStepper(self.sess, self.e) as node_stepper:
+ cli = stepper_cli.NodeStepperCLI(node_stepper)
- output = cli.list_sorted_nodes([])
- node_names, stat_labels, node_pointer = _parse_sorted_nodes_list(
- output.lines)
+ output = cli.list_sorted_nodes([])
+ node_names, stat_labels, node_pointer = _parse_sorted_nodes_list(
+ output.lines)
- self._assert_nodes_topologically_sorted_with_target_e(node_names)
- self.assertEqual(len(node_names), len(stat_labels))
- for stat_label in stat_labels:
- self.assertEqual(" ", stat_label)
- self.assertEqual(0, node_pointer)
+ self._assert_nodes_topologically_sorted_with_target_e(node_names)
+ self.assertEqual(len(node_names), len(stat_labels))
+ for stat_label in stat_labels:
+ self.assertEqual(" ", stat_label)
+ self.assertEqual(0, node_pointer)
def testListingSortedNodesLabelsPlaceholders(self):
- cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.f))
+ with stepper.NodeStepper(self.sess, self.f) as node_stepper:
+ cli = stepper_cli.NodeStepperCLI(node_stepper)
- output = cli.list_sorted_nodes([])
- node_names, stat_labels, node_pointer = _parse_sorted_nodes_list(
- output.lines)
+ output = cli.list_sorted_nodes([])
+ node_names, stat_labels, node_pointer = _parse_sorted_nodes_list(
+ output.lines)
- self._assert_nodes_topologically_sorted_with_target_f(node_names)
+ self._assert_nodes_topologically_sorted_with_target_f(node_names)
- index_ph = node_names.index("ph")
- self.assertEqual(len(node_names), len(stat_labels))
- for i in xrange(len(stat_labels)):
- if index_ph == i:
- self.assertIn(stepper_cli.NodeStepperCLI.STATE_IS_PLACEHOLDER,
- stat_labels[i])
- else:
- self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_IS_PLACEHOLDER,
- stat_labels[i])
+ index_ph = node_names.index("ph")
+ self.assertEqual(len(node_names), len(stat_labels))
+ for i in xrange(len(stat_labels)):
+ if index_ph == i:
+ self.assertIn(stepper_cli.NodeStepperCLI.STATE_IS_PLACEHOLDER,
+ stat_labels[i])
+ else:
+ self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_IS_PLACEHOLDER,
+ stat_labels[i])
- self.assertEqual(0, node_pointer)
+ self.assertEqual(0, node_pointer)
def testContToNonexistentNodeShouldError(self):
- cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.f))
+ with stepper.NodeStepper(self.sess, self.f) as node_stepper:
+ cli = stepper_cli.NodeStepperCLI(node_stepper)
- output = cli.cont(["foobar"])
- self.assertEqual([
- "ERROR: foobar is not in the transitive closure of this stepper "
- "instance."
- ], output.lines)
+ output = cli.cont(["foobar"])
+ self.assertEqual([
+ "ERROR: foobar is not in the transitive closure of this stepper "
+ "instance."
+ ], output.lines)
def testContToNodeOutsideTransitiveClosureShouldError(self):
- cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.e))
+ with stepper.NodeStepper(self.sess, self.e) as node_stepper:
+ cli = stepper_cli.NodeStepperCLI(node_stepper)
- output = cli.cont(["f"])
- self.assertEqual([
- "ERROR: f is not in the transitive closure of this stepper "
- "instance."
- ], output.lines)
+ output = cli.cont(["f"])
+ self.assertEqual([
+ "ERROR: f is not in the transitive closure of this stepper "
+ "instance."
+ ], output.lines)
def testContToValidNodeShouldUpdateStatus(self):
- cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.e))
+ with stepper.NodeStepper(self.sess, self.e) as node_stepper:
+ cli = stepper_cli.NodeStepperCLI(node_stepper)
- output = cli.list_sorted_nodes([])
- node_names, stat_labels, node_pointer = _parse_sorted_nodes_list(
- output.lines)
+ output = cli.list_sorted_nodes([])
+ node_names, stat_labels, node_pointer = _parse_sorted_nodes_list(
+ output.lines)
- index_c = node_names.index("c")
- self.assertEqual(" ", stat_labels[index_c])
- self.assertEqual(0, node_pointer)
+ index_c = node_names.index("c")
+ self.assertEqual(" ", stat_labels[index_c])
+ self.assertEqual(0, node_pointer)
- output = cli.cont("c")
- node_names, stat_labels, node_pointer = _parse_sorted_nodes_list(
- output.lines)
+ output = cli.cont("c")
+ node_names, stat_labels, node_pointer = _parse_sorted_nodes_list(
+ output.lines)
- self.assertGreaterEqual(len(node_names), 3)
- self.assertIn("c", node_names)
- index_c = node_names.index("c")
- self.assertEqual(index_c, node_pointer)
- self.assertIn(stepper_cli.NodeStepperCLI.STATE_CONT, stat_labels[index_c])
+ self.assertGreaterEqual(len(node_names), 3)
+ self.assertIn("c", node_names)
+ index_c = node_names.index("c")
+ self.assertEqual(index_c, node_pointer)
+ self.assertIn(stepper_cli.NodeStepperCLI.STATE_CONT, stat_labels[index_c])
- output = cli.cont("d")
- node_names, stat_labels, node_pointer = _parse_sorted_nodes_list(
- output.lines)
+ output = cli.cont("d")
+ node_names, stat_labels, node_pointer = _parse_sorted_nodes_list(
+ output.lines)
- used_feed_types = _parsed_used_feeds(output.lines)
- self.assertEqual({"c:0": "handle"}, used_feed_types)
+ used_feed_types = _parsed_used_feeds(output.lines)
+ self.assertEqual({
+ "c:0": stepper.NodeStepper.FEED_TYPE_HANDLE,
+ "a/read:0": stepper.NodeStepper.FEED_TYPE_DUMPED_INTERMEDIATE,
+ }, used_feed_types)
- self.assertGreaterEqual(len(node_names), 3)
- self.assertIn("d", node_names)
- index_d = node_names.index("d")
- self.assertEqual(index_d, node_pointer)
- self.assertIn(stepper_cli.NodeStepperCLI.STATE_CONT, stat_labels[index_d])
+ self.assertGreaterEqual(len(node_names), 3)
+ self.assertIn("d", node_names)
+ index_d = node_names.index("d")
+ self.assertEqual(index_d, node_pointer)
+ self.assertIn(stepper_cli.NodeStepperCLI.STATE_CONT, stat_labels[index_d])
def testSteppingOneStepAtATimeShouldUpdateStatus(self):
- cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.e))
+ with stepper.NodeStepper(self.sess, self.e) as node_stepper:
+ cli = stepper_cli.NodeStepperCLI(node_stepper)
- output = cli.list_sorted_nodes([])
- orig_node_names, _, node_pointer = _parse_sorted_nodes_list(output.lines)
- self.assertEqual(0, node_pointer)
-
- for i in xrange(len(orig_node_names)):
+ output = cli.list_sorted_nodes([])
+ orig_node_names, _, node_pointer = _parse_sorted_nodes_list(output.lines)
+ self.assertEqual(0, node_pointer)
+
+ for i in xrange(len(orig_node_names)):
+ output = cli.step([])
+ node_names, stat_labels, node_pointer = _parse_sorted_nodes_list(
+ output.lines)
+
+ next_node_name = node_names[node_pointer]
+ self.assertEqual(orig_node_names[i], next_node_name)
+
+ self.assertIn(stepper_cli.NodeStepperCLI.STATE_CONT,
+ stat_labels[node_pointer])
+
+ # The order in which the nodes are listed should not change as the
+ # stepping happens.
+ output = cli.list_sorted_nodes([])
+ node_names, _, node_pointer = _parse_sorted_nodes_list(output.lines)
+ self.assertEqual(orig_node_names, node_names)
+
+ if i < len(orig_node_names) - 1:
+ self.assertEqual(i + 1, node_pointer)
+ else:
+ # Stepped over the limit. Pointer should be at -1.
+ self.assertEqual(-1, node_pointer)
+
+ # Attempt to step once more after the end has been reached should error
+ # out.
output = cli.step([])
+ self.assertEqual([
+ "ERROR: Cannot step any further because the end of the sorted "
+ "transitive closure has been reached."
+ ], output.lines)
+
+ def testSteppingMultipleStepsUpdatesStatus(self):
+ with stepper.NodeStepper(self.sess, self.e) as node_stepper:
+ cli = stepper_cli.NodeStepperCLI(node_stepper)
+
+ output = cli.list_sorted_nodes([])
+ orig_node_names, _, _ = _parse_sorted_nodes_list(output.lines)
+
+ output = cli.step(["-t", "3"])
node_names, stat_labels, node_pointer = _parse_sorted_nodes_list(
output.lines)
- next_node_name = node_names[node_pointer]
- self.assertEqual(orig_node_names[i], next_node_name)
+ self.assertEqual(orig_node_names[2], node_names[node_pointer])
- self.assertIn(stepper_cli.NodeStepperCLI.STATE_CONT,
- stat_labels[node_pointer])
+ for i in xrange(node_pointer):
+ self.assertIn(stepper_cli.NodeStepperCLI.STATE_CONT, stat_labels[i])
- # The order in which the nodes are listed should not change as the
- # stepping happens.
- output = cli.list_sorted_nodes([])
- node_names, _, node_pointer = _parse_sorted_nodes_list(output.lines)
- self.assertEqual(orig_node_names, node_names)
+ for i in xrange(node_pointer + 1, len(stat_labels)):
+ self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_CONT, stat_labels[i])
- if i < len(orig_node_names) - 1:
- self.assertEqual(i + 1, node_pointer)
- else:
- # Stepped over the limit. Pointer should be at -1.
- self.assertEqual(-1, node_pointer)
+ def testContToNodeWithoutOutputTensorInClosureShowsNoHandleCached(self):
+ with stepper.NodeStepper(self.sess, self.opt) as node_stepper:
+ sorted_nodes = node_stepper.sorted_nodes()
+ closure_elements = node_stepper.closure_elements()
+
+ # Find a node which is in the list of sorted nodes, but whose output
+ # Tensor is not in the transitive closure.
+ no_output_node = None
+ for node in sorted_nodes:
+ if (node + ":0" not in closure_elements and
+ node + ":1" not in closure_elements):
+ no_output_node = node
+ break
+
+ self.assertIsNotNone(no_output_node)
+
+ cli = stepper_cli.NodeStepperCLI(node_stepper)
+ output = cli.cont([no_output_node])
+ node_names, stat_labels, node_pointer = _parse_sorted_nodes_list(
+ output.lines)
- # Attempt to step once more after the end has been reached should error out.
- output = cli.step([])
- self.assertEqual([
- "ERROR: Cannot step any further because the end of the sorted "
- "transitive closure has been reached."
- ], output.lines)
+ self.assertEqual(no_output_node, node_names[node_pointer])
+ self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_CONT,
+ stat_labels[node_pointer])
- def testSteppingMultipleStepsUpdatesStatus(self):
- cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.e))
+ def testContToUpdateNodeWithTrackingLeadsToDirtyVariableLabel(self):
+ with stepper.NodeStepper(self.sess, self.opt) as node_stepper:
+ cli = stepper_cli.NodeStepperCLI(node_stepper)
+ output = cli.cont(["opt/update_b/ApplyGradientDescent", "-i"])
- output = cli.list_sorted_nodes([])
- orig_node_names, _, _ = _parse_sorted_nodes_list(output.lines)
+ output = cli.list_sorted_nodes([])
+ node_names, stat_labels, _ = _parse_sorted_nodes_list(output.lines)
+ self.assertIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE,
+ stat_labels[node_names.index("b")])
+ self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE,
+ stat_labels[node_names.index("a")])
- output = cli.step(["-t", "3"])
- node_names, stat_labels, node_pointer = _parse_sorted_nodes_list(
- output.lines)
+ def testContToUpdateNodeWithoutTrackingLeadsToNoDirtyVariableLabel(self):
+ with stepper.NodeStepper(self.sess, self.opt) as node_stepper:
+ cli = stepper_cli.NodeStepperCLI(node_stepper)
+ output = cli.cont(["opt/update_b/ApplyGradientDescent"])
- self.assertEqual(orig_node_names[2], node_names[node_pointer])
+ output = cli.list_sorted_nodes([])
+ node_names, stat_labels, _ = _parse_sorted_nodes_list(output.lines)
+ self.assertIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE,
+ stat_labels[node_names.index("b")])
+ self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE,
+ stat_labels[node_names.index("a")])
- for i in xrange(node_pointer):
- self.assertIn(stepper_cli.NodeStepperCLI.STATE_CONT, stat_labels[i])
+ def testContWithRestoreVariablesOptionShouldRestoreVariableValue(self):
+ with stepper.NodeStepper(self.sess, self.opt) as node_stepper:
+ cli = stepper_cli.NodeStepperCLI(node_stepper)
+ output = cli.cont(["opt/update_a/ApplyGradientDescent",
+ "--invalidate_from_updated_variables"])
- for i in xrange(node_pointer + 1, len(stat_labels)):
- self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_CONT, stat_labels[i])
+ # After cont() call on .../update_a/..., Variable a should have been
+ # marked as dirty, whereas b should not have.
+ output = cli.list_sorted_nodes([])
+ node_names, stat_labels, _ = _parse_sorted_nodes_list(output.lines)
+ self.assertIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE,
+ stat_labels[node_names.index("a")])
+ self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE,
+ stat_labels[node_names.index("b")])
- def testContToNodeWithoutOutputTensorInClosureShowsNoHandleCached(self):
- node_stepper = stepper.NodeStepper(self.sess, self.opt)
-
- sorted_nodes = node_stepper.sorted_nodes()
- closure_elements = node_stepper.closure_elements()
-
- # Find a node which is in the list of sorted nodes, but whose output tensor
- # is not in the transitive closure.
- no_output_node = None
- for node in sorted_nodes:
- if (node + ":0" not in closure_elements and
- node + ":1" not in closure_elements):
- no_output_node = node
- break
-
- self.assertIsNotNone(no_output_node)
-
- cli = stepper_cli.NodeStepperCLI(node_stepper)
- output = cli.cont([no_output_node])
- node_names, stat_labels, node_pointer = _parse_sorted_nodes_list(
- output.lines)
-
- self.assertEqual(no_output_node, node_names[node_pointer])
- self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_CONT,
- stat_labels[node_pointer])
-
- def testContToUpdateNodeLeadsToDirtyVariableLabel(self):
- cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.opt))
- output = cli.cont(["opt/update_b/ApplyGradientDescent"])
-
- output = cli.list_sorted_nodes([])
- node_names, stat_labels, _ = _parse_sorted_nodes_list(output.lines)
- self.assertIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE,
- stat_labels[node_names.index("b")])
- self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE,
- stat_labels[node_names.index("a")])
+ output = cli.cont(["opt/update_b/ApplyGradientDescent", "-r", "-i"])
- def testContWithRestoreVariablesOptionShouldRestoreVariableValue(self):
- cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.opt))
- output = cli.cont(["opt/update_a/ApplyGradientDescent"])
-
- # After cont() call on .../update_a/..., Variable a should have been marked
- # as dirty, whereas b should not have.
- output = cli.list_sorted_nodes([])
- node_names, stat_labels, _ = _parse_sorted_nodes_list(output.lines)
- self.assertIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE,
- stat_labels[node_names.index("a")])
- self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE,
- stat_labels[node_names.index("b")])
-
- output = cli.cont(["opt/update_b/ApplyGradientDescent", "-r"])
-
- # After cont() call on .../update_b/... with the -r flag, Variable b should
- # have been marked as dirty, whereas Variable a should not be because it
- # should have been restored.
- output = cli.list_sorted_nodes([])
- node_names, stat_labels, _ = _parse_sorted_nodes_list(output.lines)
- self.assertIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE,
- stat_labels[node_names.index("b")])
- self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE,
- stat_labels[node_names.index("a")])
+ # After cont() call on .../update_b/... with the -r flag, Variable b
+ # should have been marked as dirty, whereas Variable a should not be
+ # because it should have been restored.
+ output = cli.list_sorted_nodes([])
+ node_names, stat_labels, _ = _parse_sorted_nodes_list(output.lines)
+ self.assertIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE,
+ stat_labels[node_names.index("b")])
+ self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE,
+ stat_labels[node_names.index("a")])
def testPrintTensorShouldWorkWithTensorName(self):
- cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.e))
+ with stepper.NodeStepper(self.sess, self.e) as node_stepper:
+ cli = stepper_cli.NodeStepperCLI(node_stepper)
- cli.cont("d")
- output = cli.print_tensor(["d:0"])
+ cli.cont("d")
+ output = cli.print_tensor(["d:0"])
- self.assertEqual("Tensor \"d:0\":", output.lines[0])
- self.assertEqual("-20.0", output.lines[-1])
+ self.assertEqual("Tensor \"d:0\":", output.lines[0])
+ self.assertEqual("-20.0", output.lines[-1])
def testPrintTensorShouldWorkWithNodeNameWithOutputTensor(self):
- cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.e))
+ with stepper.NodeStepper(self.sess, self.e) as node_stepper:
+ cli = stepper_cli.NodeStepperCLI(node_stepper)
- cli.cont("d")
- output = cli.print_tensor(["d"])
+ cli.cont("d")
+ output = cli.print_tensor(["d"])
- self.assertEqual("Tensor \"d:0\":", output.lines[0])
- self.assertEqual("-20.0", output.lines[-1])
+ self.assertEqual("Tensor \"d:0\":", output.lines[0])
+ self.assertEqual("-20.0", output.lines[-1])
def testPrintTensorShouldWorkSlicingString(self):
ph_value = np.array([[1.0, 0.0], [0.0, 2.0]])
- cli = stepper_cli.NodeStepperCLI(
- stepper.NodeStepper(
- self.sess, self.f, feed_dict={self.ph: ph_value}))
+ with stepper.NodeStepper(
+ self.sess, self.f, feed_dict={self.ph: ph_value}) as node_stepper:
+ cli = stepper_cli.NodeStepperCLI(node_stepper)
- output = cli.print_tensor(["ph:0[:, 1]"])
- self.assertEqual("Tensor \"ph:0[:, 1]\":", output.lines[0])
- self.assertEqual(repr(ph_value[:, 1]), output.lines[-1])
+ output = cli.print_tensor(["ph:0[:, 1]"])
+ self.assertEqual("Tensor \"ph:0[:, 1]\":", output.lines[0])
+ self.assertEqual(repr(ph_value[:, 1]), output.lines[-1])
- output = cli.print_tensor(["ph[:, 1]"])
- self.assertEqual("Tensor \"ph:0[:, 1]\":", output.lines[0])
- self.assertEqual(repr(ph_value[:, 1]), output.lines[-1])
+ output = cli.print_tensor(["ph[:, 1]"])
+ self.assertEqual("Tensor \"ph:0[:, 1]\":", output.lines[0])
+ self.assertEqual(repr(ph_value[:, 1]), output.lines[-1])
def testPrintTensorWithNonexistentTensorShouldError(self):
- cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.e))
+ with stepper.NodeStepper(self.sess, self.e) as node_stepper:
+ cli = stepper_cli.NodeStepperCLI(node_stepper)
- output = cli.print_tensor(["foobar"])
- self.assertEqual([
- "ERROR: foobar is not in the transitive closure of this stepper "
- "instance."
- ], output.lines)
+ output = cli.print_tensor(["foobar"])
+ self.assertEqual([
+ "ERROR: foobar is not in the transitive closure of this stepper "
+ "instance."
+ ], output.lines)
def testPrintTensorWithNoHandleShouldError(self):
- cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.e))
+ with stepper.NodeStepper(self.sess, self.e) as node_stepper:
+ cli = stepper_cli.NodeStepperCLI(node_stepper)
- output = cli.print_tensor("e")
- self.assertEqual([
- "This stepper instance does not have access to the value of tensor "
- "\"e:0\""
- ], output.lines)
+ output = cli.print_tensor("e")
+ self.assertEqual([
+ "This stepper instance does not have access to the value of tensor "
+ "\"e:0\""
+ ], output.lines)
def testInjectTensorValueByTensorNameShouldBeReflected(self):
- node_stepper = stepper.NodeStepper(self.sess, self.e)
- cli = stepper_cli.NodeStepperCLI(node_stepper)
+ with stepper.NodeStepper(self.sess, self.e) as node_stepper:
+ cli = stepper_cli.NodeStepperCLI(node_stepper)
- output = cli.cont(["d"])
- node_names, _, node_pointer = _parse_sorted_nodes_list(output.lines)
- self.assertEqual("d", node_names[node_pointer])
+ output = cli.cont(["d"])
+ node_names, _, node_pointer = _parse_sorted_nodes_list(output.lines)
+ self.assertEqual("d", node_names[node_pointer])
- output = cli.list_sorted_nodes([])
- node_names, stat_labels, node_pointer = _parse_sorted_nodes_list(
- output.lines)
+ output = cli.list_sorted_nodes([])
+ node_names, stat_labels, node_pointer = _parse_sorted_nodes_list(
+ output.lines)
- index_d = node_names.index("d")
- self.assertIn(stepper_cli.NodeStepperCLI.STATE_CONT, stat_labels[index_d])
- self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_OVERRIDDEN,
- stat_labels[index_d])
+ index_d = node_names.index("d")
+ self.assertIn(stepper_cli.NodeStepperCLI.STATE_CONT, stat_labels[index_d])
+ self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_OVERRIDDEN,
+ stat_labels[index_d])
- self.assertAllClose(-20.0, node_stepper.get_tensor_value("d:0"))
+ self.assertAllClose(-20.0, node_stepper.get_tensor_value("d:0"))
- output = cli.inject_value(["d:0", "20.0"])
+ output = cli.inject_value(["d:0", "20.0"])
- # Verify that the override is available.
- self.assertEqual(["d:0"], node_stepper.override_names())
+ # Verify that the override is available.
+ self.assertEqual(["d:0"], node_stepper.override_names())
- # Verify that the list of sorted nodes reflects the existence of the value
- # override (i.e., injection).
- output = cli.list_sorted_nodes([])
- node_names, stat_labels, node_pointer = _parse_sorted_nodes_list(
- output.lines)
+ # Verify that the list of sorted nodes reflects the existence of the value
+ # override (i.e., injection).
+ output = cli.list_sorted_nodes([])
+ node_names, stat_labels, node_pointer = _parse_sorted_nodes_list(
+ output.lines)
- index_d = node_names.index("d")
- self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_CONT,
- stat_labels[index_d])
- self.assertIn(stepper_cli.NodeStepperCLI.STATE_OVERRIDDEN,
- stat_labels[index_d])
+ index_d = node_names.index("d")
+ self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_CONT,
+ stat_labels[index_d])
+ self.assertIn(stepper_cli.NodeStepperCLI.STATE_OVERRIDDEN,
+ stat_labels[index_d])
def testInjectTensorValueByNodeNameShouldBeReflected(self):
- node_stepper = stepper.NodeStepper(self.sess, self.e)
- cli = stepper_cli.NodeStepperCLI(node_stepper)
+ with stepper.NodeStepper(self.sess, self.e) as node_stepper:
+ cli = stepper_cli.NodeStepperCLI(node_stepper)
- cli.inject_value(["d", "20.0"])
- self.assertEqual(["d:0"], node_stepper.override_names())
+ cli.inject_value(["d", "20.0"])
+ self.assertEqual(["d:0"], node_stepper.override_names())
def testInjectToNonexistentTensorShouldError(self):
- node_stepper = stepper.NodeStepper(self.sess, self.e)
- cli = stepper_cli.NodeStepperCLI(node_stepper)
-
- output = cli.inject_value(["foobar:0", "20.0"])
- self.assertEqual([
- "ERROR: foobar:0 is not in the transitive closure of this stepper "
- "instance."
- ], output.lines)
+ with stepper.NodeStepper(self.sess, self.e) as node_stepper:
+ cli = stepper_cli.NodeStepperCLI(node_stepper)
+
+ output = cli.inject_value(["foobar:0", "20.0"])
+ self.assertEqual([
+ "ERROR: foobar:0 is not in the transitive closure of this stepper "
+ "instance."
+ ], output.lines)
if __name__ == "__main__":
diff --git a/tensorflow/python/debug/stepper.py b/tensorflow/python/debug/stepper.py
index fb5e972627..3cbe83d072 100644
--- a/tensorflow/python/debug/stepper.py
+++ b/tensorflow/python/debug/stepper.py
@@ -18,11 +18,16 @@ from __future__ import division
from __future__ import print_function
import copy
+import os
+import shutil
+import tempfile
+import time
import six
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.debug import debug_data
+from tensorflow.python.debug import debug_utils
from tensorflow.python.framework import ops
from tensorflow.python.ops import session_ops
@@ -64,9 +69,17 @@ class NodeStepper(object):
tree of the target. When it reaches an input where one of the following is
available, it will supply the available value to the feed_dict of the cont()
call:
- (1) TensorHandles from previous cont() calls.
- (2) Overriding (injected) values from the client.
- (3) Feeds supplied during the construction of the stepper instance.
+ (1) Overriding (injected) values from the client.
+ (2) TensorHandles from previous cont() calls.
+ (3) Dumped intermediate Tensors from previous cont() calls.
+ (4) Feeds supplied during the construction of the stepper instance.
+
+ During the cont() call, intermediate Tensors are dumped to temporary
+ directories. The dumped Tensor values will be used in subsequent cont() calls
+ when they are required as data dependencies.
+
+ The temporary directories are automatically clean when the NodeStepper
+ instance exits as a context mananger.
Once the tracing is complete, it will issue a run() call on the
underlying session, using the aforementioned feed_dict prepared by the input
@@ -95,10 +108,7 @@ class NodeStepper(object):
FEED_TYPE_CLIENT = "client"
FEED_TYPE_HANDLE = "handle"
FEED_TYPE_OVERRIDE = "override"
-
- # TODO(cais): The following member constant is currently unused. Use it when
- # the stepper is capable of using dumped intermediate tensors.
- FEED_TYPE_INTERMEDIATE = "intermediate"
+ FEED_TYPE_DUMPED_INTERMEDIATE = "dumped_intermediate"
def __init__(self, sess, fetches, feed_dict=None):
"""Constructor for Debugger.
@@ -125,11 +135,15 @@ class NodeStepper(object):
self._variable_initial_values = {}
# Initialize the map for output recipients (targets).
- self._non_control_output_targets = {}
+ self._output_targets = {}
# Sorted transitive closure of the fetched node.
- self._sorted_nodes, self._closure_elements = self._dfs_visit(
- self._sess.graph, self._fetch_list)
+ # We also collect the list of the names of the reference-type Tensors,
+ # because we later need to avoid using intermediate dumps for such Tensors.
+ (self._sorted_nodes,
+ self._closure_elements,
+ self._ref_tensor_names) = self._dfs_visit(self._sess.graph,
+ self._fetch_list)
self._transitive_closure_set = set(self._sorted_nodes)
@@ -146,6 +160,11 @@ class NodeStepper(object):
# tensor handles.
self._tensor_handles = {}
+ # Cached intermediate tensor values: a dict mapping tensor names to
+ # DebugTensorDatum.
+ self._dumped_intermediate_tensors = {}
+ self._dump_session_root = tempfile.mkdtemp(prefix="tfdbg_stepper_")
+
# Feed dict from the client.
self._client_feed_dict = {}
if feed_dict:
@@ -161,6 +180,13 @@ class NodeStepper(object):
# What the feed types were used by the last cont() call.
self._last_feed_types = {}
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_value, exc_traceback):
+ if os.path.isdir(self._dump_session_root):
+ shutil.rmtree(self._dump_session_root)
+
def _get_fetch_and_name_lists(self, flattened_fetches):
"""Get the lists of fetches and their names.
@@ -220,6 +246,11 @@ class NodeStepper(object):
# Graph elements in the transitive closure, including the nodes and tensors.
closure_elements = [elem.name for elem in elem_list]
+ ref_tensor_names = set()
+ for element in elem_list:
+ if isinstance(element, ops.Tensor) and element.dtype._is_ref_dtype: # pylint: disable=protected-access
+ ref_tensor_names.add(element.name)
+
while elem_stack:
curr_elem = elem_stack.pop()
curr_node = self._get_node(curr_elem)
@@ -238,22 +269,21 @@ class NodeStepper(object):
# Iterate through the (non-control) inputs.
for inp in all_inputs:
- is_non_control_input = inp in non_control_inputs
-
# Set up the non-control output map.
- if is_non_control_input:
- if inp.name not in self._non_control_output_targets:
- self._non_control_output_targets[inp.name] = set([curr_elem.name])
- else:
- self._non_control_output_targets[inp.name].add(curr_elem.name)
-
- if (inp.op.type in ["Variable", "VariableV2"] and
- inp.name not in self._variable_initializers):
- # Obtain the initializer op of the variable, in case the Variable's
- # value needs to be restored later.
- initializer = graph.as_graph_element(inp.op.name + "/Assign")
- self._variable_initializers[inp.name] = initializer
- self._variable_initial_values[inp.name] = initializer.inputs[1]
+ # if is_non_control_input:
+ if inp.name not in self._output_targets:
+ self._output_targets[inp.name] = set([curr_elem.name])
+ else:
+ self._output_targets[inp.name].add(curr_elem.name)
+
+ if (isinstance(inp, ops.Tensor) and
+ inp.op.type in ["Variable", "VariableV2"] and
+ inp.name not in self._variable_initializers):
+ # Obtain the initializer op of the variable, in case the Variable's
+ # value needs to be restored later.
+ initializer = graph.as_graph_element(inp.op.name + "/Assign")
+ self._variable_initializers[inp.name] = initializer
+ self._variable_initial_values[inp.name] = initializer.inputs[1]
inp_node = self._get_node(inp)
if inp_node.name in done:
@@ -262,6 +292,8 @@ class NodeStepper(object):
elem_stack.append(inp)
closure_elements.append(inp.name)
+ if isinstance(inp, ops.Tensor) and inp.dtype._is_ref_dtype: # pylint: disable=protected-access
+ ref_tensor_names.add(inp.name)
# Now that we have traversed the transitive closure and obtained the
# node-input map, we can topologically sort them.
@@ -291,7 +323,7 @@ class NodeStepper(object):
stack.extend(pushes)
- return sorted_nodes, closure_elements
+ return sorted_nodes, closure_elements, ref_tensor_names
def sorted_nodes(self):
"""Get a topologically-sorted list of node names of the stepper.
@@ -409,10 +441,15 @@ class NodeStepper(object):
return self._last_feed_types
+ # TODO(cais): Add method last_updated_variables() to allow client to look
+ # up the Variables updated in the last cont() call.
+
def cont(self,
target,
use_tensor_handles=True,
+ use_dumped_intermediates=True,
use_overrides=True,
+ invalidate_from_updated_variables=False,
restore_variable_values=False):
"""Continue till the completion of the specified target tensor.
@@ -423,8 +460,13 @@ class NodeStepper(object):
# TODO(cais): Support multiple fetches as in Session.run() interface.
use_tensor_handles: (bool) Whether this cont() run will use cached tensor
handles to avoid recomputation. Default: True.
+ use_dumped_intermediates: (bool) Whether this cont() call will use dumped
+ intermediate tensors to avoid recomputation.
use_overrides: (bool) Whether the overriding tensor values supplied by
the client are to be used in this cont() call. Default: True.
+ invalidate_from_updated_variables: (bool) Whether to invalidate the
+ tensor handles and intermediate tensor handles affected by the
+ Variable updates that happen in this cont() call.
restore_variable_values: (bool) Whether the old values of the variables
(before any cont() calls in this object) are to be restored.
@@ -441,9 +483,6 @@ class NodeStepper(object):
self._last_feed_types = {}
- # The feeds to be used in the Session.run() call.
- feeds = {}
-
if isinstance(target, six.string_types):
# Fetch target is a string. Assume it is the name of the Tensor or Op and
# will attempt to find it in the Session's graph.
@@ -494,6 +533,12 @@ class NodeStepper(object):
self._last_feed_types[target_name] = self.FEED_TYPE_HANDLE
return self._tensor_handles[target_name].eval()
+ # Check if a dumped intermediate tensor can be used on the fetch directly.
+ if (use_dumped_intermediates and
+ target_name in self._dumped_intermediate_tensors):
+ self._last_feed_types[target_name] = self.FEED_TYPE_DUMPED_INTERMEDIATE
+ return self._dumped_intermediate_tensors[target_name].get_tensor()
+
# Check if an overriding tensor value can be used directly.
if use_overrides and target_name in self._override_tensors:
# Override is available. Return the value right away.
@@ -510,6 +555,7 @@ class NodeStepper(object):
# =========================================================================
# Use a non-recursive method to trace the inputs from the node and set up
# the feeds.
+ feeds = {} # The feeds to be used in the Session.run() call.
fetched = self._sess.graph.as_graph_element(target_name)
elem_stack = [fetched]
done = set()
@@ -529,7 +575,7 @@ class NodeStepper(object):
# Determine whether the input is feedable. Reference-type tensors,
# e.g., Variables, should not be fed, because they can change.
if isinstance(inp, ops.Tensor):
- is_inp_ref = inp.dtype._is_ref_dtype # pylint: disable=protected-access
+ is_inp_ref = inp.dtype._is_ref_dtype # pylint: disable=protected-access
can_feed = self._sess.graph.is_feedable(inp) and not is_inp_ref
else:
is_inp_ref = False
@@ -574,11 +620,17 @@ class NodeStepper(object):
# Use client-supplied overriding tensor value.
feeds[inp] = self._override_tensors[inp.name]
self._last_feed_types[inp.name] = self.FEED_TYPE_OVERRIDE
- elif (use_tensor_handles and can_feed and
- inp.name in self._tensor_handles and inp not in feeds):
+ elif (can_feed and inp not in feeds and
+ use_tensor_handles and inp.name in self._tensor_handles):
# Tensor handle found in cache.
feeds[inp] = self._tensor_handles[inp.name].eval()
self._last_feed_types[inp.name] = self.FEED_TYPE_HANDLE
+ elif (can_feed and inp not in feeds and
+ use_dumped_intermediates and
+ inp.name in self._dumped_intermediate_tensors):
+ # Dumped intermediate Tensor found.
+ feeds[inp] = self._dumped_intermediate_tensors[inp.name].get_tensor()
+ self._last_feed_types[inp.name] = self.FEED_TYPE_DUMPED_INTERMEDIATE
elif inp.name in self._client_feed_dict:
# This input is available in the client feed_dict.
feeds[inp] = self._client_feed_dict[inp.name]
@@ -602,14 +654,12 @@ class NodeStepper(object):
for variable in restored_variables:
self._dirty_variables.remove(variable)
- # Prepare RunOptions for DebugTensorWatches
- run_options = config_pb2.RunOptions()
- # TODO(cais): Add fields for watching intermediate tensors.
-
+ (dump_path,
+ run_options) = self._prepare_cont_call_dump_path_and_run_options()
if isinstance(fetched, ops.Operation):
# The fetched is an Operation: Will not get tensor handle.
self._sess.run(fetched, feed_dict=feeds, options=run_options)
- # No return value for a run of an Operation
+ return_value = None
else:
# This is a Tensor: Will get tensor handle and cache it.
# Will also get the additional requested tensor handles (if any).
@@ -622,17 +672,58 @@ class NodeStepper(object):
])
handle_names.extend(additional_handle_requests)
- for handle_name, tensor in zip(handle_names, tensors_to_get_handles_for):
- handle = self._sess.run(session_ops.get_session_handle(tensor),
- feed_dict=feeds,
- options=run_options)
+ handles = self._sess.run(
+ [session_ops.get_session_handle(tensor) for tensor in
+ tensors_to_get_handles_for],
+ feed_dict=feeds,
+ options=run_options)
+ for handle_name, handle in zip(handle_names, handles):
self._tensor_handles[handle_name] = handle
- return self._tensor_handles[target_name].eval()
+ return_value = self._tensor_handles[target_name].eval()
+
+ self._load_dumped_intermediate_tensors(dump_path, target_name)
- # Invalidate caches at the end.
- for touched_variable in touched_variables:
- self._invalidate_transitively_outgoing_cache(touched_variable)
+ if invalidate_from_updated_variables:
+ # Invalidate caches at the end.
+ for touched_variable in touched_variables:
+ self._invalidate_transitively_outgoing_cache(touched_variable)
+
+ return return_value
+
+ def _prepare_cont_call_dump_path_and_run_options(self):
+ """Prepare the dump path and RunOptions for next cont() call.
+
+ Returns:
+ dump_path: (str) Directory path to which the intermediate tensor will be
+ dumped.
+ run_options: (config_pb2.RunOptions) The RunOptions containing the tensor
+ watch options for this graph.
+ """
+ run_options = config_pb2.RunOptions()
+ dump_path = self._cont_call_dump_path()
+ for element_name in self._closure_elements:
+ if ":" in element_name:
+ debug_utils.add_debug_tensor_watch(
+ run_options,
+ debug_data.get_node_name(element_name),
+ output_slot=debug_data.get_output_slot(element_name),
+ debug_urls=["file://" + dump_path])
+
+ return dump_path, run_options
+
+ def _cont_call_dump_path(self):
+ return os.path.join(self._dump_session_root,
+ "cont_%d" % int(time.time() * 1e6))
+
+ def _load_dumped_intermediate_tensors(self, dump_path, target_name):
+ dump_dir = debug_data.DebugDumpDir(dump_path, validate=False)
+ for dump in dump_dir.dumped_tensor_data:
+ if (dump.tensor_name not in self._ref_tensor_names and
+ dump.tensor_name not in self._tensor_handles and
+ dump.tensor_name not in self._override_tensors and
+ dump.tensor_name != target_name):
+ self._dumped_intermediate_tensors[dump.tensor_name] = dump
def _get_node_name(self, graph_element_name):
return graph_element_name.split(":")[0]
@@ -646,29 +737,33 @@ class NodeStepper(object):
Uses non-recursive implementation to avoid stack overflow on deep networks.
- TODO(cais): Currently, only TensorHandle caches are invalidated. Invalidate
- cached intermediate tensor values from dumps when dumps are added.
-
Args:
source_element: The source graph element (e.g., a Variable output slot)
to trace the output from.
"""
- if not self._tensor_handles:
+ if not self._tensor_handles and not self._dumped_intermediate_tensors:
return
# First, use cached invalidation paths to eliminate some cached tensor
- # handles.
- to_delete = []
+ # handles and intermediate tensors.
+ to_delete_handles = []
for handle_name in self._tensor_handles:
if (handle_name in self._cached_invalidation_path and
source_element in self._cached_invalidation_path[handle_name]):
- to_delete.append(handle_name)
-
- for handle_name in to_delete:
+ to_delete_handles.append(handle_name)
+ for handle_name in to_delete_handles:
del self._tensor_handles[handle_name]
- if not self._tensor_handles:
+ to_delete_intermediates = []
+ for intm_tensor_name in self._dumped_intermediate_tensors:
+ if (intm_tensor_name in self._cached_invalidation_path and
+ source_element in self._cached_invalidation_path[intm_tensor_name]):
+ to_delete_intermediates.append(intm_tensor_name)
+ for intermediate in to_delete_intermediates:
+ del self._dumped_intermediate_tensors[intermediate]
+
+ if not self._tensor_handles and not self._dumped_intermediate_tensors:
return
stack = [source_element]
@@ -676,19 +771,22 @@ class NodeStepper(object):
while stack:
curr_element = stack.pop()
-
done.add(curr_element)
- if curr_element in self._tensor_handles:
+ if (curr_element in self._tensor_handles or
+ curr_element in self._dumped_intermediate_tensors):
# Cache the invalidation path for potential future use.
if curr_element not in self._cached_invalidation_path:
self._cached_invalidation_path[curr_element] = set([source_element])
else:
self._cached_invalidation_path[curr_element].add(source_element)
- del self._tensor_handles[curr_element]
+ if curr_element in self._tensor_handles:
+ del self._tensor_handles[curr_element]
+ else:
+ del self._dumped_intermediate_tensors[curr_element]
- targets = self._non_control_output_targets.get(curr_element, [])
+ targets = self._output_targets.get(curr_element, [])
for target in targets:
if target in done:
continue
@@ -740,6 +838,16 @@ class NodeStepper(object):
return set([self._get_node_name(name) for name in self._tensor_handles])
+ def intermediate_tensor_names(self):
+ """Get list of the names of the Tensors for which dumps are available.
+
+ Returns:
+ (list of str) List of the names of the Tensors for which intermediate
+ dumps are available.
+ """
+
+ return self._dumped_intermediate_tensors.keys()
+
def dirty_variables(self):
"""Get the set of variables that are currently "dirty".
@@ -817,6 +925,8 @@ class NodeStepper(object):
return self._override_tensors[tensor_name]
elif tensor_name in self._tensor_handles:
return self._tensor_handles[tensor_name].eval()
+ elif tensor_name in self._dumped_intermediate_tensors:
+ return self._dumped_intermediate_tensors[tensor_name].get_tensor()
else:
raise ValueError(
"This stepper instance does not have access to the value of "
diff --git a/tensorflow/python/debug/stepper_test.py b/tensorflow/python/debug/stepper_test.py
index 20d79f3ade..63501b4fe6 100644
--- a/tensorflow/python/debug/stepper_test.py
+++ b/tensorflow/python/debug/stepper_test.py
@@ -25,6 +25,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import googletest
from tensorflow.python.training import gradient_descent
@@ -54,277 +55,310 @@ class StepperTest(test_util.TensorFlowTestCase):
self.sess = session.Session()
self.sess.run(variables.global_variables_initializer())
- self.sess = session.Session()
- self.sess.run(variables.global_variables_initializer())
-
def tearDown(self):
ops.reset_default_graph()
def testContToFetchNotInTransitiveClosureShouldError(self):
- stepper = NodeStepper(self.sess, "e:0")
-
- sorted_nodes = stepper.sorted_nodes()
- self.assertEqual(7, len(sorted_nodes))
- self.assertLess(sorted_nodes.index("a"), sorted_nodes.index("a/read"))
- self.assertLess(sorted_nodes.index("b"), sorted_nodes.index("b/read"))
- self.assertLess(sorted_nodes.index("a"), sorted_nodes.index("c"))
- self.assertLess(sorted_nodes.index("b"), sorted_nodes.index("c"))
- self.assertLess(sorted_nodes.index("a"), sorted_nodes.index("d"))
- self.assertLess(sorted_nodes.index("d"), sorted_nodes.index("e"))
- self.assertLess(sorted_nodes.index("c"), sorted_nodes.index("e"))
-
- self.assertSetEqual(
- {"e:0", "d:0", "c:0", "a/read:0", "b/read:0", "b:0", "a:0"},
- set(stepper.closure_elements()))
-
- with self.assertRaisesRegexp(
- ValueError,
- "Target \"f:0\" is not in the transitive closure for the fetch of the "
- "stepper"):
- stepper.cont("f:0")
-
- def testContToNodeNameShouldReturnTensorvalue(self):
- stepper = NodeStepper(self.sess, "e:0")
-
- cont_result = stepper.cont("c")
- self.assertAllClose(6.0, cont_result)
+ with NodeStepper(self.sess, "e:0") as stepper:
+ sorted_nodes = stepper.sorted_nodes()
+ self.assertEqual(7, len(sorted_nodes))
+ self.assertLess(sorted_nodes.index("a"), sorted_nodes.index("a/read"))
+ self.assertLess(sorted_nodes.index("b"), sorted_nodes.index("b/read"))
+ self.assertLess(sorted_nodes.index("a"), sorted_nodes.index("c"))
+ self.assertLess(sorted_nodes.index("b"), sorted_nodes.index("c"))
+ self.assertLess(sorted_nodes.index("a"), sorted_nodes.index("d"))
+ self.assertLess(sorted_nodes.index("d"), sorted_nodes.index("e"))
+ self.assertLess(sorted_nodes.index("c"), sorted_nodes.index("e"))
- def testUsingNamesNotUsingIntermediateTensors(self):
- stepper = NodeStepper(self.sess, "e:0")
+ self.assertSetEqual(
+ {"e:0", "d:0", "c:0", "a/read:0", "b/read:0", "b:0", "a:0"},
+ set(stepper.closure_elements()))
- # The first cont() call should have used no feeds.
- result = stepper.cont("c:0")
- self.assertAllClose(6.0, result)
- self.assertEqual({}, stepper.last_feed_types())
+ with self.assertRaisesRegexp(
+ ValueError,
+ "Target \"f:0\" is not in the transitive closure for the fetch of "
+ "the stepper"):
+ stepper.cont("f:0")
- # The second cont() call should have used the tensor handle from the
- # previous cont() call.
- result = stepper.cont("e:0")
- self.assertAllClose(24.0, result)
- self.assertEqual({
- "c:0": NodeStepper.FEED_TYPE_HANDLE
- }, stepper.last_feed_types())
+ def testContToNodeNameShouldReturnTensorValue(self):
+ with NodeStepper(self.sess, "e:0") as stepper:
+ self.assertAllClose(6.0, stepper.cont("c"))
- def testUsingNodesNotUsingIntermediateTensors(self):
- stepper = NodeStepper(self.sess, self.e)
+ def testUsingNamesNotUsingIntermediateTensors(self):
+ with NodeStepper(self.sess, "e:0") as stepper:
+ # The first cont() call should have used no feeds.
+ result = stepper.cont("c:0")
+ self.assertAllClose(6.0, result)
+ self.assertItemsEqual(["a/read:0", "b/read:0"],
+ stepper.intermediate_tensor_names())
+ self.assertAllClose(2.0, stepper.get_tensor_value("a/read:0"))
+ self.assertAllClose(3.0, stepper.get_tensor_value("b/read:0"))
+ self.assertEqual({}, stepper.last_feed_types())
+
+ # The second cont() call should have used the tensor handle from the
+ # previous cont() call.
+ result = stepper.cont("e:0")
+ self.assertAllClose(24.0, result)
+ self.assertItemsEqual(["a/read:0", "b/read:0", "d:0"],
+ stepper.intermediate_tensor_names())
+ self.assertAllClose(2.0, stepper.get_tensor_value("a/read:0"))
+ self.assertAllClose(3.0, stepper.get_tensor_value("b/read:0"))
+ self.assertAllClose(4.0, stepper.get_tensor_value("d:0"))
+ self.assertEqual({
+ "c:0": NodeStepper.FEED_TYPE_HANDLE,
+ "a/read:0": NodeStepper.FEED_TYPE_DUMPED_INTERMEDIATE,
+ }, stepper.last_feed_types())
- # There should be no handles before any cont() calls.
- self.assertEqual([], stepper.handle_names())
- self.assertSetEqual(set(), stepper.handle_node_names())
-
- # Before the cont() call, the stepper should not have access to the value
- # of c:0.
- with self.assertRaisesRegexp(
- ValueError,
- "This stepper instance does not have access to the value of tensor "
- "\"c:0\""):
- stepper.get_tensor_value("c:0")
-
- # Using the node/tensor itself, instead of the name str, should work on
- # cont().
- result = stepper.cont(self.c)
- self.assertAllClose(6.0, result)
- self.assertEqual({}, stepper.last_feed_types())
-
- self.assertEqual(["c:0"], stepper.handle_names())
- self.assertEqual({"c"}, stepper.handle_node_names())
-
- # After the cont() call, the stepper should have access to the value of c:0
- # via a tensor handle.
- self.assertAllClose(6.0, stepper.get_tensor_value("c:0"))
-
- result = stepper.cont(self.e)
- self.assertAllClose(24.0, result)
- self.assertEqual({
- "c:0": NodeStepper.FEED_TYPE_HANDLE
- }, stepper.last_feed_types())
+ def testUsingNodesNotUsingIntermediateTensors(self):
+ with NodeStepper(self.sess, self.e) as stepper:
+ # There should be no handles before any cont() calls.
+ self.assertEqual([], stepper.handle_names())
+ self.assertSetEqual(set(), stepper.handle_node_names())
+
+ # Before the cont() call, the stepper should not have access to the value
+ # of c:0.
+ with self.assertRaisesRegexp(
+ ValueError,
+ "This stepper instance does not have access to the value of tensor "
+ "\"c:0\""):
+ stepper.get_tensor_value("c:0")
+
+ # Using the node/tensor itself, instead of the name str, should work on
+ # cont().
+ result = stepper.cont(self.c)
+ self.assertItemsEqual(["a/read:0", "b/read:0"],
+ stepper.intermediate_tensor_names())
+ self.assertAllClose(6.0, result)
+ self.assertEqual({}, stepper.last_feed_types())
+
+ self.assertEqual(["c:0"], stepper.handle_names())
+ self.assertEqual({"c"}, stepper.handle_node_names())
+
+ # After the cont() call, the stepper should have access to the value of
+ # c:0 via a tensor handle.
+ self.assertAllClose(6.0, stepper.get_tensor_value("c:0"))
+
+ result = stepper.cont(self.e)
+ self.assertAllClose(24.0, result)
+ self.assertItemsEqual(["a/read:0", "b/read:0", "d:0"],
+ stepper.intermediate_tensor_names())
+ self.assertEqual({
+ "c:0": NodeStepper.FEED_TYPE_HANDLE,
+ "a/read:0": NodeStepper.FEED_TYPE_DUMPED_INTERMEDIATE,
+ }, stepper.last_feed_types())
+
+ def testContToTensorWithIntermediateDumpShouldUseDump(self):
+ with NodeStepper(self.sess, ["e:0", "f:0"]) as stepper:
+ stepper.cont("c:0")
+ self.assertItemsEqual(["a/read:0", "b/read:0"],
+ stepper.intermediate_tensor_names())
+ self.assertAllClose(2.0, stepper.get_tensor_value("a/read:0"))
+ self.assertAllClose(3.0, stepper.get_tensor_value("b/read:0"))
+
+ self.assertAllClose(2.0, stepper.cont("a/read:0"))
+ self.assertEqual({
+ "a/read:0": NodeStepper.FEED_TYPE_DUMPED_INTERMEDIATE
+ }, stepper.last_feed_types())
+
+ self.assertAllClose(10.0, stepper.cont("f:0"))
+ self.assertEqual({
+ "b/read:0": NodeStepper.FEED_TYPE_DUMPED_INTERMEDIATE
+ }, stepper.last_feed_types())
+
+ def testDisablingUseDumpedIntermediatesWorks(self):
+ with NodeStepper(self.sess, ["e:0", "f:0"]) as stepper:
+ stepper.cont("c:0")
+ self.assertItemsEqual(["a/read:0", "b/read:0"],
+ stepper.intermediate_tensor_names())
+ self.assertAllClose(2.0, stepper.get_tensor_value("a/read:0"))
+ self.assertAllClose(3.0, stepper.get_tensor_value("b/read:0"))
+
+ self.assertAllClose(10.0,
+ stepper.cont("f:0", use_dumped_intermediates=False))
+ self.assertEqual({}, stepper.last_feed_types())
def testIsFeedableShouldGiveCorrectAnswers(self):
- stepper = NodeStepper(self.sess, self.e)
-
- self.assertTrue(stepper.is_feedable("a/read:0"))
- self.assertTrue(stepper.is_feedable("b/read:0"))
- self.assertTrue(stepper.is_feedable("c:0"))
- self.assertTrue(stepper.is_feedable("d:0"))
+ with NodeStepper(self.sess, self.e) as stepper:
+ self.assertTrue(stepper.is_feedable("a/read:0"))
+ self.assertTrue(stepper.is_feedable("b/read:0"))
+ self.assertTrue(stepper.is_feedable("c:0"))
+ self.assertTrue(stepper.is_feedable("d:0"))
def testOverrideValue(self):
- stepper = NodeStepper(self.sess, self.e)
-
- result = stepper.cont(self.c)
- self.assertAllClose(6.0, result)
- self.assertEqual({}, stepper.last_feed_types())
-
- # There should be no overrides before any cont() calls.
- self.assertEqual([], stepper.override_names())
-
- # Calling cont() on c again should lead to use of the handle.
- result = stepper.cont(self.c)
- self.assertAllClose(6.0, result)
- self.assertEqual({
- "c:0": NodeStepper.FEED_TYPE_HANDLE
- }, stepper.last_feed_types())
-
- # Override c:0.
- stepper.override_tensor("c:0", 7.0)
-
- # After the overriding, calling get_tensor_value() on c:0 should yield the
- # overriding value.
- self.assertEqual(7.0, stepper.get_tensor_value("c:0"))
-
- # Now c:0 should have only an override value, but no cached handle, because
- # the handle should have been invalidated.
- self.assertEqual([], stepper.handle_names())
- self.assertSetEqual(set(), stepper.handle_node_names())
- self.assertEqual(["c:0"], stepper.override_names())
-
- # Run a downstream tensor after the value override.
- result = stepper.cont(self.e)
- self.assertAllClose(28.0, result) # Should reflect the overriding value.
-
- # Should use override, instead of the handle.
- self.assertEqual({
- "c:0": NodeStepper.FEED_TYPE_OVERRIDE
- }, stepper.last_feed_types())
+ with NodeStepper(self.sess, self.e) as stepper:
+ result = stepper.cont(self.c)
+ self.assertAllClose(6.0, result)
+ self.assertEqual({}, stepper.last_feed_types())
+
+ # There should be no overrides before any cont() calls.
+ self.assertEqual([], stepper.override_names())
+
+ # Calling cont() on c again should lead to use of the handle.
+ result = stepper.cont(self.c)
+ self.assertAllClose(6.0, result)
+ self.assertEqual({
+ "c:0": NodeStepper.FEED_TYPE_HANDLE
+ }, stepper.last_feed_types())
+
+ # Override c:0.
+ stepper.override_tensor("c:0", 7.0)
+
+ # After the overriding, calling get_tensor_value() on c:0 should yield the
+ # overriding value.
+ self.assertEqual(7.0, stepper.get_tensor_value("c:0"))
+
+ # Now c:0 should have only an override value, but no cached handle,
+ # because the handle should have been invalidated.
+ self.assertEqual([], stepper.handle_names())
+ self.assertSetEqual(set(), stepper.handle_node_names())
+ self.assertEqual(["c:0"], stepper.override_names())
+
+ # Run a downstream tensor after the value override.
+ result = stepper.cont(self.e)
+ self.assertAllClose(28.0, result) # Should reflect the overriding value.
+
+ # Should use override, instead of the handle.
+ self.assertEqual({
+ "c:0": NodeStepper.FEED_TYPE_OVERRIDE,
+ "a/read:0": NodeStepper.FEED_TYPE_DUMPED_INTERMEDIATE,
+ }, stepper.last_feed_types())
def testOverrideValueTwice(self):
- stepper = NodeStepper(self.sess, self.e)
-
- # Override once.
- stepper.override_tensor("c:0", 7.0)
- self.assertAllClose(28.0, stepper.cont(self.e))
- self.assertEqual({
- "c:0": NodeStepper.FEED_TYPE_OVERRIDE
- }, stepper.last_feed_types())
-
- self.assertEqual(["e:0"], stepper.handle_names())
- self.assertSetEqual({"e"}, stepper.handle_node_names())
- self.assertEqual(["c:0"], stepper.override_names())
-
- # Calling cont(self.e) again. This time the cached tensor handle of e
- # should be used.
- self.assertEqual(28.0, stepper.cont(self.e))
- self.assertEqual({
- "e:0": NodeStepper.FEED_TYPE_HANDLE
- }, stepper.last_feed_types())
-
- # Override c again. This should have invalidated the cache for e.
- stepper.override_tensor("c:0", 8.0)
-
- self.assertEqual([], stepper.handle_names())
- self.assertEqual(set(), stepper.handle_node_names())
- self.assertEqual(["c:0"], stepper.override_names())
-
- self.assertAllClose(32.0, stepper.cont(self.e))
- self.assertEqual({
- "c:0": NodeStepper.FEED_TYPE_OVERRIDE
- }, stepper.last_feed_types())
+ with NodeStepper(self.sess, self.e) as stepper:
+ # Override once.
+ stepper.override_tensor("c:0", 7.0)
+ self.assertAllClose(28.0, stepper.cont(self.e))
+ self.assertEqual({
+ "c:0": NodeStepper.FEED_TYPE_OVERRIDE
+ }, stepper.last_feed_types())
+
+ self.assertEqual(["e:0"], stepper.handle_names())
+ self.assertSetEqual({"e"}, stepper.handle_node_names())
+ self.assertEqual(["c:0"], stepper.override_names())
+
+ # Calling cont(self.e) again. This time the cached tensor handle of e
+ # should be used.
+ self.assertEqual(28.0, stepper.cont(self.e))
+ self.assertEqual({
+ "e:0": NodeStepper.FEED_TYPE_HANDLE
+ }, stepper.last_feed_types())
+
+ # Override c again. This should have invalidated the cache for e.
+ stepper.override_tensor("c:0", 8.0)
+
+ self.assertEqual([], stepper.handle_names())
+ self.assertEqual(set(), stepper.handle_node_names())
+ self.assertEqual(["c:0"], stepper.override_names())
+
+ self.assertAllClose(32.0, stepper.cont(self.e))
+ self.assertEqual({
+ "c:0": NodeStepper.FEED_TYPE_OVERRIDE,
+ "d:0": NodeStepper.FEED_TYPE_DUMPED_INTERMEDIATE,
+ }, stepper.last_feed_types())
def testRemoveOverrideValue(self):
- stepper = NodeStepper(self.sess, self.e)
-
- result = stepper.cont(self.c)
- self.assertAllClose(6.0, result)
- self.assertEqual({}, stepper.last_feed_types())
-
- # The previous cont() step should have generated a cached tensor handle.
- self.assertEqual(["c:0"], stepper.handle_names())
- self.assertSetEqual({"c"}, stepper.handle_node_names())
-
- # Override c:0.
- stepper.override_tensor("c:0", 7.0)
-
- # The overriding should have invalidated the tensor handle.
- self.assertEqual([], stepper.handle_names())
- self.assertSetEqual(set(), stepper.handle_node_names())
- self.assertEqual(["c:0"], stepper.override_names())
-
- result = stepper.cont(self.e)
- self.assertAllClose(28.0, result) # Should reflect the overriding value.
- self.assertEqual({
- "c:0": NodeStepper.FEED_TYPE_OVERRIDE
- }, stepper.last_feed_types())
-
- # The handle to tensor e:0 should have been cached, even though its
- # transitive closure contains an override.
- self.assertIn("e:0", stepper.handle_names())
- self.assertSetEqual({"e"}, stepper.handle_node_names())
-
- # Remove the override.
- stepper.remove_override("c:0")
- # c:0 should not be in the overrides anymore.
- self.assertEqual([], stepper.override_names())
+ with NodeStepper(self.sess, self.e) as stepper:
+ result = stepper.cont(self.c)
+ self.assertAllClose(6.0, result)
+ self.assertEqual({}, stepper.last_feed_types())
+
+ # The previous cont() step should have generated a cached tensor handle.
+ self.assertEqual(["c:0"], stepper.handle_names())
+ self.assertSetEqual({"c"}, stepper.handle_node_names())
+
+ # Override c:0.
+ stepper.override_tensor("c:0", 7.0)
+
+ # The overriding should have invalidated the tensor handle.
+ self.assertEqual([], stepper.handle_names())
+ self.assertSetEqual(set(), stepper.handle_node_names())
+ self.assertEqual(["c:0"], stepper.override_names())
+
+ result = stepper.cont(self.e)
+ self.assertAllClose(28.0, result) # Should reflect the overriding value.
+ self.assertEqual({
+ "c:0": NodeStepper.FEED_TYPE_OVERRIDE,
+ "a/read:0": NodeStepper.FEED_TYPE_DUMPED_INTERMEDIATE,
+ }, stepper.last_feed_types())
+
+ # The handle to tensor e:0 should have been cached, even though its
+ # transitive closure contains an override.
+ self.assertIn("e:0", stepper.handle_names())
+ self.assertSetEqual({"e"}, stepper.handle_node_names())
+
+ # Remove the override.
+ stepper.remove_override("c:0")
+ # c:0 should not be in the overrides anymore.
+ self.assertEqual([], stepper.override_names())
- # Removing the override should have invalidated the tensor handle for c.
- self.assertNotIn("e:0", stepper.handle_names())
- self.assertNotIn("e", stepper.handle_node_names())
+ # Removing the override should have invalidated the tensor handle for c.
+ self.assertNotIn("e:0", stepper.handle_names())
+ self.assertNotIn("e", stepper.handle_node_names())
- # Should reflect the non-overriding value.
- self.assertAllClose(24.0, stepper.cont(self.e))
+ # Should reflect the non-overriding value.
+ self.assertAllClose(24.0, stepper.cont(self.e))
- # This time, the handle to tensor e:0 should have been cached again, even
- # thought its transitive closure contains an override.
- self.assertIn("e:0", stepper.handle_names())
- self.assertIn("e", stepper.handle_node_names())
+ # This time, the handle to tensor e:0 should have been cached again, even
+ # thought its transitive closure contains an override.
+ self.assertIn("e:0", stepper.handle_names())
+ self.assertIn("e", stepper.handle_node_names())
- # Calling cont(self.e) again should have used the tensor handle to e:0.
- self.assertAllClose(24.0, stepper.cont(self.e))
- self.assertEqual({
- "e:0": NodeStepper.FEED_TYPE_HANDLE
- }, stepper.last_feed_types())
+ # Calling cont(self.e) again should have used the tensor handle to e:0.
+ self.assertAllClose(24.0, stepper.cont(self.e))
+ self.assertEqual({
+ "e:0": NodeStepper.FEED_TYPE_HANDLE,
+ }, stepper.last_feed_types())
def testOverrideAndContToSameTensor(self):
- stepper = NodeStepper(self.sess, self.e)
+ with NodeStepper(self.sess, self.e) as stepper:
+ result = stepper.cont(self.c)
+ self.assertAllClose(6.0, result)
+ self.assertEqual({}, stepper.last_feed_types())
+ self.assertEqual(["c:0"], stepper.handle_names())
+ self.assertSetEqual({"c"}, stepper.handle_node_names())
- result = stepper.cont(self.c)
- self.assertAllClose(6.0, result)
- self.assertEqual({}, stepper.last_feed_types())
- self.assertEqual(["c:0"], stepper.handle_names())
- self.assertSetEqual({"c"}, stepper.handle_node_names())
+ self.assertAllClose(6.0, stepper.cont(self.c))
- self.assertAllClose(6.0, stepper.cont(self.c))
+ # The last cont() call should use the tensor handle directly.
+ self.assertEqual({
+ "c:0": NodeStepper.FEED_TYPE_HANDLE
+ }, stepper.last_feed_types())
- # The last cont() call should use the tensor handle directly.
- self.assertEqual({
- "c:0": NodeStepper.FEED_TYPE_HANDLE
- }, stepper.last_feed_types())
+ # Override c:0.
+ stepper.override_tensor("c:0", 7.0)
- # Override c:0.
- stepper.override_tensor("c:0", 7.0)
+ # As a result of the override, the tensor handle should have been
+ # invalidated.
+ self.assertEqual([], stepper.handle_names())
+ self.assertSetEqual(set(), stepper.handle_node_names())
- # As a result of the override, the tensor handle should have been
- # invalidated.
- self.assertEqual([], stepper.handle_names())
- self.assertSetEqual(set(), stepper.handle_node_names())
+ result = stepper.cont(self.c)
+ self.assertAllClose(7.0, result)
- result = stepper.cont(self.c)
- self.assertAllClose(7.0, result)
-
- self.assertEqual({
- "c:0": NodeStepper.FEED_TYPE_OVERRIDE
- }, stepper.last_feed_types())
+ self.assertEqual({
+ "c:0": NodeStepper.FEED_TYPE_OVERRIDE
+ }, stepper.last_feed_types())
def testFinalizeWithPreviousOverrides(self):
- stepper = NodeStepper(self.sess, self.e)
+ with NodeStepper(self.sess, self.e) as stepper:
+ stepper.override_tensor("a/read:0", 20.0)
+ self.assertEqual(["a/read:0"], stepper.override_names())
- stepper.override_tensor("a/read:0", 20.0)
- self.assertEqual(["a/read:0"], stepper.override_names())
+ # Should reflect the overriding value.
+ self.assertAllClose(24000.0, stepper.cont("e:0"))
+ self.assertEqual({
+ "a/read:0": NodeStepper.FEED_TYPE_OVERRIDE
+ }, stepper.last_feed_types())
- # Should reflect the overriding value.
- self.assertAllClose(24000.0, stepper.cont("e:0"))
- self.assertEqual({
- "a/read:0": NodeStepper.FEED_TYPE_OVERRIDE
- }, stepper.last_feed_types())
-
- # Finalize call should have ignored the overriding value.
- self.assertAllClose(24.0, stepper.finalize())
+ # Finalize call should have ignored the overriding value.
+ self.assertAllClose(24.0, stepper.finalize())
def testRemoveNonexistentOverrideValue(self):
- stepper = NodeStepper(self.sess, self.e)
- self.assertEqual([], stepper.override_names())
-
- with self.assertRaisesRegexp(
- ValueError, "No overriding value exists for tensor \"c:0\""):
- stepper.remove_override("c:0")
+ with NodeStepper(self.sess, self.e) as stepper:
+ self.assertEqual([], stepper.override_names())
+ with self.assertRaisesRegexp(
+ ValueError, "No overriding value exists for tensor \"c:0\""):
+ stepper.remove_override("c:0")
def testAttemptToOverrideInvalidTensor(self):
stepper = NodeStepper(self.sess, self.e)
@@ -333,20 +367,18 @@ class StepperTest(test_util.TensorFlowTestCase):
stepper.override_tensor("f:0", 42.0)
def testInvalidOverrideArgumentType(self):
- stepper = NodeStepper(self.sess, self.e)
-
- with self.assertRaisesRegexp(TypeError, "Expected type str; got type"):
- stepper.override_tensor(self.a, 42.0)
+ with NodeStepper(self.sess, self.e) as stepper:
+ with self.assertRaisesRegexp(TypeError, "Expected type str; got type"):
+ stepper.override_tensor(self.a, 42.0)
def testTransitiveClosureWithCrossLinksShouldHaveCorrectOrder(self):
- stepper = NodeStepper(self.sess, "z:0")
-
- sorted_nodes = stepper.sorted_nodes()
- self.assertEqual(4, len(sorted_nodes))
- self.assertLess(sorted_nodes.index("x"), sorted_nodes.index("x/read"))
- self.assertLess(sorted_nodes.index("x"), sorted_nodes.index("y"))
- self.assertLess(sorted_nodes.index("x"), sorted_nodes.index("z"))
- self.assertLess(sorted_nodes.index("y"), sorted_nodes.index("z"))
+ with NodeStepper(self.sess, "z:0") as stepper:
+ sorted_nodes = stepper.sorted_nodes()
+ self.assertEqual(4, len(sorted_nodes))
+ self.assertLess(sorted_nodes.index("x"), sorted_nodes.index("x/read"))
+ self.assertLess(sorted_nodes.index("x"), sorted_nodes.index("y"))
+ self.assertLess(sorted_nodes.index("x"), sorted_nodes.index("z"))
+ self.assertLess(sorted_nodes.index("y"), sorted_nodes.index("z"))
def testNodeStepperConstructorShouldAllowListOrTupleOrDictOfFetches(self):
for i in range(6):
@@ -363,45 +395,44 @@ class StepperTest(test_util.TensorFlowTestCase):
elif i == 5:
fetches = {"e": "e:0", "fz": {"f": "f:0", "z": "z:0"}}
- stepper = NodeStepper(self.sess, fetches)
-
- sorted_nodes = stepper.sorted_nodes()
- self.assertEqual(13, len(sorted_nodes))
-
- # Check the topological order of the sorted nodes.
- self.assertLess(sorted_nodes.index("x"), sorted_nodes.index("x/read"))
- self.assertLess(sorted_nodes.index("x"), sorted_nodes.index("y"))
- self.assertLess(sorted_nodes.index("x"), sorted_nodes.index("z"))
- self.assertLess(sorted_nodes.index("y"), sorted_nodes.index("z"))
-
- self.assertLess(sorted_nodes.index("a"), sorted_nodes.index("a/read"))
- self.assertLess(sorted_nodes.index("b"), sorted_nodes.index("b/read"))
- self.assertLess(sorted_nodes.index("a"), sorted_nodes.index("c"))
- self.assertLess(sorted_nodes.index("b"), sorted_nodes.index("c"))
- self.assertLess(sorted_nodes.index("a"), sorted_nodes.index("d"))
- self.assertLess(sorted_nodes.index("d"), sorted_nodes.index("e"))
- self.assertLess(sorted_nodes.index("c"), sorted_nodes.index("e"))
- self.assertLess(sorted_nodes.index("b"), sorted_nodes.index("f"))
- self.assertLess(sorted_nodes.index("f_y"), sorted_nodes.index("f"))
-
- closure_elements = stepper.closure_elements()
- self.assertIn("x/read:0", closure_elements)
- self.assertIn("e:0", closure_elements)
- self.assertIn("f:0", closure_elements)
-
- self.assertEqual([0], stepper.output_slots_in_closure("x/read"))
- self.assertEqual([0], stepper.output_slots_in_closure("e"))
- self.assertEqual([0], stepper.output_slots_in_closure("f"))
-
- result = stepper.finalize()
- if i == 0 or i == 1 or i == 3 or i == 4:
- self.assertAllClose(24.0, result[0])
- self.assertAllClose(10.0, result[1][0])
- self.assertAllClose(-4.0, result[1][1])
- elif i == 2 or i == 5:
- self.assertAllClose(24.0, result["e"])
- self.assertAllClose(10.0, result["fz"]["f"])
- self.assertAllClose(-4.0, result["fz"]["z"])
+ with NodeStepper(self.sess, fetches) as stepper:
+ sorted_nodes = stepper.sorted_nodes()
+ self.assertEqual(13, len(sorted_nodes))
+
+ # Check the topological order of the sorted nodes.
+ self.assertLess(sorted_nodes.index("x"), sorted_nodes.index("x/read"))
+ self.assertLess(sorted_nodes.index("x"), sorted_nodes.index("y"))
+ self.assertLess(sorted_nodes.index("x"), sorted_nodes.index("z"))
+ self.assertLess(sorted_nodes.index("y"), sorted_nodes.index("z"))
+
+ self.assertLess(sorted_nodes.index("a"), sorted_nodes.index("a/read"))
+ self.assertLess(sorted_nodes.index("b"), sorted_nodes.index("b/read"))
+ self.assertLess(sorted_nodes.index("a"), sorted_nodes.index("c"))
+ self.assertLess(sorted_nodes.index("b"), sorted_nodes.index("c"))
+ self.assertLess(sorted_nodes.index("a"), sorted_nodes.index("d"))
+ self.assertLess(sorted_nodes.index("d"), sorted_nodes.index("e"))
+ self.assertLess(sorted_nodes.index("c"), sorted_nodes.index("e"))
+ self.assertLess(sorted_nodes.index("b"), sorted_nodes.index("f"))
+ self.assertLess(sorted_nodes.index("f_y"), sorted_nodes.index("f"))
+
+ closure_elements = stepper.closure_elements()
+ self.assertIn("x/read:0", closure_elements)
+ self.assertIn("e:0", closure_elements)
+ self.assertIn("f:0", closure_elements)
+
+ self.assertEqual([0], stepper.output_slots_in_closure("x/read"))
+ self.assertEqual([0], stepper.output_slots_in_closure("e"))
+ self.assertEqual([0], stepper.output_slots_in_closure("f"))
+
+ result = stepper.finalize()
+ if i == 0 or i == 1 or i == 3 or i == 4:
+ self.assertAllClose(24.0, result[0])
+ self.assertAllClose(10.0, result[1][0])
+ self.assertAllClose(-4.0, result[1][1])
+ elif i == 2 or i == 5:
+ self.assertAllClose(24.0, result["e"])
+ self.assertAllClose(10.0, result["fz"]["f"])
+ self.assertAllClose(-4.0, result["fz"]["z"])
class StepperTestWithPlaceHolders(test_util.TensorFlowTestCase):
@@ -419,128 +450,222 @@ class StepperTestWithPlaceHolders(test_util.TensorFlowTestCase):
ops.reset_default_graph()
def testGetTensorValueWorksOnPlaceholder(self):
- stepper = NodeStepper(
+ with NodeStepper(
self.sess,
self.y,
feed_dict={
self.ph0: [[1.0, 2.0], [-3.0, 5.0]],
self.ph1: [[-1.0], [0.5]]
- })
-
- self.assertAllClose([[1.0, 2.0], [-3.0, 5.0]],
- stepper.get_tensor_value("ph0"))
- self.assertAllClose([[1.0, 2.0], [-3.0, 5.0]],
- stepper.get_tensor_value("ph0:0"))
- with self.assertRaisesRegexp(
- KeyError, r"The name 'ph0:1' refers to a Tensor which does not exist"):
- stepper.get_tensor_value("ph0:1")
+ }) as stepper:
+ self.assertAllClose([[1.0, 2.0], [-3.0, 5.0]],
+ stepper.get_tensor_value("ph0"))
+ self.assertAllClose([[1.0, 2.0], [-3.0, 5.0]],
+ stepper.get_tensor_value("ph0:0"))
+ with self.assertRaisesRegexp(
+ KeyError,
+ r"The name 'ph0:1' refers to a Tensor which does not exist"):
+ stepper.get_tensor_value("ph0:1")
def testIsPlaceholdersShouldGiveCorrectAnswers(self):
- stepper = NodeStepper(self.sess, self.y)
-
- self.assertTrue(stepper.is_placeholder(self.ph0.name))
- self.assertTrue(stepper.is_placeholder(self.ph1.name))
+ with NodeStepper(self.sess, self.y) as stepper:
+ self.assertTrue(stepper.is_placeholder(self.ph0.name))
+ self.assertTrue(stepper.is_placeholder(self.ph1.name))
- self.assertFalse(stepper.is_placeholder(self.x.name))
- self.assertFalse(stepper.is_placeholder(self.y.name))
+ self.assertFalse(stepper.is_placeholder(self.x.name))
+ self.assertFalse(stepper.is_placeholder(self.y.name))
- with self.assertRaisesRegexp(ValueError,
- "A is not in the transitive closure"):
- self.assertFalse(stepper.is_placeholder("A"))
+ with self.assertRaisesRegexp(ValueError,
+ "A is not in the transitive closure"):
+ self.assertFalse(stepper.is_placeholder("A"))
def testPlaceholdersShouldGiveCorrectAnswers(self):
- stepper = NodeStepper(self.sess, self.y)
-
- self.assertSetEqual({"ph0", "ph1"}, set(stepper.placeholders()))
+ with NodeStepper(self.sess, self.y) as stepper:
+ self.assertSetEqual({"ph0", "ph1"}, set(stepper.placeholders()))
def testContWithPlaceholders(self):
- stepper = NodeStepper(
+ with NodeStepper(
self.sess,
self.y,
feed_dict={
self.ph0: [[1.0, 2.0], [-3.0, 5.0]],
self.ph1: [[-1.0], [0.5]]
- })
-
- self.assertEqual(4, len(stepper.sorted_nodes()))
- self.assertSetEqual({"ph0:0", "ph1:0", "x:0", "y:0"},
- set(stepper.closure_elements()))
-
- result = stepper.cont(self.x)
- self.assertAllClose([[0.0], [5.5]], result)
- self.assertEqual({
- "ph0:0": NodeStepper.FEED_TYPE_CLIENT,
- "ph1:0": NodeStepper.FEED_TYPE_CLIENT,
- }, stepper.last_feed_types())
-
- self.assertEqual(["x:0"], stepper.handle_names())
- self.assertSetEqual({"x"}, stepper.handle_node_names())
-
- result = stepper.cont(self.y)
- self.assertAllClose([[-1.0], [6.0]], result)
- self.assertEqual({
- "x:0": NodeStepper.FEED_TYPE_HANDLE,
- "ph1:0": NodeStepper.FEED_TYPE_CLIENT,
- }, stepper.last_feed_types())
+ }) as stepper:
+ self.assertEqual(4, len(stepper.sorted_nodes()))
+ self.assertSetEqual({"ph0:0", "ph1:0", "x:0", "y:0"},
+ set(stepper.closure_elements()))
+
+ result = stepper.cont(self.x)
+ self.assertAllClose([[0.0], [5.5]], result)
+ self.assertEqual({
+ "ph0:0": NodeStepper.FEED_TYPE_CLIENT,
+ "ph1:0": NodeStepper.FEED_TYPE_CLIENT,
+ }, stepper.last_feed_types())
+
+ self.assertEqual(["x:0"], stepper.handle_names())
+ self.assertSetEqual({"x"}, stepper.handle_node_names())
+
+ result = stepper.cont(self.y)
+ self.assertAllClose([[-1.0], [6.0]], result)
+ self.assertEqual({
+ "x:0": NodeStepper.FEED_TYPE_HANDLE,
+ "ph1:0": NodeStepper.FEED_TYPE_CLIENT,
+ }, stepper.last_feed_types())
def testAttemptToContToPlaceholderWithTensorFeedKeysShouldWork(self):
"""Continuing to a placeholder should be allowed, using client feed."""
ph0_feed = [[1.0, 2.0], [-3.0, 5.0]]
ph1_feed = [[-1.0], [0.5]]
- stepper = NodeStepper(
+ with NodeStepper(
self.sess, self.y, feed_dict={
self.ph0: ph0_feed,
self.ph1: ph1_feed,
- })
-
- self.assertAllClose(ph0_feed, stepper.cont(self.ph0))
- self.assertEqual({
- self.ph0.name: NodeStepper.FEED_TYPE_CLIENT
- }, stepper.last_feed_types())
+ }) as stepper:
+ self.assertAllClose(ph0_feed, stepper.cont(self.ph0))
+ self.assertEqual({
+ self.ph0.name: NodeStepper.FEED_TYPE_CLIENT
+ }, stepper.last_feed_types())
- self.assertAllClose(ph1_feed, stepper.cont(self.ph1))
- self.assertEqual({
- self.ph1.name: NodeStepper.FEED_TYPE_CLIENT
- }, stepper.last_feed_types())
+ self.assertAllClose(ph1_feed, stepper.cont(self.ph1))
+ self.assertEqual({
+ self.ph1.name: NodeStepper.FEED_TYPE_CLIENT
+ }, stepper.last_feed_types())
- ph0_node = self.sess.graph.as_graph_element("ph0")
- self.assertAllClose(ph0_feed, stepper.cont(ph0_node))
- self.assertEqual({
- self.ph0.name: NodeStepper.FEED_TYPE_CLIENT
- }, stepper.last_feed_types())
+ ph0_node = self.sess.graph.as_graph_element("ph0")
+ self.assertAllClose(ph0_feed, stepper.cont(ph0_node))
+ self.assertEqual({
+ self.ph0.name: NodeStepper.FEED_TYPE_CLIENT
+ }, stepper.last_feed_types())
- self.assertAllClose([[-1.0], [6.0]], stepper.finalize())
+ self.assertAllClose([[-1.0], [6.0]], stepper.finalize())
def testAttemptToContToPlaceholderWithTensorNameFeedKeysShouldWork(self):
ph0_feed = [[1.0, 2.0], [-3.0, 5.0]]
ph1_feed = [[-1.0], [0.5]]
- stepper = NodeStepper(
+ with NodeStepper(
self.sess,
self.y,
feed_dict={
self.ph0.name: ph0_feed,
self.ph1.name: ph1_feed,
- })
+ }) as stepper:
+ self.assertAllClose(ph0_feed, stepper.cont(self.ph0))
+ self.assertEqual({
+ self.ph0.name: NodeStepper.FEED_TYPE_CLIENT
+ }, stepper.last_feed_types())
- self.assertAllClose(ph0_feed, stepper.cont(self.ph0))
- self.assertEqual({
- self.ph0.name: NodeStepper.FEED_TYPE_CLIENT
- }, stepper.last_feed_types())
+ self.assertAllClose(ph1_feed, stepper.cont(self.ph1))
+ self.assertEqual({
+ self.ph1.name: NodeStepper.FEED_TYPE_CLIENT
+ }, stepper.last_feed_types())
- self.assertAllClose(ph1_feed, stepper.cont(self.ph1))
- self.assertEqual({
- self.ph1.name: NodeStepper.FEED_TYPE_CLIENT
- }, stepper.last_feed_types())
+ ph0_node = self.sess.graph.as_graph_element("ph0")
+ self.assertAllClose(ph0_feed, stepper.cont(ph0_node))
+ self.assertEqual({
+ self.ph0.name: NodeStepper.FEED_TYPE_CLIENT
+ }, stepper.last_feed_types())
- ph0_node = self.sess.graph.as_graph_element("ph0")
- self.assertAllClose(ph0_feed, stepper.cont(ph0_node))
- self.assertEqual({
- self.ph0.name: NodeStepper.FEED_TYPE_CLIENT
- }, stepper.last_feed_types())
+ self.assertAllClose([[-1.0], [6.0]], stepper.finalize())
- self.assertAllClose([[-1.0], [6.0]], stepper.finalize())
+
+class StepperAssignAddTest(test_util.TensorFlowTestCase):
+
+ def setUp(self):
+ self.v = variables.Variable(10.0, name="v")
+ self.p = math_ops.add(self.v, self.v, name="p")
+ self.q = math_ops.multiply(self.p, self.p, name="q")
+ self.delta = constant_op.constant(2.0, name="delta")
+ self.v_add = state_ops.assign_add(self.v, self.delta, name="v_add")
+ self.v_add_plus_one = math_ops.add(self.v_add,
+ 1.0,
+ name="v_add_plus_one")
+
+ self.sess = session.Session()
+ self.sess.run(self.v.initializer)
+
+ def tearDown(self):
+ ops.reset_default_graph()
+
+ def testContToUpdateInvalidatesDumpedIntermedates(self):
+ with NodeStepper(self.sess, [self.q, self.v_add]) as stepper:
+ self.assertAllClose(400.0, stepper.cont("q:0"))
+ self.assertItemsEqual(["v/read:0", "p:0"],
+ stepper.intermediate_tensor_names())
+ self.assertAllClose(10.0, stepper.get_tensor_value("v/read:0"))
+ self.assertAllClose(20.0, stepper.get_tensor_value("p:0"))
+
+ self.assertAllClose(
+ 12.0, stepper.cont(
+ self.v_add, invalidate_from_updated_variables=True))
+ self.assertAllClose(12.0, self.sess.run(self.v))
+ self.assertItemsEqual(["v:0"], stepper.dirty_variables())
+ # Updating the value of v by calling v_add should have invalidated the
+ # dumped intermediate tensors for v/read:0 and p:0.
+ self.assertItemsEqual(["delta:0"], stepper.intermediate_tensor_names())
+ with self.assertRaisesRegexp(
+ ValueError,
+ r"This stepper instance does not have access to the value of tensor "
+ r"\"p:0\""):
+ stepper.get_tensor_value("p:0")
+
+ # The next cont to q should not have used any dumped intermediate tensors
+ # and its result should reflect the updated value.
+ self.assertAllClose(576.0, stepper.cont("q:0"))
+ self.assertEqual({}, stepper.last_feed_types())
+
+ def testOverridingUpstreamTensorInvalidatesDumpedIntermediates(self):
+ with NodeStepper(self.sess, self.q) as stepper:
+ self.assertAllClose(400.0, stepper.cont("q:0"))
+ self.assertItemsEqual(["v/read:0", "p:0"],
+ stepper.intermediate_tensor_names())
+ self.assertAllClose(10.0, stepper.get_tensor_value("v/read:0"))
+ self.assertAllClose(20.0, stepper.get_tensor_value("p:0"))
+
+ stepper.override_tensor("v/read:0", 11.0)
+ self.assertItemsEqual(["v/read:0"], stepper.override_names())
+ # Overriding the upstream v/read:0 should have invalidated the dumped
+ # intermediate tensor for the downstream p:0.
+ self.assertItemsEqual([], stepper.intermediate_tensor_names())
+
+ # The next cont to q should not have used any dumped intermediate tensors
+ # and its result should reflect the overriding value.
+ self.assertAllClose(484.0, stepper.cont("q:0"))
+ self.assertEqual({
+ "v/read:0": NodeStepper.FEED_TYPE_OVERRIDE
+ }, stepper.last_feed_types())
+
+ def testRemovingOverrideToUpstreamTensorInvalidatesDumpedIntermediates(self):
+ with NodeStepper(self.sess, self.q) as stepper:
+ stepper.override_tensor("v/read:0", 9.0)
+ self.assertItemsEqual(["v/read:0"], stepper.override_names())
+
+ self.assertAllClose(324.0, stepper.cont(self.q))
+ self.assertItemsEqual(["p:0"], stepper.intermediate_tensor_names())
+
+ stepper.remove_override("v/read:0")
+ self.assertItemsEqual([], stepper.override_names())
+ # Removing the pre-existing override to v/read:0 should have invalidated
+ # the dumped intermediate tensor.
+ self.assertItemsEqual([], stepper.intermediate_tensor_names())
+
+ def testRepeatedCallsToAssignAddDoesNotUpdateVariableAgain(self):
+ with NodeStepper(self.sess, self.v_add) as stepper:
+ stepper.cont(self.v_add)
+ self.assertAllClose(12.0, stepper.cont(self.v))
+ stepper.cont(self.v_add)
+ self.assertEqual({"v_add:0": NodeStepper.FEED_TYPE_HANDLE},
+ stepper.last_feed_types())
+ self.assertAllClose(12.0, stepper.cont(self.v))
+
+ def testRepeatedCallsToAssignAddDownStreamDoesNotUpdateVariableAgain(self):
+ with NodeStepper(self.sess, self.v_add_plus_one) as stepper:
+ stepper.cont(self.v_add_plus_one)
+ self.assertAllClose(12.0, stepper.cont(self.v))
+ stepper.cont(self.v_add_plus_one)
+ self.assertEqual({"v_add_plus_one:0": NodeStepper.FEED_TYPE_HANDLE},
+ stepper.last_feed_types())
+ self.assertAllClose(12.0, stepper.cont(self.v))
class StepperBackwardRunTest(test_util.TensorFlowTestCase):
@@ -580,93 +705,136 @@ class StepperBackwardRunTest(test_util.TensorFlowTestCase):
ops.reset_default_graph()
def testContToUpdateA(self):
- stepper = NodeStepper(self.sess, "optim")
-
- result = stepper.cont("a:0")
- self.assertAllClose(1.0, result)
- self.assertEqual({}, stepper.last_feed_types())
-
- result = stepper.cont("optim/learning_rate:0")
- self.assertAllClose(0.01, result)
- self.assertEqual({}, stepper.last_feed_types())
-
- # Before any cont calls on ApplyGradientDescent, there should be no "dirty"
- # variables.
- self.assertEqual(set(), stepper.dirty_variables())
-
- # First, all the two control inputs to optim.
- result = stepper.cont("optim/update_a/ApplyGradientDescent")
-
- # Now variable a should have been marked as dirty due to the update
- # by optim/update_a/ApplyGradientDescent.
- self.assertEqual({"a:0"}, stepper.dirty_variables())
- self.assertIsNone(result)
- self.assertEqual({
- "optim/learning_rate:0": NodeStepper.FEED_TYPE_HANDLE
- }, stepper.last_feed_types())
-
- # Check that Variable "a" has been updated properly, but "b", "c" and "d"
- # remain the same.
- # For backprop on Variable a:
- # Because f = a * b * b * c, df / da = b * b * c.
- # 1.0 - learning_rate * b * b * c
- # = 1.0 - 0.01 * 2.0 * 2.0 * 4.0 = 0.84.
- self.assertAllClose(0.84, self.sess.run(self.a))
- self.assertAllClose(2.0, self.sess.run(self.b))
- self.assertAllClose(4.0, self.sess.run(self.c))
+ with NodeStepper(self.sess, "optim") as stepper:
+ result = stepper.cont("a:0")
+ self.assertAllClose(1.0, result)
+ self.assertEqual({}, stepper.last_feed_types())
+
+ result = stepper.cont("optim/learning_rate:0")
+ self.assertAllClose(0.01, result)
+ self.assertEqual({}, stepper.last_feed_types())
+
+ # Before any cont calls on ApplyGradientDescent, there should be no
+ # "dirty" variables.
+ self.assertEqual(set(), stepper.dirty_variables())
+
+ # First, all the two control inputs to optim.
+ result = stepper.cont("optim/update_a/ApplyGradientDescent",
+ invalidate_from_updated_variables=True)
+
+ # Now variable a should have been marked as dirty due to the update
+ # by optim/update_a/ApplyGradientDescent.
+ self.assertEqual({"a:0"}, stepper.dirty_variables())
+ self.assertIsNone(result)
+ self.assertEqual({
+ "optim/learning_rate:0": NodeStepper.FEED_TYPE_HANDLE
+ }, stepper.last_feed_types())
+
+ # Check that Variable "a" has been updated properly, but "b", "c" and "d"
+ # remain the same.
+ # For backprop on Variable a:
+ # Because f = a * b * b * c, df / da = b * b * c.
+ # 1.0 - learning_rate * b * b * c
+ # = 1.0 - 0.01 * 2.0 * 2.0 * 4.0 = 0.84.
+ self.assertAllClose(0.84, self.sess.run(self.a))
+ self.assertAllClose(2.0, self.sess.run(self.b))
+ self.assertAllClose(4.0, self.sess.run(self.c))
def testContToUpdateB(self):
- stepper = NodeStepper(self.sess, "optim")
-
- result = stepper.cont("optim/update_b/ApplyGradientDescent")
- self.assertIsNone(result)
- self.assertEqual(set(["b:0"]), stepper.dirty_variables())
-
- # For backprop on Variable b:
- # Because f = a * b * b * c, df / da = 2 * a * b * c.
- # 2.0 - learning_rate * 2 * a * b * c
- # = 2.0 - 0.01 * 2 * 1.0 * 2.0 * 4.0 = 1.84
- self.assertAllClose(1.0, self.sess.run(self.a))
- self.assertAllClose(1.84, self.sess.run(self.b))
- self.assertAllClose(4.0, self.sess.run(self.c))
+ with NodeStepper(self.sess, "optim") as stepper:
+ result = stepper.cont("optim/update_b/ApplyGradientDescent",
+ invalidate_from_updated_variables=True)
+ self.assertIsNone(result)
+ self.assertEqual(set(["b:0"]), stepper.dirty_variables())
+
+ # For backprop on Variable b:
+ # Because f = a * b * b * c, df / da = 2 * a * b * c.
+ # 2.0 - learning_rate * 2 * a * b * c
+ # = 2.0 - 0.01 * 2 * 1.0 * 2.0 * 4.0 = 1.84
+ self.assertAllClose(1.0, self.sess.run(self.a))
+ self.assertAllClose(1.84, self.sess.run(self.b))
+ self.assertAllClose(4.0, self.sess.run(self.c))
def testContAfterUpdateWithoutRestoringVariableValue(self):
- stepper = NodeStepper(self.sess, "optim")
-
- # First, update Variable a from 1.0 to 0.84.
- result = stepper.cont(
- "optim/update_a/ApplyGradientDescent", restore_variable_values=True)
- self.assertIsNone(result)
- self.assertEqual(set(["a:0"]), stepper.dirty_variables())
- self.assertAllClose(0.84, self.sess.run(self.a))
- self.assertAllClose(2.0, self.sess.run(self.b))
- self.assertAllClose(4.0, self.sess.run(self.c))
-
- # Second, update Variable b without the default restore_variable_values.
- result = stepper.cont(
- "optim/update_b/ApplyGradientDescent", restore_variable_values=False)
- self.assertIsNone(result)
- # For the backprop on Variable b under the updated value of a:
- # 2.0 - learning_rate * 2 * a' * b * c
- # = 2.0 - 0.01 * 2 * 0.84 * 2.0 * 4.0 = 1.8656
- self.assertAllClose(0.84, self.sess.run(self.a))
- self.assertAllClose(1.8656, self.sess.run(self.b))
- self.assertAllClose(4.0, self.sess.run(self.c))
+ with NodeStepper(self.sess, "optim") as stepper:
+ # First, update Variable a from 1.0 to 0.84.
+ result = stepper.cont(
+ "optim/update_a/ApplyGradientDescent",
+ invalidate_from_updated_variables=True,
+ restore_variable_values=True)
+ self.assertIsNone(result)
+ self.assertEqual(set(["a:0"]), stepper.dirty_variables())
+ self.assertAllClose(0.84, self.sess.run(self.a))
+ self.assertAllClose(2.0, self.sess.run(self.b))
+ self.assertAllClose(4.0, self.sess.run(self.c))
+ # Tracking of the updated variables should have invalidated all
+ # intermediate tensors downstream to a:0.
+ self.assertNotIn("a/read:0", stepper.intermediate_tensor_names())
+ self.assertNotIn("d:0", stepper.intermediate_tensor_names())
+
+ # Second, update Variable b without the default restore_variable_values.
+ result = stepper.cont(
+ "optim/update_b/ApplyGradientDescent", restore_variable_values=False)
+ self.assertIsNone(result)
+ # For the backprop on Variable b under the updated value of a:
+ # 2.0 - learning_rate * 2 * a' * b * c
+ # = 2.0 - 0.01 * 2 * 0.84 * 2.0 * 4.0 = 1.8656
+ self.assertAllClose(0.84, self.sess.run(self.a))
+ self.assertAllClose(1.8656, self.sess.run(self.b))
+ self.assertAllClose(4.0, self.sess.run(self.c))
+
+ def testContNotInvalidatingFromVariableUpdatesWorksForNextUpdate(self):
+ with NodeStepper(self.sess, "optim") as stepper:
+ self.assertIsNone(stepper.cont(
+ "optim/update_a/ApplyGradientDescent",
+ invalidate_from_updated_variables=False))
+ # Even though invalidate_from_updated_variables is set to False, dirty
+ # variables should still have been tracked.
+ self.assertEqual({"a:0"}, stepper.dirty_variables())
+ self.assertIn("a/read:0", stepper.intermediate_tensor_names())
+ self.assertIn("b/read:0", stepper.intermediate_tensor_names())
+ self.assertIn("c/read:0", stepper.intermediate_tensor_names())
+ self.assertIn("d:0", stepper.intermediate_tensor_names())
+ self.assertIn("e:0", stepper.intermediate_tensor_names())
+ self.assertIn("optim/learning_rate:0",
+ stepper.intermediate_tensor_names())
+ self.assertNotIn("a:0", stepper.intermediate_tensor_names())
+ self.assertNotIn("b:0", stepper.intermediate_tensor_names())
+ self.assertNotIn("c:0", stepper.intermediate_tensor_names())
+
+ self.assertAllClose(0.84, self.sess.run(self.a))
+ self.assertAllClose(2.0, self.sess.run(self.b))
+ self.assertAllClose(4.0, self.sess.run(self.c))
+
+ # For the backprop on Variable b, the result should reflect the original
+ # value of Variable a, even though Variable a has actually been updated.
+ # 2.0 - learning_rate * 2 * a * b * c
+ # = 2.0 - 0.01 * 2 * 1.0 * 2.0 * 4.0 = 1.84
+ self.assertIsNone(stepper.cont(
+ "optim/update_b/ApplyGradientDescent",
+ invalidate_from_updated_variables=False,
+ restore_variable_values=False))
+ self.assertAllClose(0.84, self.sess.run(self.a))
+ self.assertAllClose(1.84, self.sess.run(self.b))
+ self.assertAllClose(4.0, self.sess.run(self.c))
def testUpdateTwiceRestoreVariable(self):
- stepper = NodeStepper(self.sess, "optim")
-
- result = stepper.cont(
- "optim/update_a/ApplyGradientDescent", restore_variable_values=True)
- self.assertIsNone(result)
- self.assertEqual({"a:0"}, stepper.dirty_variables())
-
- result = stepper.cont(
- "optim/update_b/ApplyGradientDescent", restore_variable_values=True)
- self.assertIsNone(result)
- # Variables a and c should have been restored and hence no longer dirty.
- # Variable b should have been marked as dirty.
- self.assertEqual({"b:0"}, stepper.dirty_variables())
+ with NodeStepper(self.sess, "optim") as stepper:
+ result = stepper.cont(
+ "optim/update_a/ApplyGradientDescent",
+ invalidate_from_updated_variables=True,
+ restore_variable_values=True)
+ self.assertIsNone(result)
+ self.assertEqual({"a:0"}, stepper.dirty_variables())
+
+ result = stepper.cont(
+ "optim/update_b/ApplyGradientDescent",
+ invalidate_from_updated_variables=True,
+ restore_variable_values=True)
+ self.assertIsNone(result)
+ # Variables a and c should have been restored and hence no longer dirty.
+ # Variable b should have been marked as dirty.
+ self.assertEqual({"b:0"}, stepper.dirty_variables())
# The result of the update should be identitcal to as if only update_b is
# run.
@@ -680,179 +848,203 @@ class StepperBackwardRunTest(test_util.TensorFlowTestCase):
"clean" means no Variables have been updated by preceding cont() calls.
"""
- stepper = NodeStepper(self.sess, "optim")
-
- # First, call cont() on the two tensors on the intermediate level: e and f.
- result = stepper.cont("d:0")
- self.assertAllClose(2.0, result)
- self.assertEqual({}, stepper.last_feed_types())
- self.assertEqual(set(), stepper.dirty_variables())
-
- # The cont call above should have restored Variable "b".
- result = stepper.cont("e:0")
- self.assertAllClose(8.0, result)
- self.assertEqual({}, stepper.last_feed_types())
- self.assertEqual(set(), stepper.dirty_variables())
-
- # Now run update_a, so as to let Variable a be diry.
- result = stepper.cont(
- "optim/update_a/ApplyGradientDescent", restore_variable_values=True)
- self.assertIsNone(result)
- self.assertEqual({"a:0"}, stepper.dirty_variables())
-
- # Now, run update_b.
- result = stepper.cont(
- "optim/update_b/ApplyGradientDescent", restore_variable_values=True)
- self.assertIsNone(result)
-
- # The last cont() run should have use the handle of tensor e, but not the
- # handle of tensor d, because the transitive closure of e is clean, whereas
- # that of d is dirty due to the update to a in the previous cont() call.
- self.assertEqual({
- "e:0": NodeStepper.FEED_TYPE_HANDLE
- }, stepper.last_feed_types())
-
- # The result of the update_b should be identical to as if no other
- # update_* cont() calls have occurred before.
- self.assertAllClose(1.0, self.sess.run(self.a))
- self.assertAllClose(1.84, self.sess.run(self.b))
- self.assertAllClose(4.0, self.sess.run(self.c))
+ with NodeStepper(self.sess, "optim") as stepper:
+ # First, call cont() on the two tensors on the intermediate level: e and
+ # f.
+ result = stepper.cont("d:0")
+ self.assertAllClose(2.0, result)
+ self.assertEqual({}, stepper.last_feed_types())
+ self.assertItemsEqual(["a/read:0", "b/read:0"],
+ stepper.intermediate_tensor_names())
+ self.assertItemsEqual(["d:0"], stepper.handle_names())
+ self.assertEqual(set(), stepper.dirty_variables())
+
+ result = stepper.cont("e:0")
+ self.assertAllClose(8.0, result)
+ self.assertEqual({
+ "b/read:0": NodeStepper.FEED_TYPE_DUMPED_INTERMEDIATE
+ }, stepper.last_feed_types())
+ self.assertItemsEqual(["d:0", "e:0"], stepper.handle_names())
+ self.assertItemsEqual(["a/read:0", "b/read:0", "c/read:0"],
+ stepper.intermediate_tensor_names())
+ self.assertEqual(set(), stepper.dirty_variables())
+
+ # Now run update_a, so as to let Variable a be dirty.
+ result = stepper.cont(
+ "optim/update_a/ApplyGradientDescent",
+ invalidate_from_updated_variables=True,
+ restore_variable_values=True)
+ self.assertIsNone(result)
+ # Due to the update to the value of a:0, the dumped intermediate a/read:0
+ # should have been invalidated.
+ self.assertNotIn("a/read:0", stepper.intermediate_tensor_names())
+ # ["b/read:0", "c/read:0"],
+ self.assertEqual({"a:0"}, stepper.dirty_variables())
+
+ # Now, run update_b.
+ result = stepper.cont(
+ "optim/update_b/ApplyGradientDescent", restore_variable_values=True)
+ self.assertIsNone(result)
+
+ # The last cont() run should have use the handle of tensor e, but not the
+ # handle of tensor d, because the transitive closure of e is clean,
+ # whereas that of d is dirty due to the update to a in the previous cont()
+ # call.
+ last_feed_types = stepper.last_feed_types()
+ self.assertNotIn("d:0", last_feed_types)
+ self.assertEqual(NodeStepper.FEED_TYPE_DUMPED_INTERMEDIATE,
+ last_feed_types["b/read:0"])
+ self.assertEqual(NodeStepper.FEED_TYPE_DUMPED_INTERMEDIATE,
+ last_feed_types["c/read:0"])
+
+ # The result of the update_b should be identical to as if no other
+ # update_* cont() calls have occurred before.
+ self.assertAllClose(1.0, self.sess.run(self.a))
+ self.assertAllClose(1.84, self.sess.run(self.b))
+ self.assertAllClose(4.0, self.sess.run(self.c))
def testRestoreVariableValues(self):
"""Test restore_variable_values() restores the old values of variables."""
- stepper = NodeStepper(self.sess, "optim")
-
- stepper.cont(
- "optim/update_b/ApplyGradientDescent", restore_variable_values=True)
- self.assertAllClose(1.84, self.sess.run(self.b))
+ with NodeStepper(self.sess, "optim") as stepper:
+ stepper.cont(
+ "optim/update_b/ApplyGradientDescent",
+ invalidate_from_updated_variables=True,
+ restore_variable_values=True)
+ self.assertAllClose(1.84, self.sess.run(self.b))
- stepper.restore_variable_values()
- self.assertAllClose(2.0, self.sess.run(self.b))
+ stepper.restore_variable_values()
+ self.assertAllClose(2.0, self.sess.run(self.b))
def testFinalize(self):
"""Test finalize() to restore variables and run the original fetch."""
- stepper = NodeStepper(self.sess, "optim")
-
- # Invoke update_b before calling finalize.
- stepper.cont(
- "optim/update_b/ApplyGradientDescent", restore_variable_values=True)
+ with NodeStepper(self.sess, "optim") as stepper:
+ # Invoke update_b before calling finalize.
+ stepper.cont(
+ "optim/update_b/ApplyGradientDescent",
+ invalidate_from_updated_variables=True,
+ restore_variable_values=True)
- result = stepper.finalize()
- self.assertIsNone(result)
+ result = stepper.finalize()
+ self.assertIsNone(result)
- # The results of the Variable updates should be the same as if no cont()
- # call has occurred on update_b.
- self.assertAllClose(0.84, self.sess.run(self.a))
- self.assertAllClose(1.84, self.sess.run(self.b))
- self.assertAllClose(3.96, self.sess.run(self.c))
+ # The results of the Variable updates should be the same as if no cont()
+ # call has occurred on update_b.
+ self.assertAllClose(0.84, self.sess.run(self.a))
+ self.assertAllClose(1.84, self.sess.run(self.b))
+ self.assertAllClose(3.96, self.sess.run(self.c))
- def testOverrideThenContToUpdate(self):
+ def testOverrideThenContToUpdateThenRemoveOverrideThenUpdateAgain(self):
"""Test cont() to update nodes after overriding tensor values."""
- stepper = NodeStepper(self.sess, "optim")
-
- result = stepper.cont("d:0")
- self.assertAllClose(2.0, result)
- self.assertEqual({}, stepper.last_feed_types())
- self.assertEqual(set(), stepper.dirty_variables())
- self.assertEqual(["d:0"], stepper.handle_names())
- self.assertSetEqual({"d"}, stepper.handle_node_names())
-
- # Override the value from 1.0 to 10.0.
- stepper.override_tensor("a/read:0", 10.0)
-
- self.assertEqual(["a/read:0"], stepper.override_names())
-
- result = stepper.cont(
- "optim/update_c/ApplyGradientDescent", restore_variable_values=True)
- self.assertIsNone(result)
-
- # The last cont() call should have not used the tensor handle to d:0,
- # because the transitive closure of d:0 contains an override tensor.
- self.assertEqual({
- "a/read:0": NodeStepper.FEED_TYPE_OVERRIDE
- }, stepper.last_feed_types())
-
- # The tensor handle to d:0 should have been removed due to the dirty
- # transitive closure.
- self.assertEqual([], stepper.handle_names())
- self.assertSetEqual(set(), stepper.handle_node_names())
-
- # For this backprop on c, the overriding value of a/read:0 should have been
- # used:
- # 4.0 - learning_rate * a * b * b
- # = 4.0 - 0.01 * 10.0 * 2.0 * 2.0 = 3.6.
- self.assertAllClose(3.6, self.sess.run(self.c))
-
- # Now remove the overriding value of a/read:0.
- stepper.remove_override("a/read:0")
- self.assertEqual([], stepper.override_names())
-
- # Obtain the tensor handle to d:0 again.
- result = stepper.cont("d:0")
- self.assertAllClose(2.0, result)
- self.assertEqual(["d:0"], stepper.handle_names())
- self.assertSetEqual({"d"}, stepper.handle_node_names())
-
- # Then call update_c again, without restoring c.
- result = stepper.cont(
- "optim/update_c/ApplyGradientDescent", restore_variable_values=False)
- self.assertIsNone(result)
-
- # This time, the d:0 tensor handle should have been used, because its
- # transitive closure is clean.
- self.assertEqual({
- "d:0": NodeStepper.FEED_TYPE_HANDLE
- }, stepper.last_feed_types())
-
- # For this backprop on c, the overriding value of a/read:0 should have been
- # used:
- # 3.6 - learning_rate * a * b * b
- # = 3.6 - 0.01 * 1.0 * 2.0 * 2.0 = 3.56.
- self.assertAllClose(3.56, self.sess.run(self.c))
+ with NodeStepper(self.sess, "optim") as stepper:
+ result = stepper.cont("d:0")
+ self.assertAllClose(2.0, result)
+ self.assertEqual({}, stepper.last_feed_types())
+ self.assertEqual(set(), stepper.dirty_variables())
+ self.assertEqual(["d:0"], stepper.handle_names())
+ self.assertSetEqual({"d"}, stepper.handle_node_names())
+
+ # Override the value from 1.0 to 10.0.
+ stepper.override_tensor("a/read:0", 10.0)
+
+ self.assertEqual(["a/read:0"], stepper.override_names())
+
+ result = stepper.cont(
+ "optim/update_c/ApplyGradientDescent",
+ invalidate_from_updated_variables=True,
+ restore_variable_values=True)
+ self.assertIsNone(result)
+
+ # The last cont() call should have not used the tensor handle to d:0,
+ # because the transitive closure of d:0 contains an override tensor.
+ self.assertEqual({
+ "a/read:0": NodeStepper.FEED_TYPE_OVERRIDE,
+ "b/read:0": NodeStepper.FEED_TYPE_DUMPED_INTERMEDIATE,
+ }, stepper.last_feed_types())
+
+ # The tensor handle to d:0 should have been removed due to the dirty
+ # transitive closure.
+ self.assertEqual([], stepper.handle_names())
+ self.assertSetEqual(set(), stepper.handle_node_names())
+
+ # For this backprop on c, the overriding value of a/read:0 should have
+ # been used:
+ # 4.0 - learning_rate * a * b * b
+ # = 4.0 - 0.01 * 10.0 * 2.0 * 2.0 = 3.6.
+ self.assertAllClose(3.6, self.sess.run(self.c))
+
+ # Now remove the overriding value of a/read:0.
+ stepper.remove_override("a/read:0")
+ self.assertEqual([], stepper.override_names())
+
+ # Obtain the tensor handle to d:0 again.
+ result = stepper.cont("d:0")
+ self.assertAllClose(2.0, result)
+ self.assertEqual(["d:0"], stepper.handle_names())
+ self.assertSetEqual({"d"}, stepper.handle_node_names())
+ self.assertNotIn("a/read:0", stepper.last_feed_types())
+
+ # Then call update_c again, without restoring c.
+ result = stepper.cont("optim/update_c/ApplyGradientDescent",
+ restore_variable_values=False)
+ self.assertIsNone(result)
+ self.assertNotIn("a/read:0", stepper.last_feed_types())
+
+ # This time, the d:0 tensor handle should have been used, because its
+ # transitive closure is clean.
+ self.assertEqual({
+ "b/read:0": NodeStepper.FEED_TYPE_DUMPED_INTERMEDIATE,
+ "d:0": NodeStepper.FEED_TYPE_HANDLE,
+ "optim/learning_rate:0": NodeStepper.FEED_TYPE_DUMPED_INTERMEDIATE,
+ }, stepper.last_feed_types())
+
+ # For this backprop on c, the overriding value of a/read:0 should have
+ # been used:
+ # 3.6 - learning_rate * a * b * b
+ # = 3.6 - 0.01 * 1.0 * 2.0 * 2.0 = 3.56.
+ self.assertAllClose(3.56, self.sess.run(self.c))
def testContToNodeWithOutputTensors(self):
"""cont() to an op should cache its output tensors if appropriate."""
- stepper = NodeStepper(self.sess, "optim")
-
- # In the transitive closure of the stepper, look for an op of which the
- # output tensor also is in the transitive closure.
- # Do not assume a specific op, e.g., ""gradients/e_grad/Reshape_1",
- # because it may vary between builds.
- closure_elements = stepper.closure_elements()
- op_with_output_in_closure = None
- for element_name in closure_elements:
- if element_name + ":0" in closure_elements:
- op_with_output_in_closure = str(element_name)
- break
-
- self.assertEqual([0],
- stepper.output_slots_in_closure(op_with_output_in_closure))
-
- self.assertIsNotNone(op_with_output_in_closure)
- output_tensor = op_with_output_in_closure + ":0"
-
- # The op "gradients/?_grad/Reshape_1" is in the transitive closure of the
- # stepper, because it is the control input to another o. However, its
- # output tensor "gradients/?_grad/Reshape_1:0" is also in the transitive
- # closure, because it is the (non-control) input of certain ops. Calling
- # cont() on the op should lead to the caching of the tensor handle for
- # the output tensor.
- stepper.cont(op_with_output_in_closure)
-
- self.assertEqual([output_tensor], stepper.handle_names())
- self.assertSetEqual({op_with_output_in_closure},
- stepper.handle_node_names())
-
- # Do a cont() call that uses the cached tensor of
- # "gradients/?_grad/Reshape_1:0".
- stepper.cont(output_tensor)
- self.assertEqual({
- output_tensor: NodeStepper.FEED_TYPE_HANDLE
- }, stepper.last_feed_types())
+ with NodeStepper(self.sess, "optim") as stepper:
+ # In the transitive closure of the stepper, look for an op of which the
+ # output tensor also is in the transitive closure.
+ # Do not assume a specific op, e.g., ""gradients/e_grad/Reshape_1",
+ # because it may vary between builds.
+ closure_elements = stepper.closure_elements()
+ op_with_output_in_closure = None
+ for element_name in closure_elements:
+ if element_name + ":0" in closure_elements:
+ op_with_output_in_closure = str(element_name)
+ break
+
+ self.assertEqual(
+ [0], stepper.output_slots_in_closure(op_with_output_in_closure))
+
+ self.assertIsNotNone(op_with_output_in_closure)
+ output_tensor = op_with_output_in_closure + ":0"
+
+ # The op "gradients/?_grad/Reshape_1" is in the transitive closure of the
+ # stepper, because it is the control input to another o. However, its
+ # output tensor "gradients/?_grad/Reshape_1:0" is also in the transitive
+ # closure, because it is the (non-control) input of certain ops. Calling
+ # cont() on the op should lead to the caching of the tensor handle for
+ # the output tensor.
+ stepper.cont(op_with_output_in_closure)
+
+ self.assertEqual([output_tensor], stepper.handle_names())
+ self.assertSetEqual({op_with_output_in_closure},
+ stepper.handle_node_names())
+
+ # Do a cont() call that uses the cached tensor of
+ # "gradients/?_grad/Reshape_1:0".
+ stepper.cont(output_tensor)
+ self.assertEqual({
+ output_tensor: NodeStepper.FEED_TYPE_HANDLE
+ }, stepper.last_feed_types())
if __name__ == "__main__":
diff --git a/tensorflow/python/debug/wrappers/framework.py b/tensorflow/python/debug/wrappers/framework.py
index 145008e902..00ab7277cb 100644
--- a/tensorflow/python/debug/wrappers/framework.py
+++ b/tensorflow/python/debug/wrappers/framework.py
@@ -427,9 +427,10 @@ class BaseDebugWrapperSession(session.SessionInterface):
elif (run_start_resp.action == OnRunStartAction.NON_DEBUG_RUN or
run_start_resp.action == OnRunStartAction.INVOKE_STEPPER):
if run_start_resp.action == OnRunStartAction.INVOKE_STEPPER:
- retvals = self.invoke_node_stepper(
- stepper.NodeStepper(self._sess, fetches, feed_dict),
- restore_variable_values_on_exit=True)
+ with stepper.NodeStepper(
+ self._sess, fetches, feed_dict) as node_stepper:
+ retvals = self.invoke_node_stepper(
+ node_stepper, restore_variable_values_on_exit=True)
# Invoke run() method of the wrapped session.
retvals = self._sess.run(
diff --git a/tensorflow/python/debug/wrappers/hooks.py b/tensorflow/python/debug/wrappers/hooks.py
index 30f0e117e6..58a8f7f9c9 100644
--- a/tensorflow/python/debug/wrappers/hooks.py
+++ b/tensorflow/python/debug/wrappers/hooks.py
@@ -107,10 +107,13 @@ class LocalCLIDebugHook(session_run_hook.SessionRunHook,
run_context.session.graph._finalized = False
# pylint: enable=protected-access
- self.invoke_node_stepper(
- stepper.NodeStepper(run_context.session, run_context.original_args.
- fetches, run_context.original_args.feed_dict),
- restore_variable_values_on_exit=True)
+ with stepper.NodeStepper(
+ run_context.session,
+ run_context.original_args.
+ fetches,
+ run_context.original_args.feed_dict) as node_stepper:
+ self.invoke_node_stepper(
+ node_stepper, restore_variable_values_on_exit=True)
return run_args