diff options
Diffstat (limited to 'tensorflow/python/training/checkpointable/util.py')
-rw-r--r-- | tensorflow/python/training/checkpointable/util.py | 40 |
1 files changed, 29 insertions, 11 deletions
diff --git a/tensorflow/python/training/checkpointable/util.py b/tensorflow/python/training/checkpointable/util.py index 6ae5765b13..5d26a817d4 100644 --- a/tensorflow/python/training/checkpointable/util.py +++ b/tensorflow/python/training/checkpointable/util.py @@ -361,24 +361,42 @@ class _ObjectIdentityWeakKeyDictionary(_ObjectIdentityDictionary): yield unwrapped -class _ObjectIdentityWeakSet(collections.MutableSet): - """Like weakref.WeakSet, but compares objects with "is".""" +class _ObjectIdentitySet(collections.MutableSet): + """Like the built-in set, but compares objects with "is".""" - def __init__(self): - self._storage = set() + def __init__(self, *args): + self._storage = set([self._wrap_key(obj) for obj in list(*args)]) + + def _wrap_key(self, key): + return _ObjectIdentityWrapper(key) def __contains__(self, key): - return _WeakObjectIdentityWrapper(key) in self._storage + return self._wrap_key(key) in self._storage def discard(self, key): - self._storage.discard(_WeakObjectIdentityWrapper(key)) + self._storage.discard(self._wrap_key(key)) def add(self, key): - self._storage.add(_WeakObjectIdentityWrapper(key)) + self._storage.add(self._wrap_key(key)) + + def __len__(self): + return len(self._storage) + + def __iter__(self): + keys = list(self._storage) + for key in keys: + yield key.unwrapped + + +class _ObjectIdentityWeakSet(_ObjectIdentitySet): + """Like weakref.WeakSet, but compares objects with "is".""" + + def _wrap_key(self, key): + return _WeakObjectIdentityWrapper(key) def __len__(self): # Iterate, discarding old weak refs - return len(list(self)) + return len([_ for _ in self]) def __iter__(self): keys = list(self._storage) @@ -747,7 +765,7 @@ def capture_dependencies(template): initial_value=initializer, name=name, **inner_kwargs) - if name.startswith(name_prefix): + if name is not None and name.startswith(name_prefix): scope_stripped_name = name[len(name_prefix) + 1:] if not checkpointable_parent: return template._add_variable_with_custom_getter( # pylint: disable=protected-access @@ -857,8 +875,8 @@ class CheckpointLoadStatus(_LoadStatus): for checkpointable_object in list_objects(self._root_checkpointable): self._checkpoint.all_python_objects.add(checkpointable_object) unused_python_objects = ( - set(self._checkpoint.all_python_objects) - - set(self._checkpoint.object_by_proto_id.values())) + _ObjectIdentitySet(self._checkpoint.all_python_objects) + - _ObjectIdentitySet(self._checkpoint.object_by_proto_id.values())) if unused_python_objects: raise AssertionError( ("Some Python objects were not bound to checkpointed values, likely " |