diff options
Diffstat (limited to 'tensorflow/contrib/checkpoint/python/python_state.py')
-rw-r--r-- | tensorflow/contrib/checkpoint/python/python_state.py | 40 |
1 files changed, 28 insertions, 12 deletions
diff --git a/tensorflow/contrib/checkpoint/python/python_state.py b/tensorflow/contrib/checkpoint/python/python_state.py index 9b11035b6d..302d5cfb79 100644 --- a/tensorflow/contrib/checkpoint/python/python_state.py +++ b/tensorflow/contrib/checkpoint/python/python_state.py @@ -17,7 +17,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import abc import functools +import six import numpy @@ -101,7 +103,7 @@ class NumpyState(base.CheckpointableBase): # TODO(allenl): Consider supporting lists/tuples, either ad-hoc or by making # ndarrays checkpointable natively and using standard checkpointable list # tracking. - if isinstance(value, numpy.ndarray): + if isinstance(value, (numpy.ndarray, numpy.generic)): try: existing = super(NumpyState, self).__getattribute__(name) existing.array = value @@ -127,7 +129,29 @@ class NumpyState(base.CheckpointableBase): super(NumpyState, self).__setattr__(name, value) -class _NumpyWrapper(base.CheckpointableBase): +@six.add_metaclass(abc.ABCMeta) +class PythonStateWrapper(base.CheckpointableBase): + """Wraps a Python object for storage in an object-based checkpoint.""" + + @abc.abstractmethod + def _serialize(self): + """Callback for `PythonStringStateSaveable` to serialize the object.""" + + @abc.abstractmethod + def _deserialize(self, string_value): + """Callback for `PythonStringStateSaveable` to deserialize the object.""" + + def _gather_saveables_for_checkpoint(self): + """Specify callbacks for saving and restoring `array`.""" + return { + "py_state": functools.partial( + base.PythonStringStateSaveable, + state_callback=self._serialize, + restore_callback=self._deserialize) + } + + +class _NumpyWrapper(PythonStateWrapper): """Wraps a NumPy array for storage in an object-based checkpoint.""" def __init__(self, array): @@ -139,7 +163,7 @@ class _NumpyWrapper(base.CheckpointableBase): self.array = array def _serialize(self): - """Callback for `PythonStringStateSaveable` to serialize the array.""" + """Callback to serialize the array.""" string_file = BytesIO() try: numpy.save(string_file, self.array, allow_pickle=False) @@ -149,18 +173,10 @@ class _NumpyWrapper(base.CheckpointableBase): return serialized def _deserialize(self, string_value): - """Callback for `PythonStringStateSaveable` to deserialize the array.""" + """Callback to deserialize the array.""" string_file = BytesIO(string_value) try: self.array = numpy.load(string_file, allow_pickle=False) finally: string_file.close() - def _gather_saveables_for_checkpoint(self): - """Specify callbacks for saving and restoring `array`.""" - return { - "array": functools.partial( - base.PythonStringStateSaveable, - state_callback=self._serialize, - restore_callback=self._deserialize) - } |