diff options
author | Allen Lavoie <allenl@google.com> | 2018-05-03 13:22:33 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-05-03 13:44:41 -0700 |
commit | 5a64e609d0eb94244067f5d7514605863c9f37c3 (patch) | |
tree | dbfe0ff8e2aa3ff788beabef757e90f4c2dbe4df /tensorflow/contrib/checkpoint | |
parent | 41dcb67efd272e9ce0e5071433f42a9d540ec6dc (diff) |
Checkpointable: Utilities to read object metadata
Useful for inspecting checkpoints programatically (e.g. in unit tests).
PiperOrigin-RevId: 195300780
Diffstat (limited to 'tensorflow/contrib/checkpoint')
-rw-r--r-- | tensorflow/contrib/checkpoint/__init__.py | 4 | ||||
-rw-r--r-- | tensorflow/contrib/checkpoint/python/visualize.py | 16 |
2 files changed, 6 insertions, 14 deletions
diff --git a/tensorflow/contrib/checkpoint/__init__.py b/tensorflow/contrib/checkpoint/__init__.py index 1192cc44a1..d2c30f1215 100644 --- a/tensorflow/contrib/checkpoint/__init__.py +++ b/tensorflow/contrib/checkpoint/__init__.py @@ -16,7 +16,9 @@ For creating and managing dependencies: +@@CheckpointableObjectGraph @@dot_graph_from_checkpoint +@@object_metadata @@split_dependency """ @@ -26,6 +28,8 @@ from __future__ import print_function from tensorflow.contrib.checkpoint.python.split_dependency import split_dependency from tensorflow.contrib.checkpoint.python.visualize import dot_graph_from_checkpoint +from tensorflow.core.protobuf.checkpointable_object_graph_pb2 import CheckpointableObjectGraph +from tensorflow.python.training.checkpointable_utils import object_metadata from tensorflow.python.util.all_util import remove_undocumented diff --git a/tensorflow/contrib/checkpoint/python/visualize.py b/tensorflow/contrib/checkpoint/python/visualize.py index 86fbdb41d2..9a3b23bb2c 100644 --- a/tensorflow/contrib/checkpoint/python/visualize.py +++ b/tensorflow/contrib/checkpoint/python/visualize.py @@ -17,10 +17,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.core.protobuf import checkpointable_object_graph_pb2 from tensorflow.python import pywrap_tensorflow -from tensorflow.python.framework import errors_impl from tensorflow.python.training import checkpointable +from tensorflow.python.training import checkpointable_utils def dot_graph_from_checkpoint(save_path): @@ -52,20 +51,9 @@ def dot_graph_from_checkpoint(save_path): A graph in DOT format as a string. """ reader = pywrap_tensorflow.NewCheckpointReader(save_path) - try: - object_graph_string = reader.get_tensor( - checkpointable.OBJECT_GRAPH_PROTO_KEY) - except errors_impl.NotFoundError: - raise ValueError( - ('The specified checkpoint "%s" does not appear to be object-based (it ' - 'is missing the key "%s"). Likely it was created with a name-based ' - 'saver and does not contain an object dependency graph.') % ( - save_path, checkpointable.OBJECT_GRAPH_PROTO_KEY)) + object_graph = checkpointable_utils.object_metadata(save_path) shape_map = reader.get_variable_to_shape_map() dtype_map = reader.get_variable_to_dtype_map() - object_graph = ( - checkpointable_object_graph_pb2.CheckpointableObjectGraph()) - object_graph.ParseFromString(object_graph_string) graph = 'digraph {\n' def _escape(name): return name.replace('"', '\\"') |