diff options
author | Allen Lavoie <allenl@google.com> | 2018-06-29 14:02:26 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-29 14:04:49 -0700 |
commit | dcaa037571ab0933977f70574f4f78875155ae20 (patch) | |
tree | 4968e1966ca334f42296beae6cb1ecd8d483215e /tensorflow/contrib/checkpoint | |
parent | b3c163a754574faed4337f869c2d650a9f45c09c (diff) |
Auto tracking for Python lists assigned to attributes of Model/Checkpointable
Conceptually lists just get replaced with a list-like wrapper. A shallow copy is maintained for error checking (since appends to it aren't monitored, we can't do restore-on-create for variables unless it's being modified through the wrapper).
There are lots of other details. I gave up on generalizing our isinstance(obj, (list, tuple)) checks and just subclassed list. Behaving like a list means the type should be unhashable, which requires some workarounds when we're collecting objects (object-identity collections, and object-identity versions of weak reference containers).
Adds a decorator for exempting whole methods from automatic dependency tracking so we don't need to track down every last self.inputs = [] statement to avoid polluting dependencies.
There's a TODO for tuples and dictionaries.
PiperOrigin-RevId: 202703271
Diffstat (limited to 'tensorflow/contrib/checkpoint')
-rw-r--r-- | tensorflow/contrib/checkpoint/__init__.py | 2 | ||||
-rw-r--r-- | tensorflow/contrib/checkpoint/python/containers_test.py | 3 |
2 files changed, 3 insertions, 2 deletions
diff --git a/tensorflow/contrib/checkpoint/__init__.py b/tensorflow/contrib/checkpoint/__init__.py index 8c1ce5c2a2..2fbaa31d5e 100644 --- a/tensorflow/contrib/checkpoint/__init__.py +++ b/tensorflow/contrib/checkpoint/__init__.py @@ -44,8 +44,8 @@ from tensorflow.core.protobuf.checkpointable_object_graph_pb2 import Checkpointa from tensorflow.python.training.checkpointable.base import CheckpointableBase from tensorflow.python.training.checkpointable.data_structures import List from tensorflow.python.training.checkpointable.data_structures import Mapping +from tensorflow.python.training.checkpointable.data_structures import NoDependency from tensorflow.python.training.checkpointable.tracking import Checkpointable -from tensorflow.python.training.checkpointable.tracking import NoDependency from tensorflow.python.training.checkpointable.util import capture_dependencies from tensorflow.python.training.checkpointable.util import list_objects from tensorflow.python.training.checkpointable.util import object_metadata diff --git a/tensorflow/contrib/checkpoint/python/containers_test.py b/tensorflow/contrib/checkpoint/python/containers_test.py index 64d056bd68..ac85c7be80 100644 --- a/tensorflow/contrib/checkpoint/python/containers_test.py +++ b/tensorflow/contrib/checkpoint/python/containers_test.py @@ -26,6 +26,7 @@ from tensorflow.python.keras import layers from tensorflow.python.ops import array_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.platform import test +from tensorflow.python.training.checkpointable import data_structures from tensorflow.python.training.checkpointable import tracking from tensorflow.python.training.checkpointable import util @@ -79,7 +80,7 @@ class UniqueNameTrackerTests(test.TestCase): resource_variable_ops.ResourceVariable(4.), "y")) slots.append(slotdeps.track( resource_variable_ops.ResourceVariable(5.), "x")) - self.slots = slots + self.slots = data_structures.NoDependency(slots) manager = SlotManager() self.evaluate([v.initializer for v in manager.slots]) |