aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/python/training/checkpointable/base.py66
-rw-r--r--tensorflow/python/training/checkpointable/util.py192
-rw-r--r--tensorflow/python/training/checkpointable/util_test.py40
3 files changed, 235 insertions, 63 deletions
diff --git a/tensorflow/python/training/checkpointable/base.py b/tensorflow/python/training/checkpointable/base.py
index 9189d8f3e8..095a90ddd4 100644
--- a/tensorflow/python/training/checkpointable/base.py
+++ b/tensorflow/python/training/checkpointable/base.py
@@ -17,11 +17,14 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import abc
import collections
import functools
import json
import weakref
+import six
+
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
@@ -91,7 +94,45 @@ class CheckpointInitialValue(ops.Tensor):
return self._checkpoint_position
-class PythonStringStateSaveable(saveable_object.SaveableObject):
+class NoRestoreSaveable(saveable_object.SaveableObject):
+ """Embeds a tensor in a checkpoint with no restore ops."""
+
+ def __init__(self, tensor, name, dtype=None):
+ spec = saveable_object.SaveSpec(tensor, "", name, dtype=dtype)
+ super(NoRestoreSaveable, self).__init__(tensor, [spec], name)
+
+ def restore(self, restored_tensors, restored_shapes):
+ return control_flow_ops.no_op()
+
+
+@six.add_metaclass(abc.ABCMeta)
+class PythonStateSaveable(saveable_object.SaveableObject):
+ """An interface for saving/restoring volatile Python state."""
+
+ @abc.abstractmethod
+ def feed_dict_additions(self):
+ """When running a graph, indicates fresh state to feed.
+
+ Returns:
+ A dictionary mapping `Tensor`s to current Python state.
+ """
+ pass
+
+ @abc.abstractmethod
+ def freeze(self):
+ """Create a new `SaveableObject` which freezes current state as a constant.
+
+ Used when executing eagerly to embed the current state as a constant, or
+ when creating a static tf.train.Saver with the frozen current Python state.
+
+ Returns:
+ A `SaveableObject` which is not a `PythonStateSaveable` instance (i.e. has
+ no Python state associated with it).
+ """
+ pass
+
+
+class PythonStringStateSaveable(PythonStateSaveable):
"""Saves Python state in a checkpoint."""
def __init__(self, name, state_callback, restore_callback=None):
@@ -104,19 +145,26 @@ class PythonStringStateSaveable(saveable_object.SaveableObject):
restore_callback: A function taking a Python string, used to restore
state. Optional; defaults to doing nothing.
"""
+ self._state_callback = state_callback
self._restore_callback = restore_callback
- if context.executing_eagerly():
- self._save_string = (
- lambda: constant_op.constant(state_callback(), dtype=dtypes.string))
- else:
+ with ops.device("/cpu:0"):
self._save_string = constant_op.constant("", dtype=dtypes.string)
- self.feed_dict_additions = (
- lambda: {self._save_string: state_callback()})
spec = saveable_object.SaveSpec(
self._save_string, "", name, dtype=dtypes.string)
super(PythonStringStateSaveable, self).__init__(
self._save_string, [spec], name)
+ def feed_dict_additions(self):
+ """When running a graph, indicates fresh state to feed."""
+ return {self._save_string: self._state_callback()}
+
+ def freeze(self):
+ """Create a frozen `SaveableObject` which saves the current state."""
+ return NoRestoreSaveable(
+ tensor=self._state_callback,
+ dtype=dtypes.string,
+ name=self.name)
+
def python_restore(self, restored_strings):
"""Called to restore Python state."""
if self._restore_callback:
@@ -309,7 +357,7 @@ class _CheckpointPosition(object):
if self._checkpoint.saveable_object_cache is not None:
self._checkpoint.saveable_object_cache.setdefault(
self.checkpointable, {})[serialized_tensor.name] = [saveable]
- if isinstance(saveable, PythonStringStateSaveable):
+ if isinstance(saveable, PythonStateSaveable):
python_saveables.append(saveable)
else:
named_saveables[serialized_tensor.checkpoint_key] = saveable
@@ -819,7 +867,7 @@ class CheckpointableBase(object):
def _state_callback():
dereferenced_self = weak_self()
if dereferenced_self:
- return json.dumps(self,
+ return json.dumps(dereferenced_self,
default=serialization.get_json_type,
sort_keys=True).encode("utf8")
else:
diff --git a/tensorflow/python/training/checkpointable/util.py b/tensorflow/python/training/checkpointable/util.py
index 13dddd37ac..56c4043d9d 100644
--- a/tensorflow/python/training/checkpointable/util.py
+++ b/tensorflow/python/training/checkpointable/util.py
@@ -32,7 +32,6 @@ from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_io_ops as io_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import variable_scope
@@ -557,7 +556,14 @@ def _serialize_checkpointables(
object_graph_proto = (
checkpointable_object_graph_pb2.CheckpointableObjectGraph())
named_saveables = []
- feed_additions = {}
+ if saveables_cache is None:
+ # No SaveableObject caching. Either we're executing eagerly, or building a
+ # static save which is specialized to the current Python state.
+ feed_additions = None
+ else:
+ # If we are caching SaveableObjects, we need to build up a feed_dict with
+ # functions computing volatile Python state to be saved with the checkpoint.
+ feed_additions = {}
for checkpoint_id, checkpointable in enumerate(checkpointable_objects):
assert node_ids[checkpointable] == checkpoint_id
object_proto = object_graph_proto.nodes.add()
@@ -616,18 +622,25 @@ def _serialize_checkpointables(
for saveable in saveables:
if hasattr(saveable, "full_name"):
attribute.full_name = saveable.full_name
- saveable_feed_dict_fn = getattr(saveable, "feed_dict_additions", None)
- if saveable_feed_dict_fn is not None:
- saveable_feed_dict = saveable_feed_dict_fn() # pylint: disable=not-callable
- for new_feed_key in saveable_feed_dict.keys():
- if new_feed_key in feed_additions:
- raise AssertionError(
- ("The object %s tried to feed a value for the Tensor %s "
- "when saving, but another object is already feeding a "
- "value.")
- % (checkpointable, new_feed_key))
- feed_additions.update(saveable_feed_dict)
- named_saveables.extend(saveables)
+ if isinstance(saveable, base.PythonStateSaveable):
+ if feed_additions is None:
+ assert saveables_cache is None
+ # If we're not caching saveables, then we're either executing
+ # eagerly or building a static save/restore (e.g. for a
+ # SavedModel). In either case, we should embed the current Python
+ # state in the graph rather than relying on a feed dict.
+ saveable = saveable.freeze()
+ else:
+ saveable_feed_dict = saveable.feed_dict_additions()
+ for new_feed_key in saveable_feed_dict.keys():
+ if new_feed_key in feed_additions:
+ raise AssertionError(
+ ("The object %s tried to feed a value for the Tensor %s "
+ "when saving, but another object is already feeding a "
+ "value.")
+ % (checkpointable, new_feed_key))
+ feed_additions.update(saveable_feed_dict)
+ named_saveables.append(saveable)
for child in checkpointable._checkpoint_dependencies: # pylint: disable=protected-access
child_proto = object_proto.children.add()
@@ -827,16 +840,6 @@ def capture_dependencies(template):
yield
-class _NoRestoreSaveable(saver_lib.BaseSaverBuilder.SaveableObject):
-
- def __init__(self, tensor, name):
- spec = saver_lib.BaseSaverBuilder.SaveSpec(tensor, "", name)
- super(_NoRestoreSaveable, self).__init__(tensor, [spec], name)
-
- def restore(self, restored_tensors, restored_shapes):
- return control_flow_ops.no_op()
-
-
class _LoadStatus(object):
"""Abstract base for load status callbacks."""
@@ -1241,6 +1244,78 @@ class CheckpointableSaver(object):
else:
return self._root_checkpointable_ref
+ def _gather_saveables(
+ self, object_graph_tensor=None, saveable_object_cache=None):
+ """Wraps _serialize_object_graph to include the object graph proto."""
+ assert ((object_graph_tensor is None and saveable_object_cache is None)
+ or (object_graph_tensor is not None
+ and saveable_object_cache is not None))
+ (named_saveable_objects, graph_proto,
+ feed_additions) = _serialize_object_graph(
+ self._root_checkpointable,
+ saveables_cache=saveable_object_cache)
+ if object_graph_tensor is None:
+ with ops.device("/cpu:0"):
+ object_graph_tensor = constant_op.constant(
+ graph_proto.SerializeToString(), dtype=dtypes.string)
+ else:
+ feed_additions.update(
+ {object_graph_tensor: graph_proto.SerializeToString()})
+ assert base.OBJECT_GRAPH_PROTO_KEY not in named_saveable_objects
+ named_saveable_objects.append(
+ base.NoRestoreSaveable(
+ tensor=object_graph_tensor,
+ name=base.OBJECT_GRAPH_PROTO_KEY))
+ return named_saveable_objects, graph_proto, feed_additions
+
+ def freeze(self):
+ """Creates a `tf.train.Saver` with the current object graph frozen."""
+ named_saveable_objects, _, _ = self._gather_saveables(
+ object_graph_tensor=None, saveable_object_cache=None)
+ return saver_lib.Saver(
+ var_list=named_saveable_objects, max_to_keep=None)
+
+ def _prepare_save(self,
+ object_graph_tensor=None,
+ saveable_object_cache=None):
+ """Create or retrieve save ops.
+
+ When graph building, `saveable_object_cache` will typically be non-`None`,
+ meaning that existing `SaveableObject`s are re-used across calls to
+ `_prepare_save` even if the object graph has grown. This avoids
+ unnecessarily re-creating save ops.
+
+ Args:
+ object_graph_tensor: A `Tensor` to which the current object graph will be
+ fed.
+ saveable_object_cache: A dictionary; if specified, used to cache
+ `SaveableObject`s.
+
+ Returns:
+ A two-element tuple with a `tf.train.Saver` and a feed_dict of `Tensor`s
+ to feed when running save ops. The feed dict contains the current object
+ graph and any Python state to be saved in the checkpoint.
+ """
+ (named_saveable_objects, graph_proto,
+ feed_additions) = self._gather_saveables(
+ object_graph_tensor=object_graph_tensor,
+ saveable_object_cache=saveable_object_cache)
+ if (self._last_save_object_graph != graph_proto
+ # When executing eagerly, we need to re-create SaveableObjects each time
+ # save() is called so they pick up new Tensors passed to their
+ # constructors. That means the Saver needs to be copied with a new
+ # var_list.
+ or context.executing_eagerly()):
+ if self._last_save_object_graph is not None:
+ self._last_save_saver = _copy_saver_with_new_var_list(
+ old_saver=self._last_save_saver,
+ new_var_list=named_saveable_objects)
+ else:
+ self._last_save_saver = saver_lib.Saver(
+ var_list=named_saveable_objects, max_to_keep=None)
+ self._last_save_object_graph = graph_proto
+ return self._last_save_saver, feed_additions
+
def save(self, file_prefix, checkpoint_number=None, session=None):
"""Save a training checkpoint.
@@ -1263,44 +1338,29 @@ class CheckpointableSaver(object):
Returns:
The full path to the checkpoint.
"""
- named_variables, graph_proto, feed_additions = _serialize_object_graph(
- self._root_checkpointable,
- saveables_cache=self._saveable_object_cache)
- if not context.executing_eagerly():
- if session is None:
- session = ops.get_default_session()
+ feed_additions = {}
+ graph_building = not context.executing_eagerly()
+ if graph_building:
if self._object_graph_feed_tensor is None:
with ops.device("/cpu:0"):
self._object_graph_feed_tensor = constant_op.constant(
"", dtype=dtypes.string)
object_graph_tensor = self._object_graph_feed_tensor
- feed_additions.update(
- {object_graph_tensor: graph_proto.SerializeToString()})
else:
+ object_graph_tensor = None
+
+ saver, new_feed_additions = self._prepare_save(
+ object_graph_tensor=object_graph_tensor,
+ saveable_object_cache=self._saveable_object_cache)
+ if new_feed_additions:
+ feed_additions.update(new_feed_additions)
+ if not graph_building:
session = None
- with ops.device("/cpu:0"):
- object_graph_tensor = constant_op.constant(
- graph_proto.SerializeToString(), dtype=dtypes.string)
- assert base.OBJECT_GRAPH_PROTO_KEY not in named_variables
- named_variables.append(
- _NoRestoreSaveable(
- tensor=object_graph_tensor,
- name=base.OBJECT_GRAPH_PROTO_KEY))
- if (self._last_save_object_graph != graph_proto
- # When executing eagerly, we need to re-create SaveableObjects each time
- # save() is called so they pick up new Tensors passed to their
- # constructors. That means the Saver needs to be copied with a new
- # var_list.
- or context.executing_eagerly()):
- if self._last_save_object_graph is not None:
- self._last_save_saver = _copy_saver_with_new_var_list(
- old_saver=self._last_save_saver, new_var_list=named_variables)
- else:
- self._last_save_saver = saver_lib.Saver(
- var_list=named_variables, max_to_keep=None)
- self._last_save_object_graph = graph_proto
+ elif session is None:
+ session = ops.get_default_session()
+
with ops.device("/cpu:0"):
- save_path = self._last_save_saver.save(
+ save_path = saver.save(
sess=_SessionWithFeedDictAdditions(
session=session, feed_additions=feed_additions),
save_path=file_prefix,
@@ -1422,6 +1482,30 @@ class CheckpointableSaver(object):
return load_status
+def frozen_saver(root_checkpointable):
+ """Creates a static `tf.train.Saver` from a checkpointable object.
+
+ The returned `Saver` saves object-based checkpoints, but these checkpoints
+ will no longer reflect structural changes to the object graph, only changes to
+ the values of `Variable`s added as dependencies of the root object before
+ `freeze` was called.
+
+ `restore` works on the returned `Saver`, but requires that the object graph of
+ the checkpoint being loaded exactly matches the object graph when `freeze` was
+ called. This is in contrast the object-based restore performed by
+ `tf.train.Checkpoint` which attempts a fuzzy matching between a checkpoint's
+ object graph and the current Python object graph.
+
+ Args:
+ root_checkpointable: A checkpointable object to save.
+
+ Returns:
+ A `tf.train.Saver` which saves object-based checkpoints for the object graph
+ frozen at the time `frozen_saver` was called.
+ """
+ return CheckpointableSaver(root_checkpointable).freeze()
+
+
@tf_export("train.Checkpoint")
class Checkpoint(tracking.Checkpointable):
"""Groups checkpointable objects, saving and restoring them.
diff --git a/tensorflow/python/training/checkpointable/util_test.py b/tensorflow/python/training/checkpointable/util_test.py
index bef4bf2a16..0d32d21426 100644
--- a/tensorflow/python/training/checkpointable/util_test.py
+++ b/tensorflow/python/training/checkpointable/util_test.py
@@ -560,6 +560,46 @@ class CheckpointingTests(test.TestCase):
self.evaluate(root.save_counter))
@test_util.run_in_graph_and_eager_modes
+ def testFreezing(self):
+ with self.cached_session(use_gpu=True) as session:
+ # Save an object-based checkpoint using a frozen saver
+ directory = self.get_temp_dir()
+ prefix = os.path.join(directory, "ckpt")
+ v = resource_variable_ops.ResourceVariable(0, dtype=dtypes.int64)
+ checkpoint = checkpointable_utils.Checkpoint(v=v)
+ self.evaluate(v.assign(3))
+ # Create the save counter so assert_consumed doesn't complain about it not
+ # existing in the checkpoint on restore.
+ self.evaluate(checkpoint.save_counter.assign(12))
+ saver = checkpointable_utils.frozen_saver(checkpoint)
+ save_path = saver.save(session, prefix)
+ self.evaluate(v.assign(10))
+ # Use the frozen saver to restore the same object graph
+ saver.restore(session, save_path)
+ self.assertEqual(3, self.evaluate(v))
+
+ # Restore using another frozen saver on an identical object graph
+ del v, checkpoint, saver
+ v = resource_variable_ops.ResourceVariable(0, dtype=dtypes.int64)
+ checkpoint = checkpointable_utils.Checkpoint(v=v)
+ saver = checkpointable_utils.frozen_saver(checkpoint)
+ saver.restore(session, save_path)
+ self.assertEqual(3, self.evaluate(v))
+
+ # Restore as an object-based checkpoint
+ del v, checkpoint, saver
+ checkpoint = checkpointable_utils.Checkpoint()
+ status = checkpoint.restore(save_path)
+ v = resource_variable_ops.ResourceVariable(0, dtype=dtypes.int64)
+ if context.executing_eagerly():
+ self.assertEqual(12, self.evaluate(checkpoint.save_counter))
+ self.assertEqual(0, self.evaluate(v))
+ checkpoint.v = v
+ status.assert_consumed().run_restore_ops()
+ self.assertEqual(3, self.evaluate(v))
+ self.assertEqual(12, self.evaluate(checkpoint.save_counter))
+
+ @test_util.run_in_graph_and_eager_modes
def testCustomNumbering(self):
directory = self.get_temp_dir()
prefix = os.path.join(directory, "ckpt")