diff options
author | 2018-05-24 11:23:18 -0700 | |
---|---|---|
committer | 2018-05-24 11:26:06 -0700 | |
commit | cdc1b4756a41dbfa7e7f39c466ff65dd88407cc0 (patch) | |
tree | 764bdf76cbd149d9331b15c466d790489572b8d7 /tensorflow/contrib/checkpoint | |
parent | 61dd76952e1e9a312105b7497f34d32d1a00a04b (diff) |
Make the existing checkpointable data structure a CheckpointableDataStructure
Gives it better/more consistent handling of Layers.
PiperOrigin-RevId: 197923880
Diffstat (limited to 'tensorflow/contrib/checkpoint')
-rw-r--r-- | tensorflow/contrib/checkpoint/python/containers.py | 7 | ||||
-rw-r--r-- | tensorflow/contrib/checkpoint/python/containers_test.py | 9 |
2 files changed, 14 insertions, 2 deletions
diff --git a/tensorflow/contrib/checkpoint/python/containers.py b/tensorflow/contrib/checkpoint/python/containers.py index 9807abae1f..4d3d531299 100644 --- a/tensorflow/contrib/checkpoint/python/containers.py +++ b/tensorflow/contrib/checkpoint/python/containers.py @@ -18,9 +18,10 @@ from __future__ import division from __future__ import print_function from tensorflow.python.training.checkpointable import base as checkpointable_lib +from tensorflow.python.training.checkpointable import data_structures -class UniqueNameTracker(checkpointable_lib.CheckpointableBase): +class UniqueNameTracker(data_structures.CheckpointableDataStructure): """Adds dependencies on checkpointable objects with name hints. Useful for creating dependencies with locally unique names. @@ -41,6 +42,7 @@ class UniqueNameTracker(checkpointable_lib.CheckpointableBase): """ def __init__(self): + super(UniqueNameTracker, self).__init__() self._maybe_initialize_checkpointable() self._name_counts = {} @@ -74,4 +76,5 @@ class UniqueNameTracker(checkpointable_lib.CheckpointableBase): count += 1 candidate = _format_name(base_name, count) self._name_counts[base_name] = count + 1 - return self._track_checkpointable(checkpointable, name=candidate) + self._track_value(checkpointable, name=candidate) + return checkpointable diff --git a/tensorflow/contrib/checkpoint/python/containers_test.py b/tensorflow/contrib/checkpoint/python/containers_test.py index 851a800588..3717d7f583 100644 --- a/tensorflow/contrib/checkpoint/python/containers_test.py +++ b/tensorflow/contrib/checkpoint/python/containers_test.py @@ -22,6 +22,8 @@ import six from tensorflow.contrib.checkpoint.python import containers from tensorflow.python.framework import test_util +from tensorflow.python.keras import layers +from tensorflow.python.ops import array_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.platform import test from tensorflow.python.training.checkpointable import base as checkpointable @@ -95,5 +97,12 @@ class UniqueNameTrackerTests(test.TestCase): dependency_names, ["x", "x_1", "y", "slot_manager", "slotdeps", "save_counter"]) + @test_util.run_in_graph_and_eager_modes() + def testLayers(self): + tracker = containers.UniqueNameTracker() + tracker.track(layers.Dense(3), "dense") + tracker.layers[0](array_ops.zeros([1, 1])) + self.assertEqual(2, len(tracker.trainable_weights)) + if __name__ == "__main__": test.main() |