diff options
Diffstat (limited to 'tensorflow/python/training/checkpointable/util.py')
-rw-r--r-- | tensorflow/python/training/checkpointable/util.py | 192 |
1 files changed, 138 insertions, 54 deletions
diff --git a/tensorflow/python/training/checkpointable/util.py b/tensorflow/python/training/checkpointable/util.py index 13dddd37ac..56c4043d9d 100644 --- a/tensorflow/python/training/checkpointable/util.py +++ b/tensorflow/python/training/checkpointable/util.py @@ -32,7 +32,6 @@ from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gen_io_ops as io_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import variable_scope @@ -557,7 +556,14 @@ def _serialize_checkpointables( object_graph_proto = ( checkpointable_object_graph_pb2.CheckpointableObjectGraph()) named_saveables = [] - feed_additions = {} + if saveables_cache is None: + # No SaveableObject caching. Either we're executing eagerly, or building a + # static save which is specialized to the current Python state. + feed_additions = None + else: + # If we are caching SaveableObjects, we need to build up a feed_dict with + # functions computing volatile Python state to be saved with the checkpoint. + feed_additions = {} for checkpoint_id, checkpointable in enumerate(checkpointable_objects): assert node_ids[checkpointable] == checkpoint_id object_proto = object_graph_proto.nodes.add() @@ -616,18 +622,25 @@ def _serialize_checkpointables( for saveable in saveables: if hasattr(saveable, "full_name"): attribute.full_name = saveable.full_name - saveable_feed_dict_fn = getattr(saveable, "feed_dict_additions", None) - if saveable_feed_dict_fn is not None: - saveable_feed_dict = saveable_feed_dict_fn() # pylint: disable=not-callable - for new_feed_key in saveable_feed_dict.keys(): - if new_feed_key in feed_additions: - raise AssertionError( - ("The object %s tried to feed a value for the Tensor %s " - "when saving, but another object is already feeding a " - "value.") - % (checkpointable, new_feed_key)) - feed_additions.update(saveable_feed_dict) - named_saveables.extend(saveables) + if isinstance(saveable, base.PythonStateSaveable): + if feed_additions is None: + assert saveables_cache is None + # If we're not caching saveables, then we're either executing + # eagerly or building a static save/restore (e.g. for a + # SavedModel). In either case, we should embed the current Python + # state in the graph rather than relying on a feed dict. + saveable = saveable.freeze() + else: + saveable_feed_dict = saveable.feed_dict_additions() + for new_feed_key in saveable_feed_dict.keys(): + if new_feed_key in feed_additions: + raise AssertionError( + ("The object %s tried to feed a value for the Tensor %s " + "when saving, but another object is already feeding a " + "value.") + % (checkpointable, new_feed_key)) + feed_additions.update(saveable_feed_dict) + named_saveables.append(saveable) for child in checkpointable._checkpoint_dependencies: # pylint: disable=protected-access child_proto = object_proto.children.add() @@ -827,16 +840,6 @@ def capture_dependencies(template): yield -class _NoRestoreSaveable(saver_lib.BaseSaverBuilder.SaveableObject): - - def __init__(self, tensor, name): - spec = saver_lib.BaseSaverBuilder.SaveSpec(tensor, "", name) - super(_NoRestoreSaveable, self).__init__(tensor, [spec], name) - - def restore(self, restored_tensors, restored_shapes): - return control_flow_ops.no_op() - - class _LoadStatus(object): """Abstract base for load status callbacks.""" @@ -1241,6 +1244,78 @@ class CheckpointableSaver(object): else: return self._root_checkpointable_ref + def _gather_saveables( + self, object_graph_tensor=None, saveable_object_cache=None): + """Wraps _serialize_object_graph to include the object graph proto.""" + assert ((object_graph_tensor is None and saveable_object_cache is None) + or (object_graph_tensor is not None + and saveable_object_cache is not None)) + (named_saveable_objects, graph_proto, + feed_additions) = _serialize_object_graph( + self._root_checkpointable, + saveables_cache=saveable_object_cache) + if object_graph_tensor is None: + with ops.device("/cpu:0"): + object_graph_tensor = constant_op.constant( + graph_proto.SerializeToString(), dtype=dtypes.string) + else: + feed_additions.update( + {object_graph_tensor: graph_proto.SerializeToString()}) + assert base.OBJECT_GRAPH_PROTO_KEY not in named_saveable_objects + named_saveable_objects.append( + base.NoRestoreSaveable( + tensor=object_graph_tensor, + name=base.OBJECT_GRAPH_PROTO_KEY)) + return named_saveable_objects, graph_proto, feed_additions + + def freeze(self): + """Creates a `tf.train.Saver` with the current object graph frozen.""" + named_saveable_objects, _, _ = self._gather_saveables( + object_graph_tensor=None, saveable_object_cache=None) + return saver_lib.Saver( + var_list=named_saveable_objects, max_to_keep=None) + + def _prepare_save(self, + object_graph_tensor=None, + saveable_object_cache=None): + """Create or retrieve save ops. + + When graph building, `saveable_object_cache` will typically be non-`None`, + meaning that existing `SaveableObject`s are re-used across calls to + `_prepare_save` even if the object graph has grown. This avoids + unnecessarily re-creating save ops. + + Args: + object_graph_tensor: A `Tensor` to which the current object graph will be + fed. + saveable_object_cache: A dictionary; if specified, used to cache + `SaveableObject`s. + + Returns: + A two-element tuple with a `tf.train.Saver` and a feed_dict of `Tensor`s + to feed when running save ops. The feed dict contains the current object + graph and any Python state to be saved in the checkpoint. + """ + (named_saveable_objects, graph_proto, + feed_additions) = self._gather_saveables( + object_graph_tensor=object_graph_tensor, + saveable_object_cache=saveable_object_cache) + if (self._last_save_object_graph != graph_proto + # When executing eagerly, we need to re-create SaveableObjects each time + # save() is called so they pick up new Tensors passed to their + # constructors. That means the Saver needs to be copied with a new + # var_list. + or context.executing_eagerly()): + if self._last_save_object_graph is not None: + self._last_save_saver = _copy_saver_with_new_var_list( + old_saver=self._last_save_saver, + new_var_list=named_saveable_objects) + else: + self._last_save_saver = saver_lib.Saver( + var_list=named_saveable_objects, max_to_keep=None) + self._last_save_object_graph = graph_proto + return self._last_save_saver, feed_additions + def save(self, file_prefix, checkpoint_number=None, session=None): """Save a training checkpoint. @@ -1263,44 +1338,29 @@ class CheckpointableSaver(object): Returns: The full path to the checkpoint. """ - named_variables, graph_proto, feed_additions = _serialize_object_graph( - self._root_checkpointable, - saveables_cache=self._saveable_object_cache) - if not context.executing_eagerly(): - if session is None: - session = ops.get_default_session() + feed_additions = {} + graph_building = not context.executing_eagerly() + if graph_building: if self._object_graph_feed_tensor is None: with ops.device("/cpu:0"): self._object_graph_feed_tensor = constant_op.constant( "", dtype=dtypes.string) object_graph_tensor = self._object_graph_feed_tensor - feed_additions.update( - {object_graph_tensor: graph_proto.SerializeToString()}) else: + object_graph_tensor = None + + saver, new_feed_additions = self._prepare_save( + object_graph_tensor=object_graph_tensor, + saveable_object_cache=self._saveable_object_cache) + if new_feed_additions: + feed_additions.update(new_feed_additions) + if not graph_building: session = None - with ops.device("/cpu:0"): - object_graph_tensor = constant_op.constant( - graph_proto.SerializeToString(), dtype=dtypes.string) - assert base.OBJECT_GRAPH_PROTO_KEY not in named_variables - named_variables.append( - _NoRestoreSaveable( - tensor=object_graph_tensor, - name=base.OBJECT_GRAPH_PROTO_KEY)) - if (self._last_save_object_graph != graph_proto - # When executing eagerly, we need to re-create SaveableObjects each time - # save() is called so they pick up new Tensors passed to their - # constructors. That means the Saver needs to be copied with a new - # var_list. - or context.executing_eagerly()): - if self._last_save_object_graph is not None: - self._last_save_saver = _copy_saver_with_new_var_list( - old_saver=self._last_save_saver, new_var_list=named_variables) - else: - self._last_save_saver = saver_lib.Saver( - var_list=named_variables, max_to_keep=None) - self._last_save_object_graph = graph_proto + elif session is None: + session = ops.get_default_session() + with ops.device("/cpu:0"): - save_path = self._last_save_saver.save( + save_path = saver.save( sess=_SessionWithFeedDictAdditions( session=session, feed_additions=feed_additions), save_path=file_prefix, @@ -1422,6 +1482,30 @@ class CheckpointableSaver(object): return load_status +def frozen_saver(root_checkpointable): + """Creates a static `tf.train.Saver` from a checkpointable object. + + The returned `Saver` saves object-based checkpoints, but these checkpoints + will no longer reflect structural changes to the object graph, only changes to + the values of `Variable`s added as dependencies of the root object before + `freeze` was called. + + `restore` works on the returned `Saver`, but requires that the object graph of + the checkpoint being loaded exactly matches the object graph when `freeze` was + called. This is in contrast the object-based restore performed by + `tf.train.Checkpoint` which attempts a fuzzy matching between a checkpoint's + object graph and the current Python object graph. + + Args: + root_checkpointable: A checkpointable object to save. + + Returns: + A `tf.train.Saver` which saves object-based checkpoints for the object graph + frozen at the time `frozen_saver` was called. + """ + return CheckpointableSaver(root_checkpointable).freeze() + + @tf_export("train.Checkpoint") class Checkpoint(tracking.Checkpointable): """Groups checkpointable objects, saving and restoring them. |