aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/checkpoint
diff options
context:
space:
mode:
authorGravatar Allen Lavoie <allenl@google.com>2018-04-20 12:40:57 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-20 12:43:51 -0700
commit517d1912f4ec71180944320350a3694332a1dedc (patch)
tree23fd05eb7e72c8ce13766a516280622089cd0c70 /tensorflow/contrib/checkpoint
parentb23e91d247368f2046dae035b5c7bdda56512077 (diff)
Add a utility to visualize object-based checkpoints
Useful for generating a warm fuzzy feeling that everything you think should be saved was saved, and for explaining what object-based checkpointing is. (Also useful on the former front will be a planned "assert that all of this Graph's trainable variables are accessible from object X" function.) Somewhat hacky since it generates strings rather than using the pydot bindings (and so works without a pydot dependency). PiperOrigin-RevId: 193708003
Diffstat (limited to 'tensorflow/contrib/checkpoint')
-rw-r--r--tensorflow/contrib/checkpoint/__init__.py3
-rw-r--r--tensorflow/contrib/checkpoint/python/BUILD32
-rw-r--r--tensorflow/contrib/checkpoint/python/visualize.py111
-rw-r--r--tensorflow/contrib/checkpoint/python/visualize_test.py97
4 files changed, 243 insertions, 0 deletions
diff --git a/tensorflow/contrib/checkpoint/__init__.py b/tensorflow/contrib/checkpoint/__init__.py
index 70d7d2d8d7..1192cc44a1 100644
--- a/tensorflow/contrib/checkpoint/__init__.py
+++ b/tensorflow/contrib/checkpoint/__init__.py
@@ -16,6 +16,7 @@
For creating and managing dependencies:
+@@dot_graph_from_checkpoint
@@split_dependency
"""
@@ -24,6 +25,8 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.checkpoint.python.split_dependency import split_dependency
+from tensorflow.contrib.checkpoint.python.visualize import dot_graph_from_checkpoint
+
from tensorflow.python.util.all_util import remove_undocumented
remove_undocumented(module_name=__name__)
diff --git a/tensorflow/contrib/checkpoint/python/BUILD b/tensorflow/contrib/checkpoint/python/BUILD
index d57b01aab2..a5681ffa61 100644
--- a/tensorflow/contrib/checkpoint/python/BUILD
+++ b/tensorflow/contrib/checkpoint/python/BUILD
@@ -5,6 +5,15 @@ package(default_visibility = ["//tensorflow:internal"])
load("//tensorflow:tensorflow.bzl", "py_test")
py_library(
+ name = "checkpoint",
+ srcs_version = "PY2AND3",
+ deps = [
+ ":split_dependency",
+ ":visualize",
+ ],
+)
+
+py_library(
name = "split_dependency",
srcs = ["split_dependency.py"],
srcs_version = "PY2AND3",
@@ -27,3 +36,26 @@ py_test(
"//tensorflow/python/eager:test",
],
)
+
+py_library(
+ name = "visualize",
+ srcs = ["visualize.py"],
+ srcs_version = "PY2AND3",
+ visibility = ["//tensorflow:internal"],
+ deps = [
+ "//tensorflow/python:pywrap_tensorflow",
+ ],
+)
+
+py_test(
+ name = "visualize_test",
+ srcs = ["visualize_test.py"],
+ deps = [
+ ":visualize",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:resource_variable_ops",
+ "//tensorflow/python:training",
+ "//tensorflow/python/eager:test",
+ ],
+)
diff --git a/tensorflow/contrib/checkpoint/python/visualize.py b/tensorflow/contrib/checkpoint/python/visualize.py
new file mode 100644
index 0000000000..86fbdb41d2
--- /dev/null
+++ b/tensorflow/contrib/checkpoint/python/visualize.py
@@ -0,0 +1,111 @@
+"""Utilities for visualizing dependency graphs."""
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.core.protobuf import checkpointable_object_graph_pb2
+from tensorflow.python import pywrap_tensorflow
+from tensorflow.python.framework import errors_impl
+from tensorflow.python.training import checkpointable
+
+
+def dot_graph_from_checkpoint(save_path):
+ r"""Visualizes an object-based checkpoint (from `tf.train.Checkpoint`).
+
+ Useful for inspecting checkpoints and debugging loading issues.
+
+ Example usage from Python (requires pydot):
+ ```python
+ import tensorflow as tf
+ import pydot
+
+ dot_string = tf.contrib.checkpoint.dot_graph_from_checkpoint('/path/to/ckpt')
+ parsed, = pydot.graph_from_dot_data(dot_string)
+ parsed.write_svg('/tmp/tensorflow/visualized_checkpoint.svg')
+ ```
+
+ Example command line usage:
+ ```sh
+ python -c "import tensorflow as tf;\
+ print(tf.contrib.checkpoint.dot_graph_from_checkpoint('/path/to/ckpt'))"\
+ | dot -Tsvg > /tmp/tensorflow/checkpoint_viz.svg
+ ```
+
+ Args:
+ save_path: The checkpoint prefix, as returned by `tf.train.Checkpoint.save`
+ or `tf.train.latest_checkpoint`.
+ Returns:
+ A graph in DOT format as a string.
+ """
+ reader = pywrap_tensorflow.NewCheckpointReader(save_path)
+ try:
+ object_graph_string = reader.get_tensor(
+ checkpointable.OBJECT_GRAPH_PROTO_KEY)
+ except errors_impl.NotFoundError:
+ raise ValueError(
+ ('The specified checkpoint "%s" does not appear to be object-based (it '
+ 'is missing the key "%s"). Likely it was created with a name-based '
+ 'saver and does not contain an object dependency graph.') % (
+ save_path, checkpointable.OBJECT_GRAPH_PROTO_KEY))
+ shape_map = reader.get_variable_to_shape_map()
+ dtype_map = reader.get_variable_to_dtype_map()
+ object_graph = (
+ checkpointable_object_graph_pb2.CheckpointableObjectGraph())
+ object_graph.ParseFromString(object_graph_string)
+ graph = 'digraph {\n'
+ def _escape(name):
+ return name.replace('"', '\\"')
+ slot_ids = set()
+ for node in object_graph.nodes:
+ for slot_reference in node.slot_variables:
+ slot_ids.add(slot_reference.slot_variable_node_id)
+ for node_id, node in enumerate(object_graph.nodes):
+ if (len(node.attributes) == 1
+ and node.attributes[0].name == checkpointable.VARIABLE_VALUE_KEY):
+ if node_id in slot_ids:
+ color = 'orange'
+ tooltip_prefix = 'Slot variable'
+ else:
+ color = 'blue'
+ tooltip_prefix = 'Variable'
+ attribute = node.attributes[0]
+ graph += ('N_%d [shape=point label="" color=%s width=.25'
+ ' tooltip="%s %s shape=%s %s"]\n') % (
+ node_id,
+ color,
+ tooltip_prefix,
+ _escape(attribute.full_name),
+ shape_map[attribute.checkpoint_key],
+ dtype_map[attribute.checkpoint_key].name)
+ elif node.slot_variables:
+ graph += ('N_%d [shape=point label="" width=.25 color=red,'
+ 'tooltip="Optimizer"]\n') % node_id
+ else:
+ graph += 'N_%d [shape=point label="" width=.25]\n' % node_id
+ for reference in node.children:
+ graph += 'N_%d -> N_%d [label="%s"]\n' % (
+ node_id, reference.node_id, _escape(reference.local_name))
+ for slot_reference in node.slot_variables:
+ graph += 'N_%d -> N_%d [label="%s" style=dotted]\n' % (
+ node_id,
+ slot_reference.slot_variable_node_id,
+ _escape(slot_reference.slot_name))
+ graph += 'N_%d -> N_%d [style=dotted]\n' % (
+ slot_reference.original_variable_node_id,
+ slot_reference.slot_variable_node_id)
+ graph += '}\n'
+ return graph
diff --git a/tensorflow/contrib/checkpoint/python/visualize_test.py b/tensorflow/contrib/checkpoint/python/visualize_test.py
new file mode 100644
index 0000000000..1d9ab78923
--- /dev/null
+++ b/tensorflow/contrib/checkpoint/python/visualize_test.py
@@ -0,0 +1,97 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import functools
+import os
+
+from tensorflow.contrib.checkpoint.python import visualize
+
+from tensorflow.python.eager import context
+from tensorflow.python.eager import test
+from tensorflow.python.framework import constant_op
+from tensorflow.python.keras._impl.keras.engine import training
+from tensorflow.python.keras._impl.keras.layers import core
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.training import adam
+from tensorflow.python.training import checkpointable_utils
+
+try:
+ import pydot # pylint: disable=g-import-not-at-top
+except ImportError:
+ pydot = None
+
+
+class MyModel(training.Model):
+ """A concrete Model for testing."""
+
+ def __init__(self):
+ super(MyModel, self).__init__()
+ self._named_dense = core.Dense(1, use_bias=True)
+ self._second = core.Dense(1, use_bias=False)
+
+ def call(self, values):
+ ret = self._second(self._named_dense(values))
+ return ret
+
+
+class DotGraphTests(test.TestCase):
+
+ def testMakeDotGraph(self):
+ with context.eager_mode():
+ input_value = constant_op.constant([[3.]])
+ model = MyModel()
+ optimizer = adam.AdamOptimizer(0.001)
+ optimizer_step = resource_variable_ops.ResourceVariable(12)
+ save_checkpoint = checkpointable_utils.Checkpoint(
+ optimizer=optimizer, model=model, optimizer_step=optimizer_step)
+ optimizer.minimize(functools.partial(model, input_value))
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt')
+ save_path = save_checkpoint.save(checkpoint_prefix)
+ prefix = save_checkpoint.save(save_path)
+
+ dot_graph_string = visualize.dot_graph_from_checkpoint(prefix)
+
+ # The remainder of this test is more-or-less optional since it's so
+ # dependent on pydot/platform/Python versions.
+ if pydot is None:
+ self.skipTest('pydot is required for the remainder of this test.')
+ try:
+ parsed, = pydot.graph_from_dot_data(dot_graph_string)
+ except NameError as e:
+ if "name 'dot_parser' is not defined" in str(e):
+ self.skipTest("pydot isn't working")
+ else:
+ raise
+ # Check that the graph isn't completely trivial
+ self.assertEqual(
+ '"model"',
+ parsed.obj_dict['edges'][('N_0', 'N_1')][0]['attributes']['label'])
+ image_path = os.path.join(self.get_temp_dir(), 'saved.svg')
+ try:
+ parsed.write_svg(image_path)
+ except Exception as e: # pylint: disable=broad-except
+ # For some reason PyDot's "dot not available" error is an Exception, not
+ # something more specific.
+ if '"dot" not found in path' in str(e):
+ self.skipTest("pydot won't save SVGs (dot not available)")
+ else:
+ raise
+
+if __name__ == '__main__':
+ test.main()