diff options
author | Allen Lavoie <allenl@google.com> | 2018-04-19 09:54:40 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-04-19 09:57:23 -0700 |
commit | 436f1434060d7f370baae9661baacc6cf27415ec (patch) | |
tree | 9276bf419abb2e1bcfccdb7c43bea7ffbab80283 /tensorflow/contrib/checkpoint | |
parent | c173157bdc132460c6f424a9803221e74fc73f59 (diff) |
Create a skeleton tf.contrib.checkpoint.
My plan for this is to incubate tools for working with object-based checkpoints:
- Tools for managing dependency graphs, e.g. checkpointable lists/dictionaries
- Inspecting/visualizing checkpoints
- Listing variables and gathering initializers from a Checkpointable object
and its dependencies
- Verifying all variables are accessible as dependencies, which should make
converting existing graph building Saver uses easier/safer.
This CL includes none of those things, it just moves the split_dependency tool
here instead of contrib/eager.
PiperOrigin-RevId: 193531292
Diffstat (limited to 'tensorflow/contrib/checkpoint')
-rw-r--r-- | tensorflow/contrib/checkpoint/README.md | 2 | ||||
-rw-r--r-- | tensorflow/contrib/checkpoint/__init__.py | 29 | ||||
-rw-r--r-- | tensorflow/contrib/checkpoint/python/BUILD | 29 | ||||
-rw-r--r-- | tensorflow/contrib/checkpoint/python/split_dependency.py | 136 | ||||
-rw-r--r-- | tensorflow/contrib/checkpoint/python/split_dependency_test.py | 112 |
5 files changed, 308 insertions, 0 deletions
diff --git a/tensorflow/contrib/checkpoint/README.md b/tensorflow/contrib/checkpoint/README.md new file mode 100644 index 0000000000..d35c5bae3b --- /dev/null +++ b/tensorflow/contrib/checkpoint/README.md @@ -0,0 +1,2 @@ +Tools for working with object-based checkpoints produced by +`tf.train.Checkpoint`. diff --git a/tensorflow/contrib/checkpoint/__init__.py b/tensorflow/contrib/checkpoint/__init__.py new file mode 100644 index 0000000000..70d7d2d8d7 --- /dev/null +++ b/tensorflow/contrib/checkpoint/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tools for working with object-based checkpoints. + + +For creating and managing dependencies: +@@split_dependency +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.checkpoint.python.split_dependency import split_dependency +from tensorflow.python.util.all_util import remove_undocumented + +remove_undocumented(module_name=__name__) diff --git a/tensorflow/contrib/checkpoint/python/BUILD b/tensorflow/contrib/checkpoint/python/BUILD new file mode 100644 index 0000000000..d57b01aab2 --- /dev/null +++ b/tensorflow/contrib/checkpoint/python/BUILD @@ -0,0 +1,29 @@ +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//tensorflow:internal"]) + +load("//tensorflow:tensorflow.bzl", "py_test") + +py_library( + name = "split_dependency", + srcs = ["split_dependency.py"], + srcs_version = "PY2AND3", + visibility = ["//tensorflow:internal"], + deps = [ + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:training", + ], +) + +py_test( + name = "split_dependency_test", + srcs = ["split_dependency_test.py"], + deps = [ + ":split_dependency", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:resource_variable_ops", + "//tensorflow/python:training", + "//tensorflow/python/eager:test", + ], +) diff --git a/tensorflow/contrib/checkpoint/python/split_dependency.py b/tensorflow/contrib/checkpoint/python/split_dependency.py new file mode 100644 index 0000000000..3aec8c96e9 --- /dev/null +++ b/tensorflow/contrib/checkpoint/python/split_dependency.py @@ -0,0 +1,136 @@ +"""Utility for creating multiple dependencies with synchronized save/restore.""" +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools + +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.training import checkpointable as checkpointable +from tensorflow.python.training import saver as saver_lib + + +class _CallbackSaveable(saver_lib.BaseSaverBuilder.SaveableObject): + """Wraps save and restore callbacks as a `SaveableObject`.""" + + def __init__(self, name, dtype, save_callback, restore_callback): + self._restore_callback = restore_callback + spec = saver_lib.BaseSaverBuilder.SaveSpec( + tensor=save_callback, + slice_spec="", + name=name, + dtype=dtype) + super(_CallbackSaveable, self).__init__( + save_callback, [spec], name) + + def restore(self, restored_tensors, restored_shapes): + """Restore the same value into both variables.""" + tensor, = restored_tensors + return self._restore_callback(tensor) + + +class _SplitDependency(checkpointable.CheckpointableBase): + """Looks like a regular variable while synchronizing save/restores.""" + + def __init__(self, save_buffer, restore_buffer, name, dtype, num_components, + fill_save_buffer_fn, consume_restore_buffer_fn): + self._save_buffer = save_buffer + self._restore_buffer = restore_buffer + self._name = name + self._dtype = dtype + self._num_components = num_components + self._fill_save_buffer_fn = fill_save_buffer_fn + self._consume_restore_buffer_fn = consume_restore_buffer_fn + + def _save(self): + """Pull from the shared buffer, populating it if necessary.""" + if self._name not in self._save_buffer: + if self._save_buffer: + raise AssertionError( + ("Split dependency %s (%s) unsynchronized. Split dependencies must " + "be saved together.") % (self._name, self)) + self._fill_save_buffer_fn(self._save_buffer) + return self._save_buffer.pop(self._name) + + def _restore(self, tensor): + """Push into the shared buffer, flushing it if necessary.""" + if self._name in self._restore_buffer: + raise AssertionError( + ("Split dependency %s (%s) unsynchronized. Split dependencies must " + "be restored together.") % (self._name, self)) + self._restore_buffer[self._name] = tensor + if len(self._restore_buffer) == self._num_components: + op = self._consume_restore_buffer_fn(self._restore_buffer) + self._restore_buffer.clear() + return op + else: + return control_flow_ops.no_op() + + def _gather_saveables_for_checkpoint(self): + """Looks to Checkpointable like a regular variable.""" + return { + checkpointable.VARIABLE_VALUE_KEY: + functools.partial(_CallbackSaveable, + dtype=self._dtype, + save_callback=self._save, + restore_callback=self._restore) + } + + +def split_dependency(component_names, component_dtypes, + fill_save_buffer_fn, consume_restore_buffer_fn): + """Creates multiple dependencies with a synchronized save/restore. + + Useful when a single op produces `Tensor`s which should each be saved under + different objects, or when `Tensor`s saved with many different objects need to + be restored together as inputs to a single op (i.e. an object which uses a + single fused op may be swapped out for a subgraph of objects, and these two + programs are checkpoint compatible). + + Args: + component_names: A sequence of names for the split + dependencies. `fill_save_buffer_fn` must add these keys to the dictionary + it is passed, and `consume_restore_buffer_fn` will receive a dictionary + with these keys. + component_dtypes: Data types for the `Tensor`s being saved and restored, a + sequence corresponding to `component_names`. + fill_save_buffer_fn: A function which takes an empty dictionary as an + argument and adds `Tensor`s with `component_names` as keys. These + `Tensor`s will be saved as if they were individual variables. + consume_restore_buffer_fn: A function which takes a dictionary with + `component_names` as keys mapping to restored individual `Tensor`s and + returns a restore op (or if executing eagerly, runs the restoration and + may return `None`). + + Returns: + A dictionary mapping from names to Checkpointable objects. If one is + reachable from an object as a dependency, the others should be too; adding + dependencies on some but not all of the objects will result in errors. + """ + save_buffer = {} + restore_buffer = {} + split_dependencies = {} + for name, dtype in zip(component_names, component_dtypes): + split_dependencies[name] = _SplitDependency( + save_buffer=save_buffer, + restore_buffer=restore_buffer, + name=name, + dtype=dtype, + num_components=len(component_names), + fill_save_buffer_fn=fill_save_buffer_fn, + consume_restore_buffer_fn=consume_restore_buffer_fn) + return split_dependencies diff --git a/tensorflow/contrib/checkpoint/python/split_dependency_test.py b/tensorflow/contrib/checkpoint/python/split_dependency_test.py new file mode 100644 index 0000000000..cb964c80e9 --- /dev/null +++ b/tensorflow/contrib/checkpoint/python/split_dependency_test.py @@ -0,0 +1,112 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from tensorflow.contrib.checkpoint.python import split_dependency +from tensorflow.python.eager import test +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.training import checkpointable +from tensorflow.python.training import checkpointable_utils + + +def _split_variable_closure(variable): + def _fill_save_buffer_fn(save_buffer): + save_buffer["first_half"] = variable[:2] + save_buffer["second_half"] = variable[2:] + return _fill_save_buffer_fn + + +def _combine_variable_closure(variable): + def _consume_restore_buffer_fn(restore_buffer): + return variable.assign( + array_ops.concat([restore_buffer["first_half"], + restore_buffer["second_half"]], + axis=0)) + return _consume_restore_buffer_fn + + +class SaveTensorSlicesAsDeps(checkpointable.CheckpointableBase): + + def __init__(self): + self.combined = resource_variable_ops.ResourceVariable([0., 0., 0., 0.]) + split_dependencies = split_dependency.split_dependency( + component_names=("first_half", "second_half"), + component_dtypes=(self.combined.dtype,) * 2, + fill_save_buffer_fn=_split_variable_closure( + self.combined), + consume_restore_buffer_fn=_combine_variable_closure( + self.combined)) + for name, dep in split_dependencies.items(): + self._track_checkpointable(dep, name=name) + + +class HasRegularDeps(checkpointable.Checkpointable): + + def __init__(self): + self.first_half = resource_variable_ops.ResourceVariable([0., 0.]) + self.second_half = resource_variable_ops.ResourceVariable([0., 0.]) + + +class OnlyOneDep(checkpointable.Checkpointable): + + def __init__(self): + self.first_half = resource_variable_ops.ResourceVariable([0., 0.]) + + +class SplitTests(test.TestCase): + + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) + def testSaveRestoreSplitDep(self): + save_checkpoint = checkpointable_utils.Checkpoint( + dep=SaveTensorSlicesAsDeps()) + self.evaluate(save_checkpoint.dep.combined.assign([1., 2., 3., 4.])) + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + save_path = save_checkpoint.save(checkpoint_prefix) + + regular_deps = HasRegularDeps() + regular_restore_checkpoint = checkpointable_utils.Checkpoint( + dep=regular_deps) + regular_restore_checkpoint.restore( + save_path).assert_consumed().run_restore_ops() + self.assertAllEqual([1., 2.], self.evaluate(regular_deps.first_half)) + self.assertAllEqual([3., 4.], self.evaluate(regular_deps.second_half)) + + one_dep = OnlyOneDep() + one_dep_restore_checkpoint = checkpointable_utils.Checkpoint(dep=one_dep) + status = one_dep_restore_checkpoint.restore(save_path) + with self.assertRaises(AssertionError): + # Missing the second dependency. + status.assert_consumed() + status.run_restore_ops() + self.assertAllEqual([1., 2.], self.evaluate(one_dep.first_half)) + + restore_checkpoint = checkpointable_utils.Checkpoint() + status = restore_checkpoint.restore(save_path) + restore_checkpoint.dep = SaveTensorSlicesAsDeps() + status.assert_consumed().run_restore_ops() + self.assertAllEqual( + [1., 2., 3., 4.], + self.evaluate(restore_checkpoint.dep.combined)) + + +if __name__ == "__main__": + test.main() |