diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-09-13 14:31:18 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-13 14:35:37 -0700 |
commit | 4999d856d2953aee56fa9759f995038edf3ff566 (patch) | |
tree | 00ad48d221451352c361534f18da5ad50f7101b7 /tensorflow/contrib/checkpoint | |
parent | d3458112ad5a1612ec6c77f7de4a0e0ec801e882 (diff) |
Expose tf.contrib.checkpoint.PythonStateWrapper.
This makes it possible to checkpoint arbitrary python state if it can be
serialized to a string.
Also updates NumpyState to accept np.int32, np.int64, np.float32, np.float64
types.
PiperOrigin-RevId: 212879609
Diffstat (limited to 'tensorflow/contrib/checkpoint')
-rw-r--r-- | tensorflow/contrib/checkpoint/__init__.py | 2 | ||||
-rw-r--r-- | tensorflow/contrib/checkpoint/python/python_state.py | 40 | ||||
-rw-r--r-- | tensorflow/contrib/checkpoint/python/python_state_test.py | 5 |
3 files changed, 35 insertions, 12 deletions
diff --git a/tensorflow/contrib/checkpoint/__init__.py b/tensorflow/contrib/checkpoint/__init__.py index 150d734db6..94b7f4f867 100644 --- a/tensorflow/contrib/checkpoint/__init__.py +++ b/tensorflow/contrib/checkpoint/__init__.py @@ -37,6 +37,7 @@ Checkpoint management: Saving and restoring Python state: @@NumpyState +@@PythonStateWrapper """ from __future__ import absolute_import @@ -45,6 +46,7 @@ from __future__ import print_function from tensorflow.contrib.checkpoint.python.containers import UniqueNameTracker from tensorflow.contrib.checkpoint.python.python_state import NumpyState +from tensorflow.contrib.checkpoint.python.python_state import PythonStateWrapper from tensorflow.contrib.checkpoint.python.split_dependency import split_dependency from tensorflow.contrib.checkpoint.python.visualize import dot_graph_from_checkpoint from tensorflow.core.protobuf.checkpointable_object_graph_pb2 import CheckpointableObjectGraph diff --git a/tensorflow/contrib/checkpoint/python/python_state.py b/tensorflow/contrib/checkpoint/python/python_state.py index 9b11035b6d..302d5cfb79 100644 --- a/tensorflow/contrib/checkpoint/python/python_state.py +++ b/tensorflow/contrib/checkpoint/python/python_state.py @@ -17,7 +17,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import abc import functools +import six import numpy @@ -101,7 +103,7 @@ class NumpyState(base.CheckpointableBase): # TODO(allenl): Consider supporting lists/tuples, either ad-hoc or by making # ndarrays checkpointable natively and using standard checkpointable list # tracking. - if isinstance(value, numpy.ndarray): + if isinstance(value, (numpy.ndarray, numpy.generic)): try: existing = super(NumpyState, self).__getattribute__(name) existing.array = value @@ -127,7 +129,29 @@ class NumpyState(base.CheckpointableBase): super(NumpyState, self).__setattr__(name, value) -class _NumpyWrapper(base.CheckpointableBase): +@six.add_metaclass(abc.ABCMeta) +class PythonStateWrapper(base.CheckpointableBase): + """Wraps a Python object for storage in an object-based checkpoint.""" + + @abc.abstractmethod + def _serialize(self): + """Callback for `PythonStringStateSaveable` to serialize the object.""" + + @abc.abstractmethod + def _deserialize(self, string_value): + """Callback for `PythonStringStateSaveable` to deserialize the object.""" + + def _gather_saveables_for_checkpoint(self): + """Specify callbacks for saving and restoring `array`.""" + return { + "py_state": functools.partial( + base.PythonStringStateSaveable, + state_callback=self._serialize, + restore_callback=self._deserialize) + } + + +class _NumpyWrapper(PythonStateWrapper): """Wraps a NumPy array for storage in an object-based checkpoint.""" def __init__(self, array): @@ -139,7 +163,7 @@ class _NumpyWrapper(base.CheckpointableBase): self.array = array def _serialize(self): - """Callback for `PythonStringStateSaveable` to serialize the array.""" + """Callback to serialize the array.""" string_file = BytesIO() try: numpy.save(string_file, self.array, allow_pickle=False) @@ -149,18 +173,10 @@ class _NumpyWrapper(base.CheckpointableBase): return serialized def _deserialize(self, string_value): - """Callback for `PythonStringStateSaveable` to deserialize the array.""" + """Callback to deserialize the array.""" string_file = BytesIO(string_value) try: self.array = numpy.load(string_file, allow_pickle=False) finally: string_file.close() - def _gather_saveables_for_checkpoint(self): - """Specify callbacks for saving and restoring `array`.""" - return { - "array": functools.partial( - base.PythonStringStateSaveable, - state_callback=self._serialize, - restore_callback=self._deserialize) - } diff --git a/tensorflow/contrib/checkpoint/python/python_state_test.py b/tensorflow/contrib/checkpoint/python/python_state_test.py index 0439a4755e..45494351ff 100644 --- a/tensorflow/contrib/checkpoint/python/python_state_test.py +++ b/tensorflow/contrib/checkpoint/python/python_state_test.py @@ -40,10 +40,13 @@ class NumpyStateTests(test.TestCase): save_state.a = numpy.ones([2, 2]) save_state.b = numpy.ones([2, 2]) save_state.b = numpy.zeros([2, 2]) + save_state.c = numpy.int64(3) self.assertAllEqual(numpy.ones([2, 2]), save_state.a) self.assertAllEqual(numpy.zeros([2, 2]), save_state.b) + self.assertEqual(3, save_state.c) first_save_path = saver.save(prefix) save_state.a[1, 1] = 2. + save_state.c = numpy.int64(4) second_save_path = saver.save(prefix) load_state = python_state.NumpyState() @@ -51,6 +54,7 @@ class NumpyStateTests(test.TestCase): loader.restore(first_save_path).initialize_or_restore() self.assertAllEqual(numpy.ones([2, 2]), load_state.a) self.assertAllEqual(numpy.zeros([2, 2]), load_state.b) + self.assertEqual(3, load_state.c) load_state.a[0, 0] = 42. self.assertAllEqual([[42., 1.], [1., 1.]], load_state.a) loader.restore(first_save_path).run_restore_ops() @@ -58,6 +62,7 @@ class NumpyStateTests(test.TestCase): loader.restore(second_save_path).run_restore_ops() self.assertAllEqual([[1., 1.], [1., 2.]], load_state.a) self.assertAllEqual(numpy.zeros([2, 2]), load_state.b) + self.assertEqual(4, load_state.c) def testNoGraphPollution(self): graph = ops.Graph() |