aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/checkpoint
diff options
context:
space:
mode:
authorGravatar Allen Lavoie <allenl@google.com>2018-04-19 09:54:40 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-19 09:57:23 -0700
commit436f1434060d7f370baae9661baacc6cf27415ec (patch)
tree9276bf419abb2e1bcfccdb7c43bea7ffbab80283 /tensorflow/contrib/checkpoint
parentc173157bdc132460c6f424a9803221e74fc73f59 (diff)
Create a skeleton tf.contrib.checkpoint.
My plan for this is to incubate tools for working with object-based checkpoints: - Tools for managing dependency graphs, e.g. checkpointable lists/dictionaries - Inspecting/visualizing checkpoints - Listing variables and gathering initializers from a Checkpointable object and its dependencies - Verifying all variables are accessible as dependencies, which should make converting existing graph building Saver uses easier/safer. This CL includes none of those things, it just moves the split_dependency tool here instead of contrib/eager. PiperOrigin-RevId: 193531292
Diffstat (limited to 'tensorflow/contrib/checkpoint')
-rw-r--r--tensorflow/contrib/checkpoint/README.md2
-rw-r--r--tensorflow/contrib/checkpoint/__init__.py29
-rw-r--r--tensorflow/contrib/checkpoint/python/BUILD29
-rw-r--r--tensorflow/contrib/checkpoint/python/split_dependency.py136
-rw-r--r--tensorflow/contrib/checkpoint/python/split_dependency_test.py112
5 files changed, 308 insertions, 0 deletions
diff --git a/tensorflow/contrib/checkpoint/README.md b/tensorflow/contrib/checkpoint/README.md
new file mode 100644
index 0000000000..d35c5bae3b
--- /dev/null
+++ b/tensorflow/contrib/checkpoint/README.md
@@ -0,0 +1,2 @@
+Tools for working with object-based checkpoints produced by
+`tf.train.Checkpoint`.
diff --git a/tensorflow/contrib/checkpoint/__init__.py b/tensorflow/contrib/checkpoint/__init__.py
new file mode 100644
index 0000000000..70d7d2d8d7
--- /dev/null
+++ b/tensorflow/contrib/checkpoint/__init__.py
@@ -0,0 +1,29 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tools for working with object-based checkpoints.
+
+
+For creating and managing dependencies:
+@@split_dependency
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.checkpoint.python.split_dependency import split_dependency
+from tensorflow.python.util.all_util import remove_undocumented
+
+remove_undocumented(module_name=__name__)
diff --git a/tensorflow/contrib/checkpoint/python/BUILD b/tensorflow/contrib/checkpoint/python/BUILD
new file mode 100644
index 0000000000..d57b01aab2
--- /dev/null
+++ b/tensorflow/contrib/checkpoint/python/BUILD
@@ -0,0 +1,29 @@
+licenses(["notice"]) # Apache 2.0
+
+package(default_visibility = ["//tensorflow:internal"])
+
+load("//tensorflow:tensorflow.bzl", "py_test")
+
+py_library(
+ name = "split_dependency",
+ srcs = ["split_dependency.py"],
+ srcs_version = "PY2AND3",
+ visibility = ["//tensorflow:internal"],
+ deps = [
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:training",
+ ],
+)
+
+py_test(
+ name = "split_dependency_test",
+ srcs = ["split_dependency_test.py"],
+ deps = [
+ ":split_dependency",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:resource_variable_ops",
+ "//tensorflow/python:training",
+ "//tensorflow/python/eager:test",
+ ],
+)
diff --git a/tensorflow/contrib/checkpoint/python/split_dependency.py b/tensorflow/contrib/checkpoint/python/split_dependency.py
new file mode 100644
index 0000000000..3aec8c96e9
--- /dev/null
+++ b/tensorflow/contrib/checkpoint/python/split_dependency.py
@@ -0,0 +1,136 @@
+"""Utility for creating multiple dependencies with synchronized save/restore."""
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import functools
+
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.training import checkpointable as checkpointable
+from tensorflow.python.training import saver as saver_lib
+
+
+class _CallbackSaveable(saver_lib.BaseSaverBuilder.SaveableObject):
+ """Wraps save and restore callbacks as a `SaveableObject`."""
+
+ def __init__(self, name, dtype, save_callback, restore_callback):
+ self._restore_callback = restore_callback
+ spec = saver_lib.BaseSaverBuilder.SaveSpec(
+ tensor=save_callback,
+ slice_spec="",
+ name=name,
+ dtype=dtype)
+ super(_CallbackSaveable, self).__init__(
+ save_callback, [spec], name)
+
+ def restore(self, restored_tensors, restored_shapes):
+ """Restore the same value into both variables."""
+ tensor, = restored_tensors
+ return self._restore_callback(tensor)
+
+
+class _SplitDependency(checkpointable.CheckpointableBase):
+ """Looks like a regular variable while synchronizing save/restores."""
+
+ def __init__(self, save_buffer, restore_buffer, name, dtype, num_components,
+ fill_save_buffer_fn, consume_restore_buffer_fn):
+ self._save_buffer = save_buffer
+ self._restore_buffer = restore_buffer
+ self._name = name
+ self._dtype = dtype
+ self._num_components = num_components
+ self._fill_save_buffer_fn = fill_save_buffer_fn
+ self._consume_restore_buffer_fn = consume_restore_buffer_fn
+
+ def _save(self):
+ """Pull from the shared buffer, populating it if necessary."""
+ if self._name not in self._save_buffer:
+ if self._save_buffer:
+ raise AssertionError(
+ ("Split dependency %s (%s) unsynchronized. Split dependencies must "
+ "be saved together.") % (self._name, self))
+ self._fill_save_buffer_fn(self._save_buffer)
+ return self._save_buffer.pop(self._name)
+
+ def _restore(self, tensor):
+ """Push into the shared buffer, flushing it if necessary."""
+ if self._name in self._restore_buffer:
+ raise AssertionError(
+ ("Split dependency %s (%s) unsynchronized. Split dependencies must "
+ "be restored together.") % (self._name, self))
+ self._restore_buffer[self._name] = tensor
+ if len(self._restore_buffer) == self._num_components:
+ op = self._consume_restore_buffer_fn(self._restore_buffer)
+ self._restore_buffer.clear()
+ return op
+ else:
+ return control_flow_ops.no_op()
+
+ def _gather_saveables_for_checkpoint(self):
+ """Looks to Checkpointable like a regular variable."""
+ return {
+ checkpointable.VARIABLE_VALUE_KEY:
+ functools.partial(_CallbackSaveable,
+ dtype=self._dtype,
+ save_callback=self._save,
+ restore_callback=self._restore)
+ }
+
+
+def split_dependency(component_names, component_dtypes,
+ fill_save_buffer_fn, consume_restore_buffer_fn):
+ """Creates multiple dependencies with a synchronized save/restore.
+
+ Useful when a single op produces `Tensor`s which should each be saved under
+ different objects, or when `Tensor`s saved with many different objects need to
+ be restored together as inputs to a single op (i.e. an object which uses a
+ single fused op may be swapped out for a subgraph of objects, and these two
+ programs are checkpoint compatible).
+
+ Args:
+ component_names: A sequence of names for the split
+ dependencies. `fill_save_buffer_fn` must add these keys to the dictionary
+ it is passed, and `consume_restore_buffer_fn` will receive a dictionary
+ with these keys.
+ component_dtypes: Data types for the `Tensor`s being saved and restored, a
+ sequence corresponding to `component_names`.
+ fill_save_buffer_fn: A function which takes an empty dictionary as an
+ argument and adds `Tensor`s with `component_names` as keys. These
+ `Tensor`s will be saved as if they were individual variables.
+ consume_restore_buffer_fn: A function which takes a dictionary with
+ `component_names` as keys mapping to restored individual `Tensor`s and
+ returns a restore op (or if executing eagerly, runs the restoration and
+ may return `None`).
+
+ Returns:
+ A dictionary mapping from names to Checkpointable objects. If one is
+ reachable from an object as a dependency, the others should be too; adding
+ dependencies on some but not all of the objects will result in errors.
+ """
+ save_buffer = {}
+ restore_buffer = {}
+ split_dependencies = {}
+ for name, dtype in zip(component_names, component_dtypes):
+ split_dependencies[name] = _SplitDependency(
+ save_buffer=save_buffer,
+ restore_buffer=restore_buffer,
+ name=name,
+ dtype=dtype,
+ num_components=len(component_names),
+ fill_save_buffer_fn=fill_save_buffer_fn,
+ consume_restore_buffer_fn=consume_restore_buffer_fn)
+ return split_dependencies
diff --git a/tensorflow/contrib/checkpoint/python/split_dependency_test.py b/tensorflow/contrib/checkpoint/python/split_dependency_test.py
new file mode 100644
index 0000000000..cb964c80e9
--- /dev/null
+++ b/tensorflow/contrib/checkpoint/python/split_dependency_test.py
@@ -0,0 +1,112 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+from tensorflow.contrib.checkpoint.python import split_dependency
+from tensorflow.python.eager import test
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.training import checkpointable
+from tensorflow.python.training import checkpointable_utils
+
+
+def _split_variable_closure(variable):
+ def _fill_save_buffer_fn(save_buffer):
+ save_buffer["first_half"] = variable[:2]
+ save_buffer["second_half"] = variable[2:]
+ return _fill_save_buffer_fn
+
+
+def _combine_variable_closure(variable):
+ def _consume_restore_buffer_fn(restore_buffer):
+ return variable.assign(
+ array_ops.concat([restore_buffer["first_half"],
+ restore_buffer["second_half"]],
+ axis=0))
+ return _consume_restore_buffer_fn
+
+
+class SaveTensorSlicesAsDeps(checkpointable.CheckpointableBase):
+
+ def __init__(self):
+ self.combined = resource_variable_ops.ResourceVariable([0., 0., 0., 0.])
+ split_dependencies = split_dependency.split_dependency(
+ component_names=("first_half", "second_half"),
+ component_dtypes=(self.combined.dtype,) * 2,
+ fill_save_buffer_fn=_split_variable_closure(
+ self.combined),
+ consume_restore_buffer_fn=_combine_variable_closure(
+ self.combined))
+ for name, dep in split_dependencies.items():
+ self._track_checkpointable(dep, name=name)
+
+
+class HasRegularDeps(checkpointable.Checkpointable):
+
+ def __init__(self):
+ self.first_half = resource_variable_ops.ResourceVariable([0., 0.])
+ self.second_half = resource_variable_ops.ResourceVariable([0., 0.])
+
+
+class OnlyOneDep(checkpointable.Checkpointable):
+
+ def __init__(self):
+ self.first_half = resource_variable_ops.ResourceVariable([0., 0.])
+
+
+class SplitTests(test.TestCase):
+
+ @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
+ def testSaveRestoreSplitDep(self):
+ save_checkpoint = checkpointable_utils.Checkpoint(
+ dep=SaveTensorSlicesAsDeps())
+ self.evaluate(save_checkpoint.dep.combined.assign([1., 2., 3., 4.]))
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+ save_path = save_checkpoint.save(checkpoint_prefix)
+
+ regular_deps = HasRegularDeps()
+ regular_restore_checkpoint = checkpointable_utils.Checkpoint(
+ dep=regular_deps)
+ regular_restore_checkpoint.restore(
+ save_path).assert_consumed().run_restore_ops()
+ self.assertAllEqual([1., 2.], self.evaluate(regular_deps.first_half))
+ self.assertAllEqual([3., 4.], self.evaluate(regular_deps.second_half))
+
+ one_dep = OnlyOneDep()
+ one_dep_restore_checkpoint = checkpointable_utils.Checkpoint(dep=one_dep)
+ status = one_dep_restore_checkpoint.restore(save_path)
+ with self.assertRaises(AssertionError):
+ # Missing the second dependency.
+ status.assert_consumed()
+ status.run_restore_ops()
+ self.assertAllEqual([1., 2.], self.evaluate(one_dep.first_half))
+
+ restore_checkpoint = checkpointable_utils.Checkpoint()
+ status = restore_checkpoint.restore(save_path)
+ restore_checkpoint.dep = SaveTensorSlicesAsDeps()
+ status.assert_consumed().run_restore_ops()
+ self.assertAllEqual(
+ [1., 2., 3., 4.],
+ self.evaluate(restore_checkpoint.dep.combined))
+
+
+if __name__ == "__main__":
+ test.main()