aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training
diff options
context:
space:
mode:
authorGravatar Allen Lavoie <allenl@google.com>2018-09-07 12:24:30 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-07 12:34:11 -0700
commitca92311cbdd3cecbb41c3f0012bcab90eef0c26f (patch)
treeddbc952f791332f0d700c5780a34472008ef7553 /tensorflow/python/training
parentbb096a0735445c7e05be4b68042f21660602001c (diff)
Builds a static tf.train.Saver from a checkpointable object graph
Moves around some SaveableObjects to support a freeze method for python state saveables, and makes sure that the object graph proto is included in the frozen Saver. This should be useful for embedding in SavedModels, where variables can be updated and the resulting checkpoints (saved from the SaverDef in the SavedModel) will still support Keras-style object-based restoration into Python programs (with better eager support and less fragile variable matching). This is also a step toward Estimators saving object-based checkpoints. PiperOrigin-RevId: 212017296
Diffstat (limited to 'tensorflow/python/training')
-rw-r--r--tensorflow/python/training/checkpointable/base.py66
-rw-r--r--tensorflow/python/training/checkpointable/util.py192
-rw-r--r--tensorflow/python/training/checkpointable/util_test.py40
3 files changed, 235 insertions, 63 deletions
diff --git a/tensorflow/python/training/checkpointable/base.py b/tensorflow/python/training/checkpointable/base.py
index 9189d8f3e8..095a90ddd4 100644
--- a/tensorflow/python/training/checkpointable/base.py
+++ b/tensorflow/python/training/checkpointable/base.py
@@ -17,11 +17,14 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import abc
import collections
import functools
import json
import weakref
+import six
+
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
@@ -91,7 +94,45 @@ class CheckpointInitialValue(ops.Tensor):
return self._checkpoint_position
-class PythonStringStateSaveable(saveable_object.SaveableObject):
+class NoRestoreSaveable(saveable_object.SaveableObject):
+ """Embeds a tensor in a checkpoint with no restore ops."""
+
+ def __init__(self, tensor, name, dtype=None):
+ spec = saveable_object.SaveSpec(tensor, "", name, dtype=dtype)
+ super(NoRestoreSaveable, self).__init__(tensor, [spec], name)
+
+ def restore(self, restored_tensors, restored_shapes):
+ return control_flow_ops.no_op()
+
+
+@six.add_metaclass(abc.ABCMeta)
+class PythonStateSaveable(saveable_object.SaveableObject):
+ """An interface for saving/restoring volatile Python state."""
+
+ @abc.abstractmethod
+ def feed_dict_additions(self):
+ """When running a graph, indicates fresh state to feed.
+
+ Returns:
+ A dictionary mapping `Tensor`s to current Python state.
+ """
+ pass
+
+ @abc.abstractmethod
+ def freeze(self):
+ """Create a new `SaveableObject` which freezes current state as a constant.
+
+ Used when executing eagerly to embed the current state as a constant, or
+ when creating a static tf.train.Saver with the frozen current Python state.
+
+ Returns:
+ A `SaveableObject` which is not a `PythonStateSaveable` instance (i.e. has
+ no Python state associated with it).
+ """
+ pass
+
+
+class PythonStringStateSaveable(PythonStateSaveable):
"""Saves Python state in a checkpoint."""
def __init__(self, name, state_callback, restore_callback=None):
@@ -104,19 +145,26 @@ class PythonStringStateSaveable(saveable_object.SaveableObject):
restore_callback: A function taking a Python string, used to restore
state. Optional; defaults to doing nothing.
"""
+ self._state_callback = state_callback
self._restore_callback = restore_callback
- if context.executing_eagerly():
- self._save_string = (
- lambda: constant_op.constant(state_callback(), dtype=dtypes.string))
- else:
+ with ops.device("/cpu:0"):
self._save_string = constant_op.constant("", dtype=dtypes.string)
- self.feed_dict_additions = (
- lambda: {self._save_string: state_callback()})
spec = saveable_object.SaveSpec(
self._save_string, "", name, dtype=dtypes.string)
super(PythonStringStateSaveable, self).__init__(
self._save_string, [spec], name)
+ def feed_dict_additions(self):
+ """When running a graph, indicates fresh state to feed."""
+ return {self._save_string: self._state_callback()}
+
+ def freeze(self):
+ """Create a frozen `SaveableObject` which saves the current state."""
+ return NoRestoreSaveable(
+ tensor=self._state_callback,
+ dtype=dtypes.string,
+ name=self.name)
+
def python_restore(self, restored_strings):
"""Called to restore Python state."""
if self._restore_callback:
@@ -309,7 +357,7 @@ class _CheckpointPosition(object):
if self._checkpoint.saveable_object_cache is not None:
self._checkpoint.saveable_object_cache.setdefault(
self.checkpointable, {})[serialized_tensor.name] = [saveable]
- if isinstance(saveable, PythonStringStateSaveable):
+ if isinstance(saveable, PythonStateSaveable):
python_saveables.append(saveable)
else:
named_saveables[serialized_tensor.checkpoint_key] = saveable
@@ -819,7 +867,7 @@ class CheckpointableBase(object):
def _state_callback():
dereferenced_self = weak_self()
if dereferenced_self:
- return json.dumps(self,
+ return json.dumps(dereferenced_self,
default=serialization.get_json_type,
sort_keys=True).encode("utf8")
else:
diff --git a/tensorflow/python/training/checkpointable/util.py b/tensorflow/python/training/checkpointable/util.py
index 13dddd37ac..56c4043d9d 100644
--- a/tensorflow/python/training/checkpointable/util.py
+++ b/tensorflow/python/training/checkpointable/util.py
@@ -32,7 +32,6 @@ from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_io_ops as io_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import variable_scope
@@ -557,7 +556,14 @@ def _serialize_checkpointables(
object_graph_proto = (
checkpointable_object_graph_pb2.CheckpointableObjectGraph())
named_saveables = []
- feed_additions = {}
+ if saveables_cache is None:
+ # No SaveableObject caching. Either we're executing eagerly, or building a
+ # static save which is specialized to the current Python state.
+ feed_additions = None
+ else:
+ # If we are caching SaveableObjects, we need to build up a feed_dict with
+ # functions computing volatile Python state to be saved with the checkpoint.
+ feed_additions = {}
for checkpoint_id, checkpointable in enumerate(checkpointable_objects):
assert node_ids[checkpointable] == checkpoint_id
object_proto = object_graph_proto.nodes.add()
@@ -616,18 +622,25 @@ def _serialize_checkpointables(
for saveable in saveables:
if hasattr(saveable, "full_name"):
attribute.full_name = saveable.full_name
- saveable_feed_dict_fn = getattr(saveable, "feed_dict_additions", None)
- if saveable_feed_dict_fn is not None:
- saveable_feed_dict = saveable_feed_dict_fn() # pylint: disable=not-callable
- for new_feed_key in saveable_feed_dict.keys():
- if new_feed_key in feed_additions:
- raise AssertionError(
- ("The object %s tried to feed a value for the Tensor %s "
- "when saving, but another object is already feeding a "
- "value.")
- % (checkpointable, new_feed_key))
- feed_additions.update(saveable_feed_dict)
- named_saveables.extend(saveables)
+ if isinstance(saveable, base.PythonStateSaveable):
+ if feed_additions is None:
+ assert saveables_cache is None
+ # If we're not caching saveables, then we're either executing
+ # eagerly or building a static save/restore (e.g. for a
+ # SavedModel). In either case, we should embed the current Python
+ # state in the graph rather than relying on a feed dict.
+ saveable = saveable.freeze()
+ else:
+ saveable_feed_dict = saveable.feed_dict_additions()
+ for new_feed_key in saveable_feed_dict.keys():
+ if new_feed_key in feed_additions:
+ raise AssertionError(
+ ("The object %s tried to feed a value for the Tensor %s "
+ "when saving, but another object is already feeding a "
+ "value.")
+ % (checkpointable, new_feed_key))
+ feed_additions.update(saveable_feed_dict)
+ named_saveables.append(saveable)
for child in checkpointable._checkpoint_dependencies: # pylint: disable=protected-access
child_proto = object_proto.children.add()
@@ -827,16 +840,6 @@ def capture_dependencies(template):
yield
-class _NoRestoreSaveable(saver_lib.BaseSaverBuilder.SaveableObject):
-
- def __init__(self, tensor, name):
- spec = saver_lib.BaseSaverBuilder.SaveSpec(tensor, "", name)
- super(_NoRestoreSaveable, self).__init__(tensor, [spec], name)
-
- def restore(self, restored_tensors, restored_shapes):
- return control_flow_ops.no_op()
-
-
class _LoadStatus(object):
"""Abstract base for load status callbacks."""
@@ -1241,6 +1244,78 @@ class CheckpointableSaver(object):
else:
return self._root_checkpointable_ref
+ def _gather_saveables(
+ self, object_graph_tensor=None, saveable_object_cache=None):
+ """Wraps _serialize_object_graph to include the object graph proto."""
+ assert ((object_graph_tensor is None and saveable_object_cache is None)
+ or (object_graph_tensor is not None
+ and saveable_object_cache is not None))
+ (named_saveable_objects, graph_proto,
+ feed_additions) = _serialize_object_graph(
+ self._root_checkpointable,
+ saveables_cache=saveable_object_cache)
+ if object_graph_tensor is None:
+ with ops.device("/cpu:0"):
+ object_graph_tensor = constant_op.constant(
+ graph_proto.SerializeToString(), dtype=dtypes.string)
+ else:
+ feed_additions.update(
+ {object_graph_tensor: graph_proto.SerializeToString()})
+ assert base.OBJECT_GRAPH_PROTO_KEY not in named_saveable_objects
+ named_saveable_objects.append(
+ base.NoRestoreSaveable(
+ tensor=object_graph_tensor,
+ name=base.OBJECT_GRAPH_PROTO_KEY))
+ return named_saveable_objects, graph_proto, feed_additions
+
+ def freeze(self):
+ """Creates a `tf.train.Saver` with the current object graph frozen."""
+ named_saveable_objects, _, _ = self._gather_saveables(
+ object_graph_tensor=None, saveable_object_cache=None)
+ return saver_lib.Saver(
+ var_list=named_saveable_objects, max_to_keep=None)
+
+ def _prepare_save(self,
+ object_graph_tensor=None,
+ saveable_object_cache=None):
+ """Create or retrieve save ops.
+
+ When graph building, `saveable_object_cache` will typically be non-`None`,
+ meaning that existing `SaveableObject`s are re-used across calls to
+ `_prepare_save` even if the object graph has grown. This avoids
+ unnecessarily re-creating save ops.
+
+ Args:
+ object_graph_tensor: A `Tensor` to which the current object graph will be
+ fed.
+ saveable_object_cache: A dictionary; if specified, used to cache
+ `SaveableObject`s.
+
+ Returns:
+ A two-element tuple with a `tf.train.Saver` and a feed_dict of `Tensor`s
+ to feed when running save ops. The feed dict contains the current object
+ graph and any Python state to be saved in the checkpoint.
+ """
+ (named_saveable_objects, graph_proto,
+ feed_additions) = self._gather_saveables(
+ object_graph_tensor=object_graph_tensor,
+ saveable_object_cache=saveable_object_cache)
+ if (self._last_save_object_graph != graph_proto
+ # When executing eagerly, we need to re-create SaveableObjects each time
+ # save() is called so they pick up new Tensors passed to their
+ # constructors. That means the Saver needs to be copied with a new
+ # var_list.
+ or context.executing_eagerly()):
+ if self._last_save_object_graph is not None:
+ self._last_save_saver = _copy_saver_with_new_var_list(
+ old_saver=self._last_save_saver,
+ new_var_list=named_saveable_objects)
+ else:
+ self._last_save_saver = saver_lib.Saver(
+ var_list=named_saveable_objects, max_to_keep=None)
+ self._last_save_object_graph = graph_proto
+ return self._last_save_saver, feed_additions
+
def save(self, file_prefix, checkpoint_number=None, session=None):
"""Save a training checkpoint.
@@ -1263,44 +1338,29 @@ class CheckpointableSaver(object):
Returns:
The full path to the checkpoint.
"""
- named_variables, graph_proto, feed_additions = _serialize_object_graph(
- self._root_checkpointable,
- saveables_cache=self._saveable_object_cache)
- if not context.executing_eagerly():
- if session is None:
- session = ops.get_default_session()
+ feed_additions = {}
+ graph_building = not context.executing_eagerly()
+ if graph_building:
if self._object_graph_feed_tensor is None:
with ops.device("/cpu:0"):
self._object_graph_feed_tensor = constant_op.constant(
"", dtype=dtypes.string)
object_graph_tensor = self._object_graph_feed_tensor
- feed_additions.update(
- {object_graph_tensor: graph_proto.SerializeToString()})
else:
+ object_graph_tensor = None
+
+ saver, new_feed_additions = self._prepare_save(
+ object_graph_tensor=object_graph_tensor,
+ saveable_object_cache=self._saveable_object_cache)
+ if new_feed_additions:
+ feed_additions.update(new_feed_additions)
+ if not graph_building:
session = None
- with ops.device("/cpu:0"):
- object_graph_tensor = constant_op.constant(
- graph_proto.SerializeToString(), dtype=dtypes.string)
- assert base.OBJECT_GRAPH_PROTO_KEY not in named_variables
- named_variables.append(
- _NoRestoreSaveable(
- tensor=object_graph_tensor,
- name=base.OBJECT_GRAPH_PROTO_KEY))
- if (self._last_save_object_graph != graph_proto
- # When executing eagerly, we need to re-create SaveableObjects each time
- # save() is called so they pick up new Tensors passed to their
- # constructors. That means the Saver needs to be copied with a new
- # var_list.
- or context.executing_eagerly()):
- if self._last_save_object_graph is not None:
- self._last_save_saver = _copy_saver_with_new_var_list(
- old_saver=self._last_save_saver, new_var_list=named_variables)
- else:
- self._last_save_saver = saver_lib.Saver(
- var_list=named_variables, max_to_keep=None)
- self._last_save_object_graph = graph_proto
+ elif session is None:
+ session = ops.get_default_session()
+
with ops.device("/cpu:0"):
- save_path = self._last_save_saver.save(
+ save_path = saver.save(
sess=_SessionWithFeedDictAdditions(
session=session, feed_additions=feed_additions),
save_path=file_prefix,
@@ -1422,6 +1482,30 @@ class CheckpointableSaver(object):
return load_status
+def frozen_saver(root_checkpointable):
+ """Creates a static `tf.train.Saver` from a checkpointable object.
+
+ The returned `Saver` saves object-based checkpoints, but these checkpoints
+ will no longer reflect structural changes to the object graph, only changes to
+ the values of `Variable`s added as dependencies of the root object before
+ `freeze` was called.
+
+ `restore` works on the returned `Saver`, but requires that the object graph of
+ the checkpoint being loaded exactly matches the object graph when `freeze` was
+ called. This is in contrast the object-based restore performed by
+ `tf.train.Checkpoint` which attempts a fuzzy matching between a checkpoint's
+ object graph and the current Python object graph.
+
+ Args:
+ root_checkpointable: A checkpointable object to save.
+
+ Returns:
+ A `tf.train.Saver` which saves object-based checkpoints for the object graph
+ frozen at the time `frozen_saver` was called.
+ """
+ return CheckpointableSaver(root_checkpointable).freeze()
+
+
@tf_export("train.Checkpoint")
class Checkpoint(tracking.Checkpointable):
"""Groups checkpointable objects, saving and restoring them.
diff --git a/tensorflow/python/training/checkpointable/util_test.py b/tensorflow/python/training/checkpointable/util_test.py
index bef4bf2a16..0d32d21426 100644
--- a/tensorflow/python/training/checkpointable/util_test.py
+++ b/tensorflow/python/training/checkpointable/util_test.py
@@ -560,6 +560,46 @@ class CheckpointingTests(test.TestCase):
self.evaluate(root.save_counter))
@test_util.run_in_graph_and_eager_modes
+ def testFreezing(self):
+ with self.cached_session(use_gpu=True) as session:
+ # Save an object-based checkpoint using a frozen saver
+ directory = self.get_temp_dir()
+ prefix = os.path.join(directory, "ckpt")
+ v = resource_variable_ops.ResourceVariable(0, dtype=dtypes.int64)
+ checkpoint = checkpointable_utils.Checkpoint(v=v)
+ self.evaluate(v.assign(3))
+ # Create the save counter so assert_consumed doesn't complain about it not
+ # existing in the checkpoint on restore.
+ self.evaluate(checkpoint.save_counter.assign(12))
+ saver = checkpointable_utils.frozen_saver(checkpoint)
+ save_path = saver.save(session, prefix)
+ self.evaluate(v.assign(10))
+ # Use the frozen saver to restore the same object graph
+ saver.restore(session, save_path)
+ self.assertEqual(3, self.evaluate(v))
+
+ # Restore using another frozen saver on an identical object graph
+ del v, checkpoint, saver
+ v = resource_variable_ops.ResourceVariable(0, dtype=dtypes.int64)
+ checkpoint = checkpointable_utils.Checkpoint(v=v)
+ saver = checkpointable_utils.frozen_saver(checkpoint)
+ saver.restore(session, save_path)
+ self.assertEqual(3, self.evaluate(v))
+
+ # Restore as an object-based checkpoint
+ del v, checkpoint, saver
+ checkpoint = checkpointable_utils.Checkpoint()
+ status = checkpoint.restore(save_path)
+ v = resource_variable_ops.ResourceVariable(0, dtype=dtypes.int64)
+ if context.executing_eagerly():
+ self.assertEqual(12, self.evaluate(checkpoint.save_counter))
+ self.assertEqual(0, self.evaluate(v))
+ checkpoint.v = v
+ status.assert_consumed().run_restore_ops()
+ self.assertEqual(3, self.evaluate(v))
+ self.assertEqual(12, self.evaluate(checkpoint.save_counter))
+
+ @test_util.run_in_graph_and_eager_modes
def testCustomNumbering(self):
directory = self.get_temp_dir()
prefix = os.path.join(directory, "ckpt")