aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/checkpointable/base.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/training/checkpointable/base.py')
-rw-r--r--tensorflow/python/training/checkpointable/base.py66
1 files changed, 57 insertions, 9 deletions
diff --git a/tensorflow/python/training/checkpointable/base.py b/tensorflow/python/training/checkpointable/base.py
index 9189d8f3e8..095a90ddd4 100644
--- a/tensorflow/python/training/checkpointable/base.py
+++ b/tensorflow/python/training/checkpointable/base.py
@@ -17,11 +17,14 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import abc
import collections
import functools
import json
import weakref
+import six
+
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
@@ -91,7 +94,45 @@ class CheckpointInitialValue(ops.Tensor):
return self._checkpoint_position
-class PythonStringStateSaveable(saveable_object.SaveableObject):
+class NoRestoreSaveable(saveable_object.SaveableObject):
+ """Embeds a tensor in a checkpoint with no restore ops."""
+
+ def __init__(self, tensor, name, dtype=None):
+ spec = saveable_object.SaveSpec(tensor, "", name, dtype=dtype)
+ super(NoRestoreSaveable, self).__init__(tensor, [spec], name)
+
+ def restore(self, restored_tensors, restored_shapes):
+ return control_flow_ops.no_op()
+
+
+@six.add_metaclass(abc.ABCMeta)
+class PythonStateSaveable(saveable_object.SaveableObject):
+ """An interface for saving/restoring volatile Python state."""
+
+ @abc.abstractmethod
+ def feed_dict_additions(self):
+ """When running a graph, indicates fresh state to feed.
+
+ Returns:
+ A dictionary mapping `Tensor`s to current Python state.
+ """
+ pass
+
+ @abc.abstractmethod
+ def freeze(self):
+ """Create a new `SaveableObject` which freezes current state as a constant.
+
+ Used when executing eagerly to embed the current state as a constant, or
+ when creating a static tf.train.Saver with the frozen current Python state.
+
+ Returns:
+ A `SaveableObject` which is not a `PythonStateSaveable` instance (i.e. has
+ no Python state associated with it).
+ """
+ pass
+
+
+class PythonStringStateSaveable(PythonStateSaveable):
"""Saves Python state in a checkpoint."""
def __init__(self, name, state_callback, restore_callback=None):
@@ -104,19 +145,26 @@ class PythonStringStateSaveable(saveable_object.SaveableObject):
restore_callback: A function taking a Python string, used to restore
state. Optional; defaults to doing nothing.
"""
+ self._state_callback = state_callback
self._restore_callback = restore_callback
- if context.executing_eagerly():
- self._save_string = (
- lambda: constant_op.constant(state_callback(), dtype=dtypes.string))
- else:
+ with ops.device("/cpu:0"):
self._save_string = constant_op.constant("", dtype=dtypes.string)
- self.feed_dict_additions = (
- lambda: {self._save_string: state_callback()})
spec = saveable_object.SaveSpec(
self._save_string, "", name, dtype=dtypes.string)
super(PythonStringStateSaveable, self).__init__(
self._save_string, [spec], name)
+ def feed_dict_additions(self):
+ """When running a graph, indicates fresh state to feed."""
+ return {self._save_string: self._state_callback()}
+
+ def freeze(self):
+ """Create a frozen `SaveableObject` which saves the current state."""
+ return NoRestoreSaveable(
+ tensor=self._state_callback,
+ dtype=dtypes.string,
+ name=self.name)
+
def python_restore(self, restored_strings):
"""Called to restore Python state."""
if self._restore_callback:
@@ -309,7 +357,7 @@ class _CheckpointPosition(object):
if self._checkpoint.saveable_object_cache is not None:
self._checkpoint.saveable_object_cache.setdefault(
self.checkpointable, {})[serialized_tensor.name] = [saveable]
- if isinstance(saveable, PythonStringStateSaveable):
+ if isinstance(saveable, PythonStateSaveable):
python_saveables.append(saveable)
else:
named_saveables[serialized_tensor.checkpoint_key] = saveable
@@ -819,7 +867,7 @@ class CheckpointableBase(object):
def _state_callback():
dereferenced_self = weak_self()
if dereferenced_self:
- return json.dumps(self,
+ return json.dumps(dereferenced_self,
default=serialization.get_json_type,
sort_keys=True).encode("utf8")
else: