aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/checkpointable/util.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/training/checkpointable/util.py')
-rw-r--r--tensorflow/python/training/checkpointable/util.py40
1 files changed, 29 insertions, 11 deletions
diff --git a/tensorflow/python/training/checkpointable/util.py b/tensorflow/python/training/checkpointable/util.py
index 6ae5765b13..5d26a817d4 100644
--- a/tensorflow/python/training/checkpointable/util.py
+++ b/tensorflow/python/training/checkpointable/util.py
@@ -361,24 +361,42 @@ class _ObjectIdentityWeakKeyDictionary(_ObjectIdentityDictionary):
yield unwrapped
-class _ObjectIdentityWeakSet(collections.MutableSet):
- """Like weakref.WeakSet, but compares objects with "is"."""
+class _ObjectIdentitySet(collections.MutableSet):
+ """Like the built-in set, but compares objects with "is"."""
- def __init__(self):
- self._storage = set()
+ def __init__(self, *args):
+ self._storage = set([self._wrap_key(obj) for obj in list(*args)])
+
+ def _wrap_key(self, key):
+ return _ObjectIdentityWrapper(key)
def __contains__(self, key):
- return _WeakObjectIdentityWrapper(key) in self._storage
+ return self._wrap_key(key) in self._storage
def discard(self, key):
- self._storage.discard(_WeakObjectIdentityWrapper(key))
+ self._storage.discard(self._wrap_key(key))
def add(self, key):
- self._storage.add(_WeakObjectIdentityWrapper(key))
+ self._storage.add(self._wrap_key(key))
+
+ def __len__(self):
+ return len(self._storage)
+
+ def __iter__(self):
+ keys = list(self._storage)
+ for key in keys:
+ yield key.unwrapped
+
+
+class _ObjectIdentityWeakSet(_ObjectIdentitySet):
+ """Like weakref.WeakSet, but compares objects with "is"."""
+
+ def _wrap_key(self, key):
+ return _WeakObjectIdentityWrapper(key)
def __len__(self):
# Iterate, discarding old weak refs
- return len(list(self))
+ return len([_ for _ in self])
def __iter__(self):
keys = list(self._storage)
@@ -747,7 +765,7 @@ def capture_dependencies(template):
initial_value=initializer,
name=name,
**inner_kwargs)
- if name.startswith(name_prefix):
+ if name is not None and name.startswith(name_prefix):
scope_stripped_name = name[len(name_prefix) + 1:]
if not checkpointable_parent:
return template._add_variable_with_custom_getter( # pylint: disable=protected-access
@@ -857,8 +875,8 @@ class CheckpointLoadStatus(_LoadStatus):
for checkpointable_object in list_objects(self._root_checkpointable):
self._checkpoint.all_python_objects.add(checkpointable_object)
unused_python_objects = (
- set(self._checkpoint.all_python_objects)
- - set(self._checkpoint.object_by_proto_id.values()))
+ _ObjectIdentitySet(self._checkpoint.all_python_objects)
+ - _ObjectIdentitySet(self._checkpoint.object_by_proto_id.values()))
if unused_python_objects:
raise AssertionError(
("Some Python objects were not bound to checkpointed values, likely "