aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/checkpoint
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-13 14:31:18 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-13 14:35:37 -0700
commit4999d856d2953aee56fa9759f995038edf3ff566 (patch)
tree00ad48d221451352c361534f18da5ad50f7101b7 /tensorflow/contrib/checkpoint
parentd3458112ad5a1612ec6c77f7de4a0e0ec801e882 (diff)
Expose tf.contrib.checkpoint.PythonStateWrapper.
This makes it possible to checkpoint arbitrary python state if it can be serialized to a string. Also updates NumpyState to accept np.int32, np.int64, np.float32, np.float64 types. PiperOrigin-RevId: 212879609
Diffstat (limited to 'tensorflow/contrib/checkpoint')
-rw-r--r--tensorflow/contrib/checkpoint/__init__.py2
-rw-r--r--tensorflow/contrib/checkpoint/python/python_state.py40
-rw-r--r--tensorflow/contrib/checkpoint/python/python_state_test.py5
3 files changed, 35 insertions, 12 deletions
diff --git a/tensorflow/contrib/checkpoint/__init__.py b/tensorflow/contrib/checkpoint/__init__.py
index 150d734db6..94b7f4f867 100644
--- a/tensorflow/contrib/checkpoint/__init__.py
+++ b/tensorflow/contrib/checkpoint/__init__.py
@@ -37,6 +37,7 @@ Checkpoint management:
Saving and restoring Python state:
@@NumpyState
+@@PythonStateWrapper
"""
from __future__ import absolute_import
@@ -45,6 +46,7 @@ from __future__ import print_function
from tensorflow.contrib.checkpoint.python.containers import UniqueNameTracker
from tensorflow.contrib.checkpoint.python.python_state import NumpyState
+from tensorflow.contrib.checkpoint.python.python_state import PythonStateWrapper
from tensorflow.contrib.checkpoint.python.split_dependency import split_dependency
from tensorflow.contrib.checkpoint.python.visualize import dot_graph_from_checkpoint
from tensorflow.core.protobuf.checkpointable_object_graph_pb2 import CheckpointableObjectGraph
diff --git a/tensorflow/contrib/checkpoint/python/python_state.py b/tensorflow/contrib/checkpoint/python/python_state.py
index 9b11035b6d..302d5cfb79 100644
--- a/tensorflow/contrib/checkpoint/python/python_state.py
+++ b/tensorflow/contrib/checkpoint/python/python_state.py
@@ -17,7 +17,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import abc
import functools
+import six
import numpy
@@ -101,7 +103,7 @@ class NumpyState(base.CheckpointableBase):
# TODO(allenl): Consider supporting lists/tuples, either ad-hoc or by making
# ndarrays checkpointable natively and using standard checkpointable list
# tracking.
- if isinstance(value, numpy.ndarray):
+ if isinstance(value, (numpy.ndarray, numpy.generic)):
try:
existing = super(NumpyState, self).__getattribute__(name)
existing.array = value
@@ -127,7 +129,29 @@ class NumpyState(base.CheckpointableBase):
super(NumpyState, self).__setattr__(name, value)
-class _NumpyWrapper(base.CheckpointableBase):
+@six.add_metaclass(abc.ABCMeta)
+class PythonStateWrapper(base.CheckpointableBase):
+ """Wraps a Python object for storage in an object-based checkpoint."""
+
+ @abc.abstractmethod
+ def _serialize(self):
+ """Callback for `PythonStringStateSaveable` to serialize the object."""
+
+ @abc.abstractmethod
+ def _deserialize(self, string_value):
+ """Callback for `PythonStringStateSaveable` to deserialize the object."""
+
+ def _gather_saveables_for_checkpoint(self):
+ """Specify callbacks for saving and restoring `array`."""
+ return {
+ "py_state": functools.partial(
+ base.PythonStringStateSaveable,
+ state_callback=self._serialize,
+ restore_callback=self._deserialize)
+ }
+
+
+class _NumpyWrapper(PythonStateWrapper):
"""Wraps a NumPy array for storage in an object-based checkpoint."""
def __init__(self, array):
@@ -139,7 +163,7 @@ class _NumpyWrapper(base.CheckpointableBase):
self.array = array
def _serialize(self):
- """Callback for `PythonStringStateSaveable` to serialize the array."""
+ """Callback to serialize the array."""
string_file = BytesIO()
try:
numpy.save(string_file, self.array, allow_pickle=False)
@@ -149,18 +173,10 @@ class _NumpyWrapper(base.CheckpointableBase):
return serialized
def _deserialize(self, string_value):
- """Callback for `PythonStringStateSaveable` to deserialize the array."""
+ """Callback to deserialize the array."""
string_file = BytesIO(string_value)
try:
self.array = numpy.load(string_file, allow_pickle=False)
finally:
string_file.close()
- def _gather_saveables_for_checkpoint(self):
- """Specify callbacks for saving and restoring `array`."""
- return {
- "array": functools.partial(
- base.PythonStringStateSaveable,
- state_callback=self._serialize,
- restore_callback=self._deserialize)
- }
diff --git a/tensorflow/contrib/checkpoint/python/python_state_test.py b/tensorflow/contrib/checkpoint/python/python_state_test.py
index 0439a4755e..45494351ff 100644
--- a/tensorflow/contrib/checkpoint/python/python_state_test.py
+++ b/tensorflow/contrib/checkpoint/python/python_state_test.py
@@ -40,10 +40,13 @@ class NumpyStateTests(test.TestCase):
save_state.a = numpy.ones([2, 2])
save_state.b = numpy.ones([2, 2])
save_state.b = numpy.zeros([2, 2])
+ save_state.c = numpy.int64(3)
self.assertAllEqual(numpy.ones([2, 2]), save_state.a)
self.assertAllEqual(numpy.zeros([2, 2]), save_state.b)
+ self.assertEqual(3, save_state.c)
first_save_path = saver.save(prefix)
save_state.a[1, 1] = 2.
+ save_state.c = numpy.int64(4)
second_save_path = saver.save(prefix)
load_state = python_state.NumpyState()
@@ -51,6 +54,7 @@ class NumpyStateTests(test.TestCase):
loader.restore(first_save_path).initialize_or_restore()
self.assertAllEqual(numpy.ones([2, 2]), load_state.a)
self.assertAllEqual(numpy.zeros([2, 2]), load_state.b)
+ self.assertEqual(3, load_state.c)
load_state.a[0, 0] = 42.
self.assertAllEqual([[42., 1.], [1., 1.]], load_state.a)
loader.restore(first_save_path).run_restore_ops()
@@ -58,6 +62,7 @@ class NumpyStateTests(test.TestCase):
loader.restore(second_save_path).run_restore_ops()
self.assertAllEqual([[1., 1.], [1., 2.]], load_state.a)
self.assertAllEqual(numpy.zeros([2, 2]), load_state.b)
+ self.assertEqual(4, load_state.c)
def testNoGraphPollution(self):
graph = ops.Graph()