aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/checkpoint/python
diff options
context:
space:
mode:
authorGravatar Allen Lavoie <allenl@google.com>2018-05-03 13:22:33 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-03 13:44:41 -0700
commit5a64e609d0eb94244067f5d7514605863c9f37c3 (patch)
treedbfe0ff8e2aa3ff788beabef757e90f4c2dbe4df /tensorflow/contrib/checkpoint/python
parent41dcb67efd272e9ce0e5071433f42a9d540ec6dc (diff)
Checkpointable: Utilities to read object metadata
Useful for inspecting checkpoints programatically (e.g. in unit tests). PiperOrigin-RevId: 195300780
Diffstat (limited to 'tensorflow/contrib/checkpoint/python')
-rw-r--r--tensorflow/contrib/checkpoint/python/visualize.py16
1 files changed, 2 insertions, 14 deletions
diff --git a/tensorflow/contrib/checkpoint/python/visualize.py b/tensorflow/contrib/checkpoint/python/visualize.py
index 86fbdb41d2..9a3b23bb2c 100644
--- a/tensorflow/contrib/checkpoint/python/visualize.py
+++ b/tensorflow/contrib/checkpoint/python/visualize.py
@@ -17,10 +17,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.core.protobuf import checkpointable_object_graph_pb2
from tensorflow.python import pywrap_tensorflow
-from tensorflow.python.framework import errors_impl
from tensorflow.python.training import checkpointable
+from tensorflow.python.training import checkpointable_utils
def dot_graph_from_checkpoint(save_path):
@@ -52,20 +51,9 @@ def dot_graph_from_checkpoint(save_path):
A graph in DOT format as a string.
"""
reader = pywrap_tensorflow.NewCheckpointReader(save_path)
- try:
- object_graph_string = reader.get_tensor(
- checkpointable.OBJECT_GRAPH_PROTO_KEY)
- except errors_impl.NotFoundError:
- raise ValueError(
- ('The specified checkpoint "%s" does not appear to be object-based (it '
- 'is missing the key "%s"). Likely it was created with a name-based '
- 'saver and does not contain an object dependency graph.') % (
- save_path, checkpointable.OBJECT_GRAPH_PROTO_KEY))
+ object_graph = checkpointable_utils.object_metadata(save_path)
shape_map = reader.get_variable_to_shape_map()
dtype_map = reader.get_variable_to_dtype_map()
- object_graph = (
- checkpointable_object_graph_pb2.CheckpointableObjectGraph())
- object_graph.ParseFromString(object_graph_string)
graph = 'digraph {\n'
def _escape(name):
return name.replace('"', '\\"')