diff options
Diffstat (limited to 'tensorflow/python/training/checkpointable/data_structures.py')
-rw-r--r-- | tensorflow/python/training/checkpointable/data_structures.py | 189 |
1 files changed, 188 insertions, 1 deletions
diff --git a/tensorflow/python/training/checkpointable/data_structures.py b/tensorflow/python/training/checkpointable/data_structures.py index 019d43f09c..507cda8734 100644 --- a/tensorflow/python/training/checkpointable/data_structures.py +++ b/tensorflow/python/training/checkpointable/data_structures.py @@ -57,6 +57,8 @@ def _wrap_or_unwrap(value): return value.value if isinstance(value, base.CheckpointableBase): return value # Skip conversion for already checkpointable objects. + elif isinstance(value, dict): + return _DictWrapper(value) elif isinstance(value, list): return _ListWrapper(value) else: @@ -438,12 +440,15 @@ class Mapping(CheckpointableDataStructure, collections.Mapping): def __init__(self, *args, **kwargs): """Construct a new sequence. Arguments are passed to `dict()`.""" super(Mapping, self).__init__() - self._storage = dict(*args, **kwargs) + self._storage = self._make_storage(*args, **kwargs) self._storage.update( {key: self._track_value( value, name=self._name_element(key)) for key, value in self._storage.items()}) + def _make_storage(self, *args, **kwargs): + return dict(*args, **kwargs) + def _name_element(self, key): if not isinstance(key, six.string_types): raise TypeError( @@ -476,3 +481,185 @@ class Mapping(CheckpointableDataStructure, collections.Mapping): def __iter__(self): return iter(self._storage) + + +# Unlike _ListWrapper, having _DictWrapper inherit from dict and pass isinstance +# checks seems infeasible. CPython will not call Python methods/properties on +# dictionary subclasses when running e.g. {}.update(dict_subclass), and instead +# collects elements directly from dict_subclass's C structs. So subclassing dict +# implies that the storage has to be "self" (i.e. the C structs for the object +# must be updated correctly), but we also need that storage to be the wrapped +# dictionary to avoid synchronization bugs (un-tracked external modifications +# should still show up when the dict is accessed through the wrapper). Monkey +# patching all of the "wrapped" dict's methods instead of creating a wrapper +# object is an option, but not a very attractive one (replacing methods without +# creating reference cycles is difficult, and then dicts would need to be +# special cased everywhere as being checkpointable). +class _DictWrapper(Mapping, collections.MutableMapping): + """Wraps built-in dicts to support restore-on-create for variables. + + _DictWrapper is to Mapping as _ListWrapper is to List. Unlike Mapping, + _DictWrapper allows non-string keys and values and arbitrary mutations (delete + keys, reassign values). Like _ListWrapper, these mutations mean that + _DictWrapper will raise an exception on save. + """ + + def __new__(cls, *args): + if len(args) == 1 and isinstance(args[0], dict): + return super(_DictWrapper, cls).__new__(cls) + else: + # Allow construction from a sequence, e.g. for nest.pack_sequence_as. In + # this case there's nothing to wrap, so we make a normal dictionary. Also + # allows constructing empty instances of the _DictWrapper type, as Session + # is wont to do (and again there's nothing to wrap, so a normal dictionary + # makes more sense). + return dict(*args) + + def __init__(self, wrapped_dict): + self._non_string_key = False + self._non_append_mutation = False + self._external_modification = False + super(_DictWrapper, self).__init__(wrapped_dict) + self._update_snapshot() + + def _make_storage(self, wrapped_dict): + """Re-use the wrapped dict for storage (to force them to be in sync).""" + return wrapped_dict + + @property + def _checkpoint_dependencies(self): + """Check that the object is saveable before listing its dependencies.""" + self._check_external_modification() + if self._non_string_key: + raise ValueError( + "Unable to save the object %s (a dictionary wrapper constructed " + "automatically on attribute assignment). The wrapped dictionary " + "contains a non-string key which maps to a checkpointable object or " + "mutable data structure.\n\nIf you don't need this dictionary " + "checkpointed, wrap it in a tf.contrib.checkpoint.NoDependency " + "object; it will be automatically un-wrapped and subsequently " + "ignored." % (self,)) + if self._non_append_mutation: + raise ValueError( + "Unable to save the object %s (a dictionary wrapper constructed " + "automatically on attribute assignment). A key mapping to a " + "checkpointable object was overwritten or deleted, which would " + "cause problems for restoration.\n\nIf you don't need this " + "dictionary checkpointed, wrap it in a " + "tf.contrib.checkpoint.NoDependency object; it will be automatically " + "un-wrapped and subsequently ignored." % (self,)) + if self._external_modification: + raise ValueError( + "Unable to save the object %s (a dictionary wrapper constructed " + "automatically on attribute assignment). The wrapped dictionary was " + "modified outside the wrapper (its final value was %s, its value " + "when a checkpoint dependency was added was %s), which breaks " + "restoration on object creation.\n\nIf you don't need this " + "dictionary checkpointed, wrap it in a " + "tf.contrib.checkpoint.NoDependency object; it will be automatically " + "un-wrapped and subsequently ignored." % ( + self, self, self._last_wrapped_dict_snapshot)) + assert not self._dirty # Any reason for dirtiness should have an exception. + return super(_DictWrapper, self)._checkpoint_dependencies + + @property + def _dirty(self): + """Check if there has already been a mutation which prevents saving.""" + return (self._external_modification + or self._non_append_mutation + or self._non_string_key) + + def _check_external_modification(self): + """Checks for any changes to the wrapped dict not through the wrapper.""" + if self._dirty: + return + if self != self._last_wrapped_dict_snapshot: + self._external_modification = True + self._last_wrapped_dict_snapshot = None + + def _update_snapshot(self): + """Acknowledges tracked changes to the wrapped dict.""" + if self._dirty: + return + self._last_wrapped_dict_snapshot = dict(self) + + def _track_value(self, value, name): + """Allows storage of non-checkpointable objects.""" + if isinstance(name, six.string_types): + string_key = True + else: + name = "-non_string_key" + string_key = False + try: + no_dependency = isinstance(value, NoDependency) + value = super(_DictWrapper, self)._track_value(value=value, name=name) + if not (string_key or no_dependency): + # A non-string key maps to a checkpointable value. This data structure + # is not saveable. + self._non_string_key = True + return value + except ValueError: + # Even if this value isn't checkpointable, we need to make sure + # NoDependency objects get unwrapped. + return sticky_attribute_assignment( + checkpointable=self, value=value, name=name) + + def _name_element(self, key): + """Don't throw errors for non-string keys.""" + if isinstance(key, six.string_types): + return super(_DictWrapper, self)._name_element(key) + else: + return key + + def __setitem__(self, key, value): + """Allow any modifications, but possibly mark the wrapper as unsaveable.""" + self._check_external_modification() + no_dep = isinstance(value, NoDependency) + if isinstance(key, six.string_types): + existing_dependency = self._lookup_dependency(key) + value = self._track_value(value, name=key) + else: + value = _wrap_or_unwrap(value) + existing_dependency = None + if not no_dep and isinstance(value, base.CheckpointableBase): + # Non-string keys are OK as long as we have no reason to add a + # dependency on the value (either because the value is not + # checkpointable, or because it was wrapped in a NoDependency object). + self._non_string_key = True + current_value = self._storage.setdefault(key, value) + if current_value is not value: + if ((not no_dep and isinstance(value, base.CheckpointableBase)) + # We don't want to just check that the existing object is + # checkpointable, since it may have been wrapped in a NoDependency + # object. + or existing_dependency is not None): + # A checkpointable object was replaced under the same key; this means + # that restoring would be error-prone, so we'll throw an exception on + # save. + self._non_append_mutation = True + self._storage[key] = value + + self._update_snapshot() + + def __delitem__(self, key): + self._check_external_modification() + existing_value = self[key] + if isinstance(existing_value, base.CheckpointableBase): + # Deleting tracked checkpointable values means restoring is problematic, + # so we'll throw an exception on save. + self._non_append_mutation = True + del self._storage[key] + self._update_snapshot() + + def __repr__(self): + return "DictWrapper(%s)" % (repr(self._storage),) + + def __hash__(self): + raise TypeError("unhashable type: 'DictWrapper'") + + def __eq__(self, other): + return self._storage == getattr(other, "_storage", other) + + def update(self, *args, **kwargs): + for key, value in dict(*args, **kwargs).items(): + self[key] = value |