aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/checkpoint/python/python_state.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/checkpoint/python/python_state.py')
-rw-r--r--tensorflow/contrib/checkpoint/python/python_state.py40
1 files changed, 28 insertions, 12 deletions
diff --git a/tensorflow/contrib/checkpoint/python/python_state.py b/tensorflow/contrib/checkpoint/python/python_state.py
index 9b11035b6d..302d5cfb79 100644
--- a/tensorflow/contrib/checkpoint/python/python_state.py
+++ b/tensorflow/contrib/checkpoint/python/python_state.py
@@ -17,7 +17,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import abc
import functools
+import six
import numpy
@@ -101,7 +103,7 @@ class NumpyState(base.CheckpointableBase):
# TODO(allenl): Consider supporting lists/tuples, either ad-hoc or by making
# ndarrays checkpointable natively and using standard checkpointable list
# tracking.
- if isinstance(value, numpy.ndarray):
+ if isinstance(value, (numpy.ndarray, numpy.generic)):
try:
existing = super(NumpyState, self).__getattribute__(name)
existing.array = value
@@ -127,7 +129,29 @@ class NumpyState(base.CheckpointableBase):
super(NumpyState, self).__setattr__(name, value)
-class _NumpyWrapper(base.CheckpointableBase):
+@six.add_metaclass(abc.ABCMeta)
+class PythonStateWrapper(base.CheckpointableBase):
+ """Wraps a Python object for storage in an object-based checkpoint."""
+
+ @abc.abstractmethod
+ def _serialize(self):
+ """Callback for `PythonStringStateSaveable` to serialize the object."""
+
+ @abc.abstractmethod
+ def _deserialize(self, string_value):
+ """Callback for `PythonStringStateSaveable` to deserialize the object."""
+
+ def _gather_saveables_for_checkpoint(self):
+ """Specify callbacks for saving and restoring `array`."""
+ return {
+ "py_state": functools.partial(
+ base.PythonStringStateSaveable,
+ state_callback=self._serialize,
+ restore_callback=self._deserialize)
+ }
+
+
+class _NumpyWrapper(PythonStateWrapper):
"""Wraps a NumPy array for storage in an object-based checkpoint."""
def __init__(self, array):
@@ -139,7 +163,7 @@ class _NumpyWrapper(base.CheckpointableBase):
self.array = array
def _serialize(self):
- """Callback for `PythonStringStateSaveable` to serialize the array."""
+ """Callback to serialize the array."""
string_file = BytesIO()
try:
numpy.save(string_file, self.array, allow_pickle=False)
@@ -149,18 +173,10 @@ class _NumpyWrapper(base.CheckpointableBase):
return serialized
def _deserialize(self, string_value):
- """Callback for `PythonStringStateSaveable` to deserialize the array."""
+ """Callback to deserialize the array."""
string_file = BytesIO(string_value)
try:
self.array = numpy.load(string_file, allow_pickle=False)
finally:
string_file.close()
- def _gather_saveables_for_checkpoint(self):
- """Specify callbacks for saving and restoring `array`."""
- return {
- "array": functools.partial(
- base.PythonStringStateSaveable,
- state_callback=self._serialize,
- restore_callback=self._deserialize)
- }