diff options
author | Allen Lavoie <allenl@google.com> | 2018-05-11 14:45:36 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-05-11 14:48:25 -0700 |
commit | a8dbfc607ecaa54b032573a0c7033cb4c9d033a2 (patch) | |
tree | dc5970056d4a8e32d1d7732d0ff7b03c28f707fe /tensorflow/contrib/checkpoint | |
parent | ddb8fe491faccfdf219a5d9b7ba959c98ae38f33 (diff) |
Checkpointable: Add UniqueNameTracker for managing dependencies on arbitrarily named objects
Makes generating object-unique dependency names easier, which will hopefully discourage people from using Graph-global names with Checkpointable.
PiperOrigin-RevId: 196311633
Diffstat (limited to 'tensorflow/contrib/checkpoint')
-rw-r--r-- | tensorflow/contrib/checkpoint/__init__.py | 11 | ||||
-rw-r--r-- | tensorflow/contrib/checkpoint/python/BUILD | 23 | ||||
-rw-r--r-- | tensorflow/contrib/checkpoint/python/containers.py | 77 | ||||
-rw-r--r-- | tensorflow/contrib/checkpoint/python/containers_test.py | 100 |
4 files changed, 208 insertions, 3 deletions
diff --git a/tensorflow/contrib/checkpoint/__init__.py b/tensorflow/contrib/checkpoint/__init__.py index e529b25b3c..c5f7072aea 100644 --- a/tensorflow/contrib/checkpoint/__init__.py +++ b/tensorflow/contrib/checkpoint/__init__.py @@ -14,22 +14,27 @@ # ============================================================================== """Tools for working with object-based checkpoints. - -For creating and managing dependencies: -@@CheckpointableObjectGraph +Visualization and inspection: @@dot_graph_from_checkpoint @@object_metadata + +Creating and managing dependencies: +@@Checkpointable +@@CheckpointableObjectGraph @@NoDependency @@split_dependency +@@UniqueNameTracker """ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.contrib.checkpoint.python.containers import UniqueNameTracker from tensorflow.contrib.checkpoint.python.split_dependency import split_dependency from tensorflow.contrib.checkpoint.python.visualize import dot_graph_from_checkpoint from tensorflow.core.protobuf.checkpointable_object_graph_pb2 import CheckpointableObjectGraph +from tensorflow.python.training.checkpointable import Checkpointable from tensorflow.python.training.checkpointable import NoDependency from tensorflow.python.training.checkpointable_utils import object_metadata diff --git a/tensorflow/contrib/checkpoint/python/BUILD b/tensorflow/contrib/checkpoint/python/BUILD index a5681ffa61..cbb9852ccf 100644 --- a/tensorflow/contrib/checkpoint/python/BUILD +++ b/tensorflow/contrib/checkpoint/python/BUILD @@ -8,12 +8,35 @@ py_library( name = "checkpoint", srcs_version = "PY2AND3", deps = [ + ":containers", ":split_dependency", ":visualize", ], ) py_library( + name = "containers", + srcs = ["containers.py"], + srcs_version = "PY2AND3", + visibility = ["//tensorflow:internal"], + deps = ["//tensorflow/python:checkpointable"], +) + +py_test( + name = "containers_test", + srcs = ["containers_test.py"], + deps = [ + ":containers", + "//tensorflow/python:checkpointable", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:resource_variable_ops", + "//tensorflow/python:training", + "@six_archive//:six", + ], +) + +py_library( name = "split_dependency", srcs = ["split_dependency.py"], srcs_version = "PY2AND3", diff --git a/tensorflow/contrib/checkpoint/python/containers.py b/tensorflow/contrib/checkpoint/python/containers.py new file mode 100644 index 0000000000..82aa04e38f --- /dev/null +++ b/tensorflow/contrib/checkpoint/python/containers.py @@ -0,0 +1,77 @@ +"""Checkpointable data structures.""" +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.training import checkpointable as checkpointable_lib + + +class UniqueNameTracker(checkpointable_lib.CheckpointableBase): + """Adds dependencies on checkpointable objects with name hints. + + Useful for creating dependencies with locally unique names. + + Example usage: + ```python + class SlotManager(tf.contrib.checkpoint.Checkpointable): + + def __init__(self): + # Create a dependency named "slotdeps" on the container. + self.slotdeps = tf.contrib.checkpoint.UniqueNameTracker() + slotdeps = self.slotdeps + slots = [] + slots.append(slotdeps.track(tfe.Variable(3.), "x")) # Named "x" + slots.append(slotdeps.track(tfe.Variable(4.), "y")) + slots.append(slotdeps.track(tfe.Variable(5.), "x")) # Named "x_1" + ``` + """ + + def __init__(self): + self._maybe_initialize_checkpointable() + self._name_counts = {} + + def track(self, checkpointable, base_name): + """Add a dependency on `checkpointable`. + + Args: + checkpointable: An object to add a checkpoint dependency on. + base_name: A name hint, which is uniquified to determine the dependency + name. + Returns: + `checkpointable`, for chaining. + Raises: + ValueError: If `checkpointable` is not a checkpointable object. + """ + + if not isinstance(checkpointable, checkpointable_lib.CheckpointableBase): + raise ValueError( + ("Expected a checkpointable value, got %s which does not inherit " + "from CheckpointableBase.") % (checkpointable,)) + + def _format_name(prefix, number): + if number > 0: + return "%s_%d" % (prefix, number) + else: + return prefix + + count = self._name_counts.get(base_name, 0) + candidate = _format_name(base_name, count) + while self._lookup_dependency(candidate) is not None: + count += 1 + candidate = _format_name(base_name, count) + self._name_counts[base_name] = count + 1 + return self._track_checkpointable(checkpointable, name=candidate) diff --git a/tensorflow/contrib/checkpoint/python/containers_test.py b/tensorflow/contrib/checkpoint/python/containers_test.py new file mode 100644 index 0000000000..15775f4cb3 --- /dev/null +++ b/tensorflow/contrib/checkpoint/python/containers_test.py @@ -0,0 +1,100 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +import six + +from tensorflow.contrib.checkpoint.python import containers +from tensorflow.python.framework import test_util +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.platform import test +from tensorflow.python.training import checkpointable +from tensorflow.python.training import checkpointable_utils +from tensorflow.python.training.checkpointable_utils import object_metadata + + +class UniqueNameTrackerTests(test.TestCase): + + @test_util.run_in_graph_and_eager_modes() + def testNames(self): + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + + x1 = resource_variable_ops.ResourceVariable(2.) + x2 = resource_variable_ops.ResourceVariable(3.) + x3 = resource_variable_ops.ResourceVariable(4.) + y = resource_variable_ops.ResourceVariable(5.) + slots = containers.UniqueNameTracker() + slots.track(x1, "x") + slots.track(x2, "x") + slots.track(x3, "x_1") + slots.track(y, "y") + self.evaluate((x1.initializer, x2.initializer, x3.initializer, + y.initializer)) + save_root = checkpointable_utils.Checkpoint(slots=slots) + save_path = save_root.save(checkpoint_prefix) + + restore_slots = checkpointable.Checkpointable() + restore_root = checkpointable_utils.Checkpoint( + slots=restore_slots) + status = restore_root.restore(save_path) + restore_slots.x = resource_variable_ops.ResourceVariable(0.) + restore_slots.x_1 = resource_variable_ops.ResourceVariable(0.) + restore_slots.x_1_1 = resource_variable_ops.ResourceVariable(0.) + restore_slots.y = resource_variable_ops.ResourceVariable(0.) + status.assert_consumed().run_restore_ops() + self.assertEqual(2., self.evaluate(restore_slots.x)) + self.assertEqual(3., self.evaluate(restore_slots.x_1)) + self.assertEqual(4., self.evaluate(restore_slots.x_1_1)) + self.assertEqual(5., self.evaluate(restore_slots.y)) + + @test_util.run_in_graph_and_eager_modes() + def testExample(self): + class SlotManager(checkpointable.Checkpointable): + + def __init__(self): + self.slotdeps = containers.UniqueNameTracker() + slotdeps = self.slotdeps + slots = [] + slots.append(slotdeps.track( + resource_variable_ops.ResourceVariable(3.), "x")) + slots.append(slotdeps.track( + resource_variable_ops.ResourceVariable(4.), "y")) + slots.append(slotdeps.track( + resource_variable_ops.ResourceVariable(5.), "x")) + self.slots = slots + + manager = SlotManager() + self.evaluate([v.initializer for v in manager.slots]) + checkpoint = checkpointable_utils.Checkpoint(slot_manager=manager) + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + save_path = checkpoint.save(checkpoint_prefix) + metadata = object_metadata(save_path) + dependency_names = [] + for node in metadata.nodes: + for child in node.children: + dependency_names.append(child.local_name) + six.assertCountEqual( + self, + dependency_names, + ["x", "x_1", "y", "slot_manager", "slotdeps", "save_counter"]) + +if __name__ == "__main__": + test.main() |