aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/checkpoint
diff options
context:
space:
mode:
authorGravatar Allen Lavoie <allenl@google.com>2018-05-11 14:45:36 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-11 14:48:25 -0700
commita8dbfc607ecaa54b032573a0c7033cb4c9d033a2 (patch)
treedc5970056d4a8e32d1d7732d0ff7b03c28f707fe /tensorflow/contrib/checkpoint
parentddb8fe491faccfdf219a5d9b7ba959c98ae38f33 (diff)
Checkpointable: Add UniqueNameTracker for managing dependencies on arbitrarily named objects
Makes generating object-unique dependency names easier, which will hopefully discourage people from using Graph-global names with Checkpointable. PiperOrigin-RevId: 196311633
Diffstat (limited to 'tensorflow/contrib/checkpoint')
-rw-r--r--tensorflow/contrib/checkpoint/__init__.py11
-rw-r--r--tensorflow/contrib/checkpoint/python/BUILD23
-rw-r--r--tensorflow/contrib/checkpoint/python/containers.py77
-rw-r--r--tensorflow/contrib/checkpoint/python/containers_test.py100
4 files changed, 208 insertions, 3 deletions
diff --git a/tensorflow/contrib/checkpoint/__init__.py b/tensorflow/contrib/checkpoint/__init__.py
index e529b25b3c..c5f7072aea 100644
--- a/tensorflow/contrib/checkpoint/__init__.py
+++ b/tensorflow/contrib/checkpoint/__init__.py
@@ -14,22 +14,27 @@
# ==============================================================================
"""Tools for working with object-based checkpoints.
-
-For creating and managing dependencies:
-@@CheckpointableObjectGraph
+Visualization and inspection:
@@dot_graph_from_checkpoint
@@object_metadata
+
+Creating and managing dependencies:
+@@Checkpointable
+@@CheckpointableObjectGraph
@@NoDependency
@@split_dependency
+@@UniqueNameTracker
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.contrib.checkpoint.python.containers import UniqueNameTracker
from tensorflow.contrib.checkpoint.python.split_dependency import split_dependency
from tensorflow.contrib.checkpoint.python.visualize import dot_graph_from_checkpoint
from tensorflow.core.protobuf.checkpointable_object_graph_pb2 import CheckpointableObjectGraph
+from tensorflow.python.training.checkpointable import Checkpointable
from tensorflow.python.training.checkpointable import NoDependency
from tensorflow.python.training.checkpointable_utils import object_metadata
diff --git a/tensorflow/contrib/checkpoint/python/BUILD b/tensorflow/contrib/checkpoint/python/BUILD
index a5681ffa61..cbb9852ccf 100644
--- a/tensorflow/contrib/checkpoint/python/BUILD
+++ b/tensorflow/contrib/checkpoint/python/BUILD
@@ -8,12 +8,35 @@ py_library(
name = "checkpoint",
srcs_version = "PY2AND3",
deps = [
+ ":containers",
":split_dependency",
":visualize",
],
)
py_library(
+ name = "containers",
+ srcs = ["containers.py"],
+ srcs_version = "PY2AND3",
+ visibility = ["//tensorflow:internal"],
+ deps = ["//tensorflow/python:checkpointable"],
+)
+
+py_test(
+ name = "containers_test",
+ srcs = ["containers_test.py"],
+ deps = [
+ ":containers",
+ "//tensorflow/python:checkpointable",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:resource_variable_ops",
+ "//tensorflow/python:training",
+ "@six_archive//:six",
+ ],
+)
+
+py_library(
name = "split_dependency",
srcs = ["split_dependency.py"],
srcs_version = "PY2AND3",
diff --git a/tensorflow/contrib/checkpoint/python/containers.py b/tensorflow/contrib/checkpoint/python/containers.py
new file mode 100644
index 0000000000..82aa04e38f
--- /dev/null
+++ b/tensorflow/contrib/checkpoint/python/containers.py
@@ -0,0 +1,77 @@
+"""Checkpointable data structures."""
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.training import checkpointable as checkpointable_lib
+
+
+class UniqueNameTracker(checkpointable_lib.CheckpointableBase):
+ """Adds dependencies on checkpointable objects with name hints.
+
+ Useful for creating dependencies with locally unique names.
+
+ Example usage:
+ ```python
+ class SlotManager(tf.contrib.checkpoint.Checkpointable):
+
+ def __init__(self):
+ # Create a dependency named "slotdeps" on the container.
+ self.slotdeps = tf.contrib.checkpoint.UniqueNameTracker()
+ slotdeps = self.slotdeps
+ slots = []
+ slots.append(slotdeps.track(tfe.Variable(3.), "x")) # Named "x"
+ slots.append(slotdeps.track(tfe.Variable(4.), "y"))
+ slots.append(slotdeps.track(tfe.Variable(5.), "x")) # Named "x_1"
+ ```
+ """
+
+ def __init__(self):
+ self._maybe_initialize_checkpointable()
+ self._name_counts = {}
+
+ def track(self, checkpointable, base_name):
+ """Add a dependency on `checkpointable`.
+
+ Args:
+ checkpointable: An object to add a checkpoint dependency on.
+ base_name: A name hint, which is uniquified to determine the dependency
+ name.
+ Returns:
+ `checkpointable`, for chaining.
+ Raises:
+ ValueError: If `checkpointable` is not a checkpointable object.
+ """
+
+ if not isinstance(checkpointable, checkpointable_lib.CheckpointableBase):
+ raise ValueError(
+ ("Expected a checkpointable value, got %s which does not inherit "
+ "from CheckpointableBase.") % (checkpointable,))
+
+ def _format_name(prefix, number):
+ if number > 0:
+ return "%s_%d" % (prefix, number)
+ else:
+ return prefix
+
+ count = self._name_counts.get(base_name, 0)
+ candidate = _format_name(base_name, count)
+ while self._lookup_dependency(candidate) is not None:
+ count += 1
+ candidate = _format_name(base_name, count)
+ self._name_counts[base_name] = count + 1
+ return self._track_checkpointable(checkpointable, name=candidate)
diff --git a/tensorflow/contrib/checkpoint/python/containers_test.py b/tensorflow/contrib/checkpoint/python/containers_test.py
new file mode 100644
index 0000000000..15775f4cb3
--- /dev/null
+++ b/tensorflow/contrib/checkpoint/python/containers_test.py
@@ -0,0 +1,100 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+import six
+
+from tensorflow.contrib.checkpoint.python import containers
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.platform import test
+from tensorflow.python.training import checkpointable
+from tensorflow.python.training import checkpointable_utils
+from tensorflow.python.training.checkpointable_utils import object_metadata
+
+
+class UniqueNameTrackerTests(test.TestCase):
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testNames(self):
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+
+ x1 = resource_variable_ops.ResourceVariable(2.)
+ x2 = resource_variable_ops.ResourceVariable(3.)
+ x3 = resource_variable_ops.ResourceVariable(4.)
+ y = resource_variable_ops.ResourceVariable(5.)
+ slots = containers.UniqueNameTracker()
+ slots.track(x1, "x")
+ slots.track(x2, "x")
+ slots.track(x3, "x_1")
+ slots.track(y, "y")
+ self.evaluate((x1.initializer, x2.initializer, x3.initializer,
+ y.initializer))
+ save_root = checkpointable_utils.Checkpoint(slots=slots)
+ save_path = save_root.save(checkpoint_prefix)
+
+ restore_slots = checkpointable.Checkpointable()
+ restore_root = checkpointable_utils.Checkpoint(
+ slots=restore_slots)
+ status = restore_root.restore(save_path)
+ restore_slots.x = resource_variable_ops.ResourceVariable(0.)
+ restore_slots.x_1 = resource_variable_ops.ResourceVariable(0.)
+ restore_slots.x_1_1 = resource_variable_ops.ResourceVariable(0.)
+ restore_slots.y = resource_variable_ops.ResourceVariable(0.)
+ status.assert_consumed().run_restore_ops()
+ self.assertEqual(2., self.evaluate(restore_slots.x))
+ self.assertEqual(3., self.evaluate(restore_slots.x_1))
+ self.assertEqual(4., self.evaluate(restore_slots.x_1_1))
+ self.assertEqual(5., self.evaluate(restore_slots.y))
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testExample(self):
+ class SlotManager(checkpointable.Checkpointable):
+
+ def __init__(self):
+ self.slotdeps = containers.UniqueNameTracker()
+ slotdeps = self.slotdeps
+ slots = []
+ slots.append(slotdeps.track(
+ resource_variable_ops.ResourceVariable(3.), "x"))
+ slots.append(slotdeps.track(
+ resource_variable_ops.ResourceVariable(4.), "y"))
+ slots.append(slotdeps.track(
+ resource_variable_ops.ResourceVariable(5.), "x"))
+ self.slots = slots
+
+ manager = SlotManager()
+ self.evaluate([v.initializer for v in manager.slots])
+ checkpoint = checkpointable_utils.Checkpoint(slot_manager=manager)
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+ save_path = checkpoint.save(checkpoint_prefix)
+ metadata = object_metadata(save_path)
+ dependency_names = []
+ for node in metadata.nodes:
+ for child in node.children:
+ dependency_names.append(child.local_name)
+ six.assertCountEqual(
+ self,
+ dependency_names,
+ ["x", "x_1", "y", "slot_manager", "slotdeps", "save_counter"])
+
+if __name__ == "__main__":
+ test.main()