aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/checkpointable/util.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/training/checkpointable/util.py')
-rw-r--r--tensorflow/python/training/checkpointable/util.py192
1 files changed, 138 insertions, 54 deletions
diff --git a/tensorflow/python/training/checkpointable/util.py b/tensorflow/python/training/checkpointable/util.py
index 13dddd37ac..56c4043d9d 100644
--- a/tensorflow/python/training/checkpointable/util.py
+++ b/tensorflow/python/training/checkpointable/util.py
@@ -32,7 +32,6 @@ from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_io_ops as io_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import variable_scope
@@ -557,7 +556,14 @@ def _serialize_checkpointables(
object_graph_proto = (
checkpointable_object_graph_pb2.CheckpointableObjectGraph())
named_saveables = []
- feed_additions = {}
+ if saveables_cache is None:
+ # No SaveableObject caching. Either we're executing eagerly, or building a
+ # static save which is specialized to the current Python state.
+ feed_additions = None
+ else:
+ # If we are caching SaveableObjects, we need to build up a feed_dict with
+ # functions computing volatile Python state to be saved with the checkpoint.
+ feed_additions = {}
for checkpoint_id, checkpointable in enumerate(checkpointable_objects):
assert node_ids[checkpointable] == checkpoint_id
object_proto = object_graph_proto.nodes.add()
@@ -616,18 +622,25 @@ def _serialize_checkpointables(
for saveable in saveables:
if hasattr(saveable, "full_name"):
attribute.full_name = saveable.full_name
- saveable_feed_dict_fn = getattr(saveable, "feed_dict_additions", None)
- if saveable_feed_dict_fn is not None:
- saveable_feed_dict = saveable_feed_dict_fn() # pylint: disable=not-callable
- for new_feed_key in saveable_feed_dict.keys():
- if new_feed_key in feed_additions:
- raise AssertionError(
- ("The object %s tried to feed a value for the Tensor %s "
- "when saving, but another object is already feeding a "
- "value.")
- % (checkpointable, new_feed_key))
- feed_additions.update(saveable_feed_dict)
- named_saveables.extend(saveables)
+ if isinstance(saveable, base.PythonStateSaveable):
+ if feed_additions is None:
+ assert saveables_cache is None
+ # If we're not caching saveables, then we're either executing
+ # eagerly or building a static save/restore (e.g. for a
+ # SavedModel). In either case, we should embed the current Python
+ # state in the graph rather than relying on a feed dict.
+ saveable = saveable.freeze()
+ else:
+ saveable_feed_dict = saveable.feed_dict_additions()
+ for new_feed_key in saveable_feed_dict.keys():
+ if new_feed_key in feed_additions:
+ raise AssertionError(
+ ("The object %s tried to feed a value for the Tensor %s "
+ "when saving, but another object is already feeding a "
+ "value.")
+ % (checkpointable, new_feed_key))
+ feed_additions.update(saveable_feed_dict)
+ named_saveables.append(saveable)
for child in checkpointable._checkpoint_dependencies: # pylint: disable=protected-access
child_proto = object_proto.children.add()
@@ -827,16 +840,6 @@ def capture_dependencies(template):
yield
-class _NoRestoreSaveable(saver_lib.BaseSaverBuilder.SaveableObject):
-
- def __init__(self, tensor, name):
- spec = saver_lib.BaseSaverBuilder.SaveSpec(tensor, "", name)
- super(_NoRestoreSaveable, self).__init__(tensor, [spec], name)
-
- def restore(self, restored_tensors, restored_shapes):
- return control_flow_ops.no_op()
-
-
class _LoadStatus(object):
"""Abstract base for load status callbacks."""
@@ -1241,6 +1244,78 @@ class CheckpointableSaver(object):
else:
return self._root_checkpointable_ref
+ def _gather_saveables(
+ self, object_graph_tensor=None, saveable_object_cache=None):
+ """Wraps _serialize_object_graph to include the object graph proto."""
+ assert ((object_graph_tensor is None and saveable_object_cache is None)
+ or (object_graph_tensor is not None
+ and saveable_object_cache is not None))
+ (named_saveable_objects, graph_proto,
+ feed_additions) = _serialize_object_graph(
+ self._root_checkpointable,
+ saveables_cache=saveable_object_cache)
+ if object_graph_tensor is None:
+ with ops.device("/cpu:0"):
+ object_graph_tensor = constant_op.constant(
+ graph_proto.SerializeToString(), dtype=dtypes.string)
+ else:
+ feed_additions.update(
+ {object_graph_tensor: graph_proto.SerializeToString()})
+ assert base.OBJECT_GRAPH_PROTO_KEY not in named_saveable_objects
+ named_saveable_objects.append(
+ base.NoRestoreSaveable(
+ tensor=object_graph_tensor,
+ name=base.OBJECT_GRAPH_PROTO_KEY))
+ return named_saveable_objects, graph_proto, feed_additions
+
+ def freeze(self):
+ """Creates a `tf.train.Saver` with the current object graph frozen."""
+ named_saveable_objects, _, _ = self._gather_saveables(
+ object_graph_tensor=None, saveable_object_cache=None)
+ return saver_lib.Saver(
+ var_list=named_saveable_objects, max_to_keep=None)
+
+ def _prepare_save(self,
+ object_graph_tensor=None,
+ saveable_object_cache=None):
+ """Create or retrieve save ops.
+
+ When graph building, `saveable_object_cache` will typically be non-`None`,
+ meaning that existing `SaveableObject`s are re-used across calls to
+ `_prepare_save` even if the object graph has grown. This avoids
+ unnecessarily re-creating save ops.
+
+ Args:
+ object_graph_tensor: A `Tensor` to which the current object graph will be
+ fed.
+ saveable_object_cache: A dictionary; if specified, used to cache
+ `SaveableObject`s.
+
+ Returns:
+ A two-element tuple with a `tf.train.Saver` and a feed_dict of `Tensor`s
+ to feed when running save ops. The feed dict contains the current object
+ graph and any Python state to be saved in the checkpoint.
+ """
+ (named_saveable_objects, graph_proto,
+ feed_additions) = self._gather_saveables(
+ object_graph_tensor=object_graph_tensor,
+ saveable_object_cache=saveable_object_cache)
+ if (self._last_save_object_graph != graph_proto
+ # When executing eagerly, we need to re-create SaveableObjects each time
+ # save() is called so they pick up new Tensors passed to their
+ # constructors. That means the Saver needs to be copied with a new
+ # var_list.
+ or context.executing_eagerly()):
+ if self._last_save_object_graph is not None:
+ self._last_save_saver = _copy_saver_with_new_var_list(
+ old_saver=self._last_save_saver,
+ new_var_list=named_saveable_objects)
+ else:
+ self._last_save_saver = saver_lib.Saver(
+ var_list=named_saveable_objects, max_to_keep=None)
+ self._last_save_object_graph = graph_proto
+ return self._last_save_saver, feed_additions
+
def save(self, file_prefix, checkpoint_number=None, session=None):
"""Save a training checkpoint.
@@ -1263,44 +1338,29 @@ class CheckpointableSaver(object):
Returns:
The full path to the checkpoint.
"""
- named_variables, graph_proto, feed_additions = _serialize_object_graph(
- self._root_checkpointable,
- saveables_cache=self._saveable_object_cache)
- if not context.executing_eagerly():
- if session is None:
- session = ops.get_default_session()
+ feed_additions = {}
+ graph_building = not context.executing_eagerly()
+ if graph_building:
if self._object_graph_feed_tensor is None:
with ops.device("/cpu:0"):
self._object_graph_feed_tensor = constant_op.constant(
"", dtype=dtypes.string)
object_graph_tensor = self._object_graph_feed_tensor
- feed_additions.update(
- {object_graph_tensor: graph_proto.SerializeToString()})
else:
+ object_graph_tensor = None
+
+ saver, new_feed_additions = self._prepare_save(
+ object_graph_tensor=object_graph_tensor,
+ saveable_object_cache=self._saveable_object_cache)
+ if new_feed_additions:
+ feed_additions.update(new_feed_additions)
+ if not graph_building:
session = None
- with ops.device("/cpu:0"):
- object_graph_tensor = constant_op.constant(
- graph_proto.SerializeToString(), dtype=dtypes.string)
- assert base.OBJECT_GRAPH_PROTO_KEY not in named_variables
- named_variables.append(
- _NoRestoreSaveable(
- tensor=object_graph_tensor,
- name=base.OBJECT_GRAPH_PROTO_KEY))
- if (self._last_save_object_graph != graph_proto
- # When executing eagerly, we need to re-create SaveableObjects each time
- # save() is called so they pick up new Tensors passed to their
- # constructors. That means the Saver needs to be copied with a new
- # var_list.
- or context.executing_eagerly()):
- if self._last_save_object_graph is not None:
- self._last_save_saver = _copy_saver_with_new_var_list(
- old_saver=self._last_save_saver, new_var_list=named_variables)
- else:
- self._last_save_saver = saver_lib.Saver(
- var_list=named_variables, max_to_keep=None)
- self._last_save_object_graph = graph_proto
+ elif session is None:
+ session = ops.get_default_session()
+
with ops.device("/cpu:0"):
- save_path = self._last_save_saver.save(
+ save_path = saver.save(
sess=_SessionWithFeedDictAdditions(
session=session, feed_additions=feed_additions),
save_path=file_prefix,
@@ -1422,6 +1482,30 @@ class CheckpointableSaver(object):
return load_status
+def frozen_saver(root_checkpointable):
+ """Creates a static `tf.train.Saver` from a checkpointable object.
+
+ The returned `Saver` saves object-based checkpoints, but these checkpoints
+ will no longer reflect structural changes to the object graph, only changes to
+ the values of `Variable`s added as dependencies of the root object before
+ `freeze` was called.
+
+ `restore` works on the returned `Saver`, but requires that the object graph of
+ the checkpoint being loaded exactly matches the object graph when `freeze` was
+ called. This is in contrast the object-based restore performed by
+ `tf.train.Checkpoint` which attempts a fuzzy matching between a checkpoint's
+ object graph and the current Python object graph.
+
+ Args:
+ root_checkpointable: A checkpointable object to save.
+
+ Returns:
+ A `tf.train.Saver` which saves object-based checkpoints for the object graph
+ frozen at the time `frozen_saver` was called.
+ """
+ return CheckpointableSaver(root_checkpointable).freeze()
+
+
@tf_export("train.Checkpoint")
class Checkpoint(tracking.Checkpointable):
"""Groups checkpointable objects, saving and restoring them.