aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/checkpoint
diff options
context:
space:
mode:
authorGravatar Allen Lavoie <allenl@google.com>2018-05-24 11:23:18 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-24 11:26:06 -0700
commitcdc1b4756a41dbfa7e7f39c466ff65dd88407cc0 (patch)
tree764bdf76cbd149d9331b15c466d790489572b8d7 /tensorflow/contrib/checkpoint
parent61dd76952e1e9a312105b7497f34d32d1a00a04b (diff)
Make the existing checkpointable data structure a CheckpointableDataStructure
Gives it better/more consistent handling of Layers. PiperOrigin-RevId: 197923880
Diffstat (limited to 'tensorflow/contrib/checkpoint')
-rw-r--r--tensorflow/contrib/checkpoint/python/containers.py7
-rw-r--r--tensorflow/contrib/checkpoint/python/containers_test.py9
2 files changed, 14 insertions, 2 deletions
diff --git a/tensorflow/contrib/checkpoint/python/containers.py b/tensorflow/contrib/checkpoint/python/containers.py
index 9807abae1f..4d3d531299 100644
--- a/tensorflow/contrib/checkpoint/python/containers.py
+++ b/tensorflow/contrib/checkpoint/python/containers.py
@@ -18,9 +18,10 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.training.checkpointable import base as checkpointable_lib
+from tensorflow.python.training.checkpointable import data_structures
-class UniqueNameTracker(checkpointable_lib.CheckpointableBase):
+class UniqueNameTracker(data_structures.CheckpointableDataStructure):
"""Adds dependencies on checkpointable objects with name hints.
Useful for creating dependencies with locally unique names.
@@ -41,6 +42,7 @@ class UniqueNameTracker(checkpointable_lib.CheckpointableBase):
"""
def __init__(self):
+ super(UniqueNameTracker, self).__init__()
self._maybe_initialize_checkpointable()
self._name_counts = {}
@@ -74,4 +76,5 @@ class UniqueNameTracker(checkpointable_lib.CheckpointableBase):
count += 1
candidate = _format_name(base_name, count)
self._name_counts[base_name] = count + 1
- return self._track_checkpointable(checkpointable, name=candidate)
+ self._track_value(checkpointable, name=candidate)
+ return checkpointable
diff --git a/tensorflow/contrib/checkpoint/python/containers_test.py b/tensorflow/contrib/checkpoint/python/containers_test.py
index 851a800588..3717d7f583 100644
--- a/tensorflow/contrib/checkpoint/python/containers_test.py
+++ b/tensorflow/contrib/checkpoint/python/containers_test.py
@@ -22,6 +22,8 @@ import six
from tensorflow.contrib.checkpoint.python import containers
from tensorflow.python.framework import test_util
+from tensorflow.python.keras import layers
+from tensorflow.python.ops import array_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.platform import test
from tensorflow.python.training.checkpointable import base as checkpointable
@@ -95,5 +97,12 @@ class UniqueNameTrackerTests(test.TestCase):
dependency_names,
["x", "x_1", "y", "slot_manager", "slotdeps", "save_counter"])
+ @test_util.run_in_graph_and_eager_modes()
+ def testLayers(self):
+ tracker = containers.UniqueNameTracker()
+ tracker.track(layers.Dense(3), "dense")
+ tracker.layers[0](array_ops.zeros([1, 1]))
+ self.assertEqual(2, len(tracker.trainable_weights))
+
if __name__ == "__main__":
test.main()