diff options
author | Allen Lavoie <allenl@google.com> | 2018-08-16 18:35:03 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-16 18:39:16 -0700 |
commit | 8bca3e4ed80a212bdcd8dc1c8505c4e92d2eac15 (patch) | |
tree | 50e3da0d2911d07331b9f13e8fc6c443d2eccf5e /tensorflow/contrib/checkpoint | |
parent | 4d5f6fb8b296bfbd7f72eabd9b7a9a8d29eab633 (diff) |
tf.contrib.checkpoint.NumpyState for saving/restoring NumPy arrays with TF checkpoints
A bit of extra infrastructure in checkpointable restore (save was already done) to support Python callbacks.
The same strategy should work for any Python state, although it's confined to non-pickled NumPy arrays at the moment.
PiperOrigin-RevId: 209085928
Diffstat (limited to 'tensorflow/contrib/checkpoint')
-rw-r--r-- | tensorflow/contrib/checkpoint/__init__.py | 4 | ||||
-rw-r--r-- | tensorflow/contrib/checkpoint/python/BUILD | 28 | ||||
-rw-r--r-- | tensorflow/contrib/checkpoint/python/python_state.py | 166 | ||||
-rw-r--r-- | tensorflow/contrib/checkpoint/python/python_state_test.py | 101 |
4 files changed, 299 insertions, 0 deletions
diff --git a/tensorflow/contrib/checkpoint/__init__.py b/tensorflow/contrib/checkpoint/__init__.py index e92f0bb841..150d734db6 100644 --- a/tensorflow/contrib/checkpoint/__init__.py +++ b/tensorflow/contrib/checkpoint/__init__.py @@ -34,6 +34,9 @@ Checkpointable data structures: Checkpoint management: @@CheckpointManager + +Saving and restoring Python state: +@@NumpyState """ from __future__ import absolute_import @@ -41,6 +44,7 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.checkpoint.python.containers import UniqueNameTracker +from tensorflow.contrib.checkpoint.python.python_state import NumpyState from tensorflow.contrib.checkpoint.python.split_dependency import split_dependency from tensorflow.contrib.checkpoint.python.visualize import dot_graph_from_checkpoint from tensorflow.core.protobuf.checkpointable_object_graph_pb2 import CheckpointableObjectGraph diff --git a/tensorflow/contrib/checkpoint/python/BUILD b/tensorflow/contrib/checkpoint/python/BUILD index 7b200a29bf..ada4168726 100644 --- a/tensorflow/contrib/checkpoint/python/BUILD +++ b/tensorflow/contrib/checkpoint/python/BUILD @@ -9,6 +9,7 @@ py_library( srcs_version = "PY2AND3", deps = [ ":containers", + ":python_state", ":split_dependency", ":visualize", "//tensorflow/python/training/checkpointable:data_structures", @@ -41,6 +42,33 @@ py_test( ) py_library( + name = "python_state", + srcs = ["python_state.py"], + srcs_version = "PY2AND3", + visibility = ["//tensorflow:internal"], + deps = [ + "//tensorflow/python/training/checkpointable:base", + "//third_party/py/numpy", + "@six_archive//:six", + ], +) + +py_test( + name = "python_state_test", + srcs = ["python_state_test.py"], + deps = [ + ":python_state", + "//tensorflow/python:framework_ops", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:session", + "//tensorflow/python:variables", + "//tensorflow/python/eager:test", + "//tensorflow/python/training/checkpointable:util", + "//third_party/py/numpy", + ], +) + +py_library( name = "split_dependency", srcs = ["split_dependency.py"], srcs_version = "PY2AND3", diff --git a/tensorflow/contrib/checkpoint/python/python_state.py b/tensorflow/contrib/checkpoint/python/python_state.py new file mode 100644 index 0000000000..9b11035b6d --- /dev/null +++ b/tensorflow/contrib/checkpoint/python/python_state.py @@ -0,0 +1,166 @@ +"""Utilities for including Python state in TensorFlow checkpoints.""" +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools + +import numpy + +from tensorflow.python.training.checkpointable import base + +# pylint: disable=g-import-not-at-top +try: + # In Python 2.x, use the faster string buffering option. + from cStringIO import StringIO as BytesIO +except ImportError: + from io import BytesIO +# pylint: enable=g-import-not-at-top + + +class NumpyState(base.CheckpointableBase): + """A checkpointable object whose NumPy array attributes are saved/restored. + + Example usage: + + ```python + arrays = tf.contrib.checkpoint.NumpyState() + checkpoint = tf.train.Checkpoint(numpy_arrays=arrays) + arrays.x = numpy.zeros([3, 4]) + save_path = checkpoint.save("/tmp/ckpt") + arrays.x[1, 1] = 4. + checkpoint.restore(save_path) + assert (arrays.x == numpy.zeros([3, 4])).all() + + second_checkpoint = tf.train.Checkpoint( + numpy_arrays=tf.contrib.checkpoint.NumpyState()) + # Attributes of NumpyState objects are created automatically by restore() + second_checkpoint.restore(save_path) + assert (second_checkpoint.numpy_arrays.x == numpy.zeros([3, 4])).all() + ``` + + Note that `NumpyState` objects re-create the attributes of the previously + saved object on `restore()`. This is in contrast to TensorFlow variables, for + which a `Variable` object must be created and assigned to an attribute. + + This snippet works both when graph building and when executing eagerly. On + save, the NumPy array(s) are fed as strings to be saved in the checkpoint (via + a placeholder when graph building, or as a string constant when executing + eagerly). When restoring they skip the TensorFlow graph entirely, and so no + restore ops need be run. This means that restoration always happens eagerly, + rather than waiting for `checkpoint.restore(...).run_restore_ops()` like + TensorFlow variables when graph building. + """ + + def _lookup_dependency(self, name): + """Create placeholder NumPy arrays for to-be-restored attributes. + + Typically `_lookup_dependency` is used to check by name whether a dependency + exists. We cheat slightly by creating a checkpointable object for `name` if + we don't already have one, giving us attribute re-creation behavior when + loading a checkpoint. + + Args: + name: The name of the dependency being checked. + Returns: + An existing dependency if one exists, or a new `_NumpyWrapper` placeholder + dependency (which will generally be restored immediately). + """ + value = super(NumpyState, self)._lookup_dependency(name) + if value is None: + value = _NumpyWrapper(numpy.array([])) + new_reference = base.CheckpointableReference(name=name, ref=value) + self._unconditional_checkpoint_dependencies.append(new_reference) + self._unconditional_dependency_names[name] = value + super(NumpyState, self).__setattr__(name, value) + return value + + def __getattribute__(self, name): + """Un-wrap `_NumpyWrapper` objects when accessing attributes.""" + value = super(NumpyState, self).__getattribute__(name) + if isinstance(value, _NumpyWrapper): + return value.array + return value + + def __setattr__(self, name, value): + """Automatically wrap NumPy arrays assigned to attributes.""" + # TODO(allenl): Consider supporting lists/tuples, either ad-hoc or by making + # ndarrays checkpointable natively and using standard checkpointable list + # tracking. + if isinstance(value, numpy.ndarray): + try: + existing = super(NumpyState, self).__getattribute__(name) + existing.array = value + return + except AttributeError: + value = _NumpyWrapper(value) + self._track_checkpointable(value, name=name, overwrite=True) + elif (name not in ("_setattr_tracking", "_update_uid") + and getattr(self, "_setattr_tracking", True)): + # Mixing restore()-created attributes with user-added checkpointable + # objects is tricky, since we can't use the `_lookup_dependency` trick to + # re-create attributes (we might accidentally steal the restoration for + # another checkpointable object). For now `NumpyState` objects must be + # leaf nodes. Theoretically we could add some extra arguments to + # `_lookup_dependency` to figure out whether we should create a NumPy + # array for the attribute or not. + raise NotImplementedError( + ("Assigned %s to the %s property of %s, which is not a NumPy array. " + "Currently mixing NumPy arrays and other checkpointable objects is " + "not supported. File a feature request if this limitation bothers " + "you.") + % (value, name, self)) + super(NumpyState, self).__setattr__(name, value) + + +class _NumpyWrapper(base.CheckpointableBase): + """Wraps a NumPy array for storage in an object-based checkpoint.""" + + def __init__(self, array): + """Specify a NumPy array to wrap. + + Args: + array: The NumPy array to save and restore (may be overwritten). + """ + self.array = array + + def _serialize(self): + """Callback for `PythonStringStateSaveable` to serialize the array.""" + string_file = BytesIO() + try: + numpy.save(string_file, self.array, allow_pickle=False) + serialized = string_file.getvalue() + finally: + string_file.close() + return serialized + + def _deserialize(self, string_value): + """Callback for `PythonStringStateSaveable` to deserialize the array.""" + string_file = BytesIO(string_value) + try: + self.array = numpy.load(string_file, allow_pickle=False) + finally: + string_file.close() + + def _gather_saveables_for_checkpoint(self): + """Specify callbacks for saving and restoring `array`.""" + return { + "array": functools.partial( + base.PythonStringStateSaveable, + state_callback=self._serialize, + restore_callback=self._deserialize) + } diff --git a/tensorflow/contrib/checkpoint/python/python_state_test.py b/tensorflow/contrib/checkpoint/python/python_state_test.py new file mode 100644 index 0000000000..0439a4755e --- /dev/null +++ b/tensorflow/contrib/checkpoint/python/python_state_test.py @@ -0,0 +1,101 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +import numpy + +from tensorflow.contrib.checkpoint.python import python_state +from tensorflow.python.client import session +from tensorflow.python.eager import test +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import variables +from tensorflow.python.training.checkpointable import util + + +class NumpyStateTests(test.TestCase): + + @test_util.run_in_graph_and_eager_modes + def testSaveRestoreNumpyState(self): + directory = self.get_temp_dir() + prefix = os.path.join(directory, "ckpt") + save_state = python_state.NumpyState() + saver = util.Checkpoint(numpy=save_state) + save_state.a = numpy.ones([2, 2]) + save_state.b = numpy.ones([2, 2]) + save_state.b = numpy.zeros([2, 2]) + self.assertAllEqual(numpy.ones([2, 2]), save_state.a) + self.assertAllEqual(numpy.zeros([2, 2]), save_state.b) + first_save_path = saver.save(prefix) + save_state.a[1, 1] = 2. + second_save_path = saver.save(prefix) + + load_state = python_state.NumpyState() + loader = util.Checkpoint(numpy=load_state) + loader.restore(first_save_path).initialize_or_restore() + self.assertAllEqual(numpy.ones([2, 2]), load_state.a) + self.assertAllEqual(numpy.zeros([2, 2]), load_state.b) + load_state.a[0, 0] = 42. + self.assertAllEqual([[42., 1.], [1., 1.]], load_state.a) + loader.restore(first_save_path).run_restore_ops() + self.assertAllEqual(numpy.ones([2, 2]), load_state.a) + loader.restore(second_save_path).run_restore_ops() + self.assertAllEqual([[1., 1.], [1., 2.]], load_state.a) + self.assertAllEqual(numpy.zeros([2, 2]), load_state.b) + + def testNoGraphPollution(self): + graph = ops.Graph() + with graph.as_default(), session.Session(): + directory = self.get_temp_dir() + prefix = os.path.join(directory, "ckpt") + save_state = python_state.NumpyState() + saver = util.Checkpoint(numpy=save_state) + save_state.a = numpy.ones([2, 2]) + save_path = saver.save(prefix) + saver.restore(save_path) + graph.finalize() + saver.save(prefix) + save_state.a = numpy.zeros([2, 2]) + saver.save(prefix) + saver.restore(save_path) + + @test_util.run_in_graph_and_eager_modes + def testNoMixedNumpyStateTF(self): + save_state = python_state.NumpyState() + save_state.a = numpy.ones([2, 2]) + with self.assertRaises(NotImplementedError): + save_state.v = variables.Variable(1.) + + @test_util.run_in_graph_and_eager_modes + def testDocstringExample(self): + arrays = python_state.NumpyState() + checkpoint = util.Checkpoint(numpy_arrays=arrays) + arrays.x = numpy.zeros([3, 4]) + save_path = checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt")) + arrays.x[1, 1] = 4. + checkpoint.restore(save_path) + self.assertAllEqual(numpy.zeros([3, 4]), arrays.x) + + second_checkpoint = util.Checkpoint(numpy_arrays=python_state.NumpyState()) + second_checkpoint.restore(save_path) + self.assertAllEqual(numpy.zeros([3, 4]), second_checkpoint.numpy_arrays.x) + + +if __name__ == "__main__": + test.main() |