diff options
Diffstat (limited to 'tensorflow/python/training/checkpointable/base.py')
-rw-r--r-- | tensorflow/python/training/checkpointable/base.py | 66 |
1 files changed, 57 insertions, 9 deletions
diff --git a/tensorflow/python/training/checkpointable/base.py b/tensorflow/python/training/checkpointable/base.py index 9189d8f3e8..095a90ddd4 100644 --- a/tensorflow/python/training/checkpointable/base.py +++ b/tensorflow/python/training/checkpointable/base.py @@ -17,11 +17,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import abc import collections import functools import json import weakref +import six + from tensorflow.python import pywrap_tensorflow from tensorflow.python.eager import context from tensorflow.python.framework import constant_op @@ -91,7 +94,45 @@ class CheckpointInitialValue(ops.Tensor): return self._checkpoint_position -class PythonStringStateSaveable(saveable_object.SaveableObject): +class NoRestoreSaveable(saveable_object.SaveableObject): + """Embeds a tensor in a checkpoint with no restore ops.""" + + def __init__(self, tensor, name, dtype=None): + spec = saveable_object.SaveSpec(tensor, "", name, dtype=dtype) + super(NoRestoreSaveable, self).__init__(tensor, [spec], name) + + def restore(self, restored_tensors, restored_shapes): + return control_flow_ops.no_op() + + +@six.add_metaclass(abc.ABCMeta) +class PythonStateSaveable(saveable_object.SaveableObject): + """An interface for saving/restoring volatile Python state.""" + + @abc.abstractmethod + def feed_dict_additions(self): + """When running a graph, indicates fresh state to feed. + + Returns: + A dictionary mapping `Tensor`s to current Python state. + """ + pass + + @abc.abstractmethod + def freeze(self): + """Create a new `SaveableObject` which freezes current state as a constant. + + Used when executing eagerly to embed the current state as a constant, or + when creating a static tf.train.Saver with the frozen current Python state. + + Returns: + A `SaveableObject` which is not a `PythonStateSaveable` instance (i.e. has + no Python state associated with it). + """ + pass + + +class PythonStringStateSaveable(PythonStateSaveable): """Saves Python state in a checkpoint.""" def __init__(self, name, state_callback, restore_callback=None): @@ -104,19 +145,26 @@ class PythonStringStateSaveable(saveable_object.SaveableObject): restore_callback: A function taking a Python string, used to restore state. Optional; defaults to doing nothing. """ + self._state_callback = state_callback self._restore_callback = restore_callback - if context.executing_eagerly(): - self._save_string = ( - lambda: constant_op.constant(state_callback(), dtype=dtypes.string)) - else: + with ops.device("/cpu:0"): self._save_string = constant_op.constant("", dtype=dtypes.string) - self.feed_dict_additions = ( - lambda: {self._save_string: state_callback()}) spec = saveable_object.SaveSpec( self._save_string, "", name, dtype=dtypes.string) super(PythonStringStateSaveable, self).__init__( self._save_string, [spec], name) + def feed_dict_additions(self): + """When running a graph, indicates fresh state to feed.""" + return {self._save_string: self._state_callback()} + + def freeze(self): + """Create a frozen `SaveableObject` which saves the current state.""" + return NoRestoreSaveable( + tensor=self._state_callback, + dtype=dtypes.string, + name=self.name) + def python_restore(self, restored_strings): """Called to restore Python state.""" if self._restore_callback: @@ -309,7 +357,7 @@ class _CheckpointPosition(object): if self._checkpoint.saveable_object_cache is not None: self._checkpoint.saveable_object_cache.setdefault( self.checkpointable, {})[serialized_tensor.name] = [saveable] - if isinstance(saveable, PythonStringStateSaveable): + if isinstance(saveable, PythonStateSaveable): python_saveables.append(saveable) else: named_saveables[serialized_tensor.checkpoint_key] = saveable @@ -819,7 +867,7 @@ class CheckpointableBase(object): def _state_callback(): dereferenced_self = weak_self() if dereferenced_self: - return json.dumps(self, + return json.dumps(dereferenced_self, default=serialization.get_json_type, sort_keys=True).encode("utf8") else: |