aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/checkpoint
diff options
context:
space:
mode:
authorGravatar Allen Lavoie <allenl@google.com>2018-05-03 13:22:33 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-03 13:44:41 -0700
commit5a64e609d0eb94244067f5d7514605863c9f37c3 (patch)
treedbfe0ff8e2aa3ff788beabef757e90f4c2dbe4df /tensorflow/contrib/checkpoint
parent41dcb67efd272e9ce0e5071433f42a9d540ec6dc (diff)
Checkpointable: Utilities to read object metadata
Useful for inspecting checkpoints programatically (e.g. in unit tests). PiperOrigin-RevId: 195300780
Diffstat (limited to 'tensorflow/contrib/checkpoint')
-rw-r--r--tensorflow/contrib/checkpoint/__init__.py4
-rw-r--r--tensorflow/contrib/checkpoint/python/visualize.py16
2 files changed, 6 insertions, 14 deletions
diff --git a/tensorflow/contrib/checkpoint/__init__.py b/tensorflow/contrib/checkpoint/__init__.py
index 1192cc44a1..d2c30f1215 100644
--- a/tensorflow/contrib/checkpoint/__init__.py
+++ b/tensorflow/contrib/checkpoint/__init__.py
@@ -16,7 +16,9 @@
For creating and managing dependencies:
+@@CheckpointableObjectGraph
@@dot_graph_from_checkpoint
+@@object_metadata
@@split_dependency
"""
@@ -26,6 +28,8 @@ from __future__ import print_function
from tensorflow.contrib.checkpoint.python.split_dependency import split_dependency
from tensorflow.contrib.checkpoint.python.visualize import dot_graph_from_checkpoint
+from tensorflow.core.protobuf.checkpointable_object_graph_pb2 import CheckpointableObjectGraph
+from tensorflow.python.training.checkpointable_utils import object_metadata
from tensorflow.python.util.all_util import remove_undocumented
diff --git a/tensorflow/contrib/checkpoint/python/visualize.py b/tensorflow/contrib/checkpoint/python/visualize.py
index 86fbdb41d2..9a3b23bb2c 100644
--- a/tensorflow/contrib/checkpoint/python/visualize.py
+++ b/tensorflow/contrib/checkpoint/python/visualize.py
@@ -17,10 +17,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.core.protobuf import checkpointable_object_graph_pb2
from tensorflow.python import pywrap_tensorflow
-from tensorflow.python.framework import errors_impl
from tensorflow.python.training import checkpointable
+from tensorflow.python.training import checkpointable_utils
def dot_graph_from_checkpoint(save_path):
@@ -52,20 +51,9 @@ def dot_graph_from_checkpoint(save_path):
A graph in DOT format as a string.
"""
reader = pywrap_tensorflow.NewCheckpointReader(save_path)
- try:
- object_graph_string = reader.get_tensor(
- checkpointable.OBJECT_GRAPH_PROTO_KEY)
- except errors_impl.NotFoundError:
- raise ValueError(
- ('The specified checkpoint "%s" does not appear to be object-based (it '
- 'is missing the key "%s"). Likely it was created with a name-based '
- 'saver and does not contain an object dependency graph.') % (
- save_path, checkpointable.OBJECT_GRAPH_PROTO_KEY))
+ object_graph = checkpointable_utils.object_metadata(save_path)
shape_map = reader.get_variable_to_shape_map()
dtype_map = reader.get_variable_to_dtype_map()
- object_graph = (
- checkpointable_object_graph_pb2.CheckpointableObjectGraph())
- object_graph.ParseFromString(object_graph_string)
graph = 'digraph {\n'
def _escape(name):
return name.replace('"', '\\"')