aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/checkpoint/python/python_state.py
diff options
context:
space:
mode:
authorGravatar avijit-nervana <avijit.chakraborty@intel.com>2018-09-14 09:21:08 -0700
committerGravatar avijit-nervana <avijit.chakraborty@intel.com>2018-09-14 09:21:08 -0700
commit41aaed7751690b0b3137dad2620656a698b3ceae (patch)
tree00fc1a7f6be0c3968f3e674a65ca4907110ddf2d /tensorflow/contrib/checkpoint/python/python_state.py
parentc26c5e1217944448f1f4c2b97626fc4d7d6406d3 (diff)
parent95338704198205c1bdec1e344e103f1daf05df68 (diff)
Merge branch 'master' into avijit/add-cpu-backend
Diffstat (limited to 'tensorflow/contrib/checkpoint/python/python_state.py')
-rw-r--r--tensorflow/contrib/checkpoint/python/python_state.py40
1 files changed, 28 insertions, 12 deletions
diff --git a/tensorflow/contrib/checkpoint/python/python_state.py b/tensorflow/contrib/checkpoint/python/python_state.py
index 9b11035b6d..302d5cfb79 100644
--- a/tensorflow/contrib/checkpoint/python/python_state.py
+++ b/tensorflow/contrib/checkpoint/python/python_state.py
@@ -17,7 +17,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import abc
import functools
+import six
import numpy
@@ -101,7 +103,7 @@ class NumpyState(base.CheckpointableBase):
# TODO(allenl): Consider supporting lists/tuples, either ad-hoc or by making
# ndarrays checkpointable natively and using standard checkpointable list
# tracking.
- if isinstance(value, numpy.ndarray):
+ if isinstance(value, (numpy.ndarray, numpy.generic)):
try:
existing = super(NumpyState, self).__getattribute__(name)
existing.array = value
@@ -127,7 +129,29 @@ class NumpyState(base.CheckpointableBase):
super(NumpyState, self).__setattr__(name, value)
-class _NumpyWrapper(base.CheckpointableBase):
+@six.add_metaclass(abc.ABCMeta)
+class PythonStateWrapper(base.CheckpointableBase):
+ """Wraps a Python object for storage in an object-based checkpoint."""
+
+ @abc.abstractmethod
+ def _serialize(self):
+ """Callback for `PythonStringStateSaveable` to serialize the object."""
+
+ @abc.abstractmethod
+ def _deserialize(self, string_value):
+ """Callback for `PythonStringStateSaveable` to deserialize the object."""
+
+ def _gather_saveables_for_checkpoint(self):
+ """Specify callbacks for saving and restoring `array`."""
+ return {
+ "py_state": functools.partial(
+ base.PythonStringStateSaveable,
+ state_callback=self._serialize,
+ restore_callback=self._deserialize)
+ }
+
+
+class _NumpyWrapper(PythonStateWrapper):
"""Wraps a NumPy array for storage in an object-based checkpoint."""
def __init__(self, array):
@@ -139,7 +163,7 @@ class _NumpyWrapper(base.CheckpointableBase):
self.array = array
def _serialize(self):
- """Callback for `PythonStringStateSaveable` to serialize the array."""
+ """Callback to serialize the array."""
string_file = BytesIO()
try:
numpy.save(string_file, self.array, allow_pickle=False)
@@ -149,18 +173,10 @@ class _NumpyWrapper(base.CheckpointableBase):
return serialized
def _deserialize(self, string_value):
- """Callback for `PythonStringStateSaveable` to deserialize the array."""
+ """Callback to deserialize the array."""
string_file = BytesIO(string_value)
try:
self.array = numpy.load(string_file, allow_pickle=False)
finally:
string_file.close()
- def _gather_saveables_for_checkpoint(self):
- """Specify callbacks for saving and restoring `array`."""
- return {
- "array": functools.partial(
- base.PythonStringStateSaveable,
- state_callback=self._serialize,
- restore_callback=self._deserialize)
- }