aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/checkpoint
diff options
context:
space:
mode:
authorGravatar Allen Lavoie <allenl@google.com>2018-08-16 18:35:03 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-16 18:39:16 -0700
commit8bca3e4ed80a212bdcd8dc1c8505c4e92d2eac15 (patch)
tree50e3da0d2911d07331b9f13e8fc6c443d2eccf5e /tensorflow/contrib/checkpoint
parent4d5f6fb8b296bfbd7f72eabd9b7a9a8d29eab633 (diff)
tf.contrib.checkpoint.NumpyState for saving/restoring NumPy arrays with TF checkpoints
A bit of extra infrastructure in checkpointable restore (save was already done) to support Python callbacks. The same strategy should work for any Python state, although it's confined to non-pickled NumPy arrays at the moment. PiperOrigin-RevId: 209085928
Diffstat (limited to 'tensorflow/contrib/checkpoint')
-rw-r--r--tensorflow/contrib/checkpoint/__init__.py4
-rw-r--r--tensorflow/contrib/checkpoint/python/BUILD28
-rw-r--r--tensorflow/contrib/checkpoint/python/python_state.py166
-rw-r--r--tensorflow/contrib/checkpoint/python/python_state_test.py101
4 files changed, 299 insertions, 0 deletions
diff --git a/tensorflow/contrib/checkpoint/__init__.py b/tensorflow/contrib/checkpoint/__init__.py
index e92f0bb841..150d734db6 100644
--- a/tensorflow/contrib/checkpoint/__init__.py
+++ b/tensorflow/contrib/checkpoint/__init__.py
@@ -34,6 +34,9 @@ Checkpointable data structures:
Checkpoint management:
@@CheckpointManager
+
+Saving and restoring Python state:
+@@NumpyState
"""
from __future__ import absolute_import
@@ -41,6 +44,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.checkpoint.python.containers import UniqueNameTracker
+from tensorflow.contrib.checkpoint.python.python_state import NumpyState
from tensorflow.contrib.checkpoint.python.split_dependency import split_dependency
from tensorflow.contrib.checkpoint.python.visualize import dot_graph_from_checkpoint
from tensorflow.core.protobuf.checkpointable_object_graph_pb2 import CheckpointableObjectGraph
diff --git a/tensorflow/contrib/checkpoint/python/BUILD b/tensorflow/contrib/checkpoint/python/BUILD
index 7b200a29bf..ada4168726 100644
--- a/tensorflow/contrib/checkpoint/python/BUILD
+++ b/tensorflow/contrib/checkpoint/python/BUILD
@@ -9,6 +9,7 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":containers",
+ ":python_state",
":split_dependency",
":visualize",
"//tensorflow/python/training/checkpointable:data_structures",
@@ -41,6 +42,33 @@ py_test(
)
py_library(
+ name = "python_state",
+ srcs = ["python_state.py"],
+ srcs_version = "PY2AND3",
+ visibility = ["//tensorflow:internal"],
+ deps = [
+ "//tensorflow/python/training/checkpointable:base",
+ "//third_party/py/numpy",
+ "@six_archive//:six",
+ ],
+)
+
+py_test(
+ name = "python_state_test",
+ srcs = ["python_state_test.py"],
+ deps = [
+ ":python_state",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:session",
+ "//tensorflow/python:variables",
+ "//tensorflow/python/eager:test",
+ "//tensorflow/python/training/checkpointable:util",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_library(
name = "split_dependency",
srcs = ["split_dependency.py"],
srcs_version = "PY2AND3",
diff --git a/tensorflow/contrib/checkpoint/python/python_state.py b/tensorflow/contrib/checkpoint/python/python_state.py
new file mode 100644
index 0000000000..9b11035b6d
--- /dev/null
+++ b/tensorflow/contrib/checkpoint/python/python_state.py
@@ -0,0 +1,166 @@
+"""Utilities for including Python state in TensorFlow checkpoints."""
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import functools
+
+import numpy
+
+from tensorflow.python.training.checkpointable import base
+
+# pylint: disable=g-import-not-at-top
+try:
+ # In Python 2.x, use the faster string buffering option.
+ from cStringIO import StringIO as BytesIO
+except ImportError:
+ from io import BytesIO
+# pylint: enable=g-import-not-at-top
+
+
+class NumpyState(base.CheckpointableBase):
+ """A checkpointable object whose NumPy array attributes are saved/restored.
+
+ Example usage:
+
+ ```python
+ arrays = tf.contrib.checkpoint.NumpyState()
+ checkpoint = tf.train.Checkpoint(numpy_arrays=arrays)
+ arrays.x = numpy.zeros([3, 4])
+ save_path = checkpoint.save("/tmp/ckpt")
+ arrays.x[1, 1] = 4.
+ checkpoint.restore(save_path)
+ assert (arrays.x == numpy.zeros([3, 4])).all()
+
+ second_checkpoint = tf.train.Checkpoint(
+ numpy_arrays=tf.contrib.checkpoint.NumpyState())
+ # Attributes of NumpyState objects are created automatically by restore()
+ second_checkpoint.restore(save_path)
+ assert (second_checkpoint.numpy_arrays.x == numpy.zeros([3, 4])).all()
+ ```
+
+ Note that `NumpyState` objects re-create the attributes of the previously
+ saved object on `restore()`. This is in contrast to TensorFlow variables, for
+ which a `Variable` object must be created and assigned to an attribute.
+
+ This snippet works both when graph building and when executing eagerly. On
+ save, the NumPy array(s) are fed as strings to be saved in the checkpoint (via
+ a placeholder when graph building, or as a string constant when executing
+ eagerly). When restoring they skip the TensorFlow graph entirely, and so no
+ restore ops need be run. This means that restoration always happens eagerly,
+ rather than waiting for `checkpoint.restore(...).run_restore_ops()` like
+ TensorFlow variables when graph building.
+ """
+
+ def _lookup_dependency(self, name):
+ """Create placeholder NumPy arrays for to-be-restored attributes.
+
+ Typically `_lookup_dependency` is used to check by name whether a dependency
+ exists. We cheat slightly by creating a checkpointable object for `name` if
+ we don't already have one, giving us attribute re-creation behavior when
+ loading a checkpoint.
+
+ Args:
+ name: The name of the dependency being checked.
+ Returns:
+ An existing dependency if one exists, or a new `_NumpyWrapper` placeholder
+ dependency (which will generally be restored immediately).
+ """
+ value = super(NumpyState, self)._lookup_dependency(name)
+ if value is None:
+ value = _NumpyWrapper(numpy.array([]))
+ new_reference = base.CheckpointableReference(name=name, ref=value)
+ self._unconditional_checkpoint_dependencies.append(new_reference)
+ self._unconditional_dependency_names[name] = value
+ super(NumpyState, self).__setattr__(name, value)
+ return value
+
+ def __getattribute__(self, name):
+ """Un-wrap `_NumpyWrapper` objects when accessing attributes."""
+ value = super(NumpyState, self).__getattribute__(name)
+ if isinstance(value, _NumpyWrapper):
+ return value.array
+ return value
+
+ def __setattr__(self, name, value):
+ """Automatically wrap NumPy arrays assigned to attributes."""
+ # TODO(allenl): Consider supporting lists/tuples, either ad-hoc or by making
+ # ndarrays checkpointable natively and using standard checkpointable list
+ # tracking.
+ if isinstance(value, numpy.ndarray):
+ try:
+ existing = super(NumpyState, self).__getattribute__(name)
+ existing.array = value
+ return
+ except AttributeError:
+ value = _NumpyWrapper(value)
+ self._track_checkpointable(value, name=name, overwrite=True)
+ elif (name not in ("_setattr_tracking", "_update_uid")
+ and getattr(self, "_setattr_tracking", True)):
+ # Mixing restore()-created attributes with user-added checkpointable
+ # objects is tricky, since we can't use the `_lookup_dependency` trick to
+ # re-create attributes (we might accidentally steal the restoration for
+ # another checkpointable object). For now `NumpyState` objects must be
+ # leaf nodes. Theoretically we could add some extra arguments to
+ # `_lookup_dependency` to figure out whether we should create a NumPy
+ # array for the attribute or not.
+ raise NotImplementedError(
+ ("Assigned %s to the %s property of %s, which is not a NumPy array. "
+ "Currently mixing NumPy arrays and other checkpointable objects is "
+ "not supported. File a feature request if this limitation bothers "
+ "you.")
+ % (value, name, self))
+ super(NumpyState, self).__setattr__(name, value)
+
+
+class _NumpyWrapper(base.CheckpointableBase):
+ """Wraps a NumPy array for storage in an object-based checkpoint."""
+
+ def __init__(self, array):
+ """Specify a NumPy array to wrap.
+
+ Args:
+ array: The NumPy array to save and restore (may be overwritten).
+ """
+ self.array = array
+
+ def _serialize(self):
+ """Callback for `PythonStringStateSaveable` to serialize the array."""
+ string_file = BytesIO()
+ try:
+ numpy.save(string_file, self.array, allow_pickle=False)
+ serialized = string_file.getvalue()
+ finally:
+ string_file.close()
+ return serialized
+
+ def _deserialize(self, string_value):
+ """Callback for `PythonStringStateSaveable` to deserialize the array."""
+ string_file = BytesIO(string_value)
+ try:
+ self.array = numpy.load(string_file, allow_pickle=False)
+ finally:
+ string_file.close()
+
+ def _gather_saveables_for_checkpoint(self):
+ """Specify callbacks for saving and restoring `array`."""
+ return {
+ "array": functools.partial(
+ base.PythonStringStateSaveable,
+ state_callback=self._serialize,
+ restore_callback=self._deserialize)
+ }
diff --git a/tensorflow/contrib/checkpoint/python/python_state_test.py b/tensorflow/contrib/checkpoint/python/python_state_test.py
new file mode 100644
index 0000000000..0439a4755e
--- /dev/null
+++ b/tensorflow/contrib/checkpoint/python/python_state_test.py
@@ -0,0 +1,101 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+import numpy
+
+from tensorflow.contrib.checkpoint.python import python_state
+from tensorflow.python.client import session
+from tensorflow.python.eager import test
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import variables
+from tensorflow.python.training.checkpointable import util
+
+
+class NumpyStateTests(test.TestCase):
+
+ @test_util.run_in_graph_and_eager_modes
+ def testSaveRestoreNumpyState(self):
+ directory = self.get_temp_dir()
+ prefix = os.path.join(directory, "ckpt")
+ save_state = python_state.NumpyState()
+ saver = util.Checkpoint(numpy=save_state)
+ save_state.a = numpy.ones([2, 2])
+ save_state.b = numpy.ones([2, 2])
+ save_state.b = numpy.zeros([2, 2])
+ self.assertAllEqual(numpy.ones([2, 2]), save_state.a)
+ self.assertAllEqual(numpy.zeros([2, 2]), save_state.b)
+ first_save_path = saver.save(prefix)
+ save_state.a[1, 1] = 2.
+ second_save_path = saver.save(prefix)
+
+ load_state = python_state.NumpyState()
+ loader = util.Checkpoint(numpy=load_state)
+ loader.restore(first_save_path).initialize_or_restore()
+ self.assertAllEqual(numpy.ones([2, 2]), load_state.a)
+ self.assertAllEqual(numpy.zeros([2, 2]), load_state.b)
+ load_state.a[0, 0] = 42.
+ self.assertAllEqual([[42., 1.], [1., 1.]], load_state.a)
+ loader.restore(first_save_path).run_restore_ops()
+ self.assertAllEqual(numpy.ones([2, 2]), load_state.a)
+ loader.restore(second_save_path).run_restore_ops()
+ self.assertAllEqual([[1., 1.], [1., 2.]], load_state.a)
+ self.assertAllEqual(numpy.zeros([2, 2]), load_state.b)
+
+ def testNoGraphPollution(self):
+ graph = ops.Graph()
+ with graph.as_default(), session.Session():
+ directory = self.get_temp_dir()
+ prefix = os.path.join(directory, "ckpt")
+ save_state = python_state.NumpyState()
+ saver = util.Checkpoint(numpy=save_state)
+ save_state.a = numpy.ones([2, 2])
+ save_path = saver.save(prefix)
+ saver.restore(save_path)
+ graph.finalize()
+ saver.save(prefix)
+ save_state.a = numpy.zeros([2, 2])
+ saver.save(prefix)
+ saver.restore(save_path)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testNoMixedNumpyStateTF(self):
+ save_state = python_state.NumpyState()
+ save_state.a = numpy.ones([2, 2])
+ with self.assertRaises(NotImplementedError):
+ save_state.v = variables.Variable(1.)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testDocstringExample(self):
+ arrays = python_state.NumpyState()
+ checkpoint = util.Checkpoint(numpy_arrays=arrays)
+ arrays.x = numpy.zeros([3, 4])
+ save_path = checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
+ arrays.x[1, 1] = 4.
+ checkpoint.restore(save_path)
+ self.assertAllEqual(numpy.zeros([3, 4]), arrays.x)
+
+ second_checkpoint = util.Checkpoint(numpy_arrays=python_state.NumpyState())
+ second_checkpoint.restore(save_path)
+ self.assertAllEqual(numpy.zeros([3, 4]), second_checkpoint.numpy_arrays.x)
+
+
+if __name__ == "__main__":
+ test.main()