aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/checkpointable/data_structures.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/training/checkpointable/data_structures.py')
-rw-r--r--tensorflow/python/training/checkpointable/data_structures.py189
1 files changed, 188 insertions, 1 deletions
diff --git a/tensorflow/python/training/checkpointable/data_structures.py b/tensorflow/python/training/checkpointable/data_structures.py
index 019d43f09c..507cda8734 100644
--- a/tensorflow/python/training/checkpointable/data_structures.py
+++ b/tensorflow/python/training/checkpointable/data_structures.py
@@ -57,6 +57,8 @@ def _wrap_or_unwrap(value):
return value.value
if isinstance(value, base.CheckpointableBase):
return value # Skip conversion for already checkpointable objects.
+ elif isinstance(value, dict):
+ return _DictWrapper(value)
elif isinstance(value, list):
return _ListWrapper(value)
else:
@@ -438,12 +440,15 @@ class Mapping(CheckpointableDataStructure, collections.Mapping):
def __init__(self, *args, **kwargs):
"""Construct a new sequence. Arguments are passed to `dict()`."""
super(Mapping, self).__init__()
- self._storage = dict(*args, **kwargs)
+ self._storage = self._make_storage(*args, **kwargs)
self._storage.update(
{key: self._track_value(
value, name=self._name_element(key))
for key, value in self._storage.items()})
+ def _make_storage(self, *args, **kwargs):
+ return dict(*args, **kwargs)
+
def _name_element(self, key):
if not isinstance(key, six.string_types):
raise TypeError(
@@ -476,3 +481,185 @@ class Mapping(CheckpointableDataStructure, collections.Mapping):
def __iter__(self):
return iter(self._storage)
+
+
+# Unlike _ListWrapper, having _DictWrapper inherit from dict and pass isinstance
+# checks seems infeasible. CPython will not call Python methods/properties on
+# dictionary subclasses when running e.g. {}.update(dict_subclass), and instead
+# collects elements directly from dict_subclass's C structs. So subclassing dict
+# implies that the storage has to be "self" (i.e. the C structs for the object
+# must be updated correctly), but we also need that storage to be the wrapped
+# dictionary to avoid synchronization bugs (un-tracked external modifications
+# should still show up when the dict is accessed through the wrapper). Monkey
+# patching all of the "wrapped" dict's methods instead of creating a wrapper
+# object is an option, but not a very attractive one (replacing methods without
+# creating reference cycles is difficult, and then dicts would need to be
+# special cased everywhere as being checkpointable).
+class _DictWrapper(Mapping, collections.MutableMapping):
+ """Wraps built-in dicts to support restore-on-create for variables.
+
+ _DictWrapper is to Mapping as _ListWrapper is to List. Unlike Mapping,
+ _DictWrapper allows non-string keys and values and arbitrary mutations (delete
+ keys, reassign values). Like _ListWrapper, these mutations mean that
+ _DictWrapper will raise an exception on save.
+ """
+
+ def __new__(cls, *args):
+ if len(args) == 1 and isinstance(args[0], dict):
+ return super(_DictWrapper, cls).__new__(cls)
+ else:
+ # Allow construction from a sequence, e.g. for nest.pack_sequence_as. In
+ # this case there's nothing to wrap, so we make a normal dictionary. Also
+ # allows constructing empty instances of the _DictWrapper type, as Session
+ # is wont to do (and again there's nothing to wrap, so a normal dictionary
+ # makes more sense).
+ return dict(*args)
+
+ def __init__(self, wrapped_dict):
+ self._non_string_key = False
+ self._non_append_mutation = False
+ self._external_modification = False
+ super(_DictWrapper, self).__init__(wrapped_dict)
+ self._update_snapshot()
+
+ def _make_storage(self, wrapped_dict):
+ """Re-use the wrapped dict for storage (to force them to be in sync)."""
+ return wrapped_dict
+
+ @property
+ def _checkpoint_dependencies(self):
+ """Check that the object is saveable before listing its dependencies."""
+ self._check_external_modification()
+ if self._non_string_key:
+ raise ValueError(
+ "Unable to save the object %s (a dictionary wrapper constructed "
+ "automatically on attribute assignment). The wrapped dictionary "
+ "contains a non-string key which maps to a checkpointable object or "
+ "mutable data structure.\n\nIf you don't need this dictionary "
+ "checkpointed, wrap it in a tf.contrib.checkpoint.NoDependency "
+ "object; it will be automatically un-wrapped and subsequently "
+ "ignored." % (self,))
+ if self._non_append_mutation:
+ raise ValueError(
+ "Unable to save the object %s (a dictionary wrapper constructed "
+ "automatically on attribute assignment). A key mapping to a "
+ "checkpointable object was overwritten or deleted, which would "
+ "cause problems for restoration.\n\nIf you don't need this "
+ "dictionary checkpointed, wrap it in a "
+ "tf.contrib.checkpoint.NoDependency object; it will be automatically "
+ "un-wrapped and subsequently ignored." % (self,))
+ if self._external_modification:
+ raise ValueError(
+ "Unable to save the object %s (a dictionary wrapper constructed "
+ "automatically on attribute assignment). The wrapped dictionary was "
+ "modified outside the wrapper (its final value was %s, its value "
+ "when a checkpoint dependency was added was %s), which breaks "
+ "restoration on object creation.\n\nIf you don't need this "
+ "dictionary checkpointed, wrap it in a "
+ "tf.contrib.checkpoint.NoDependency object; it will be automatically "
+ "un-wrapped and subsequently ignored." % (
+ self, self, self._last_wrapped_dict_snapshot))
+ assert not self._dirty # Any reason for dirtiness should have an exception.
+ return super(_DictWrapper, self)._checkpoint_dependencies
+
+ @property
+ def _dirty(self):
+ """Check if there has already been a mutation which prevents saving."""
+ return (self._external_modification
+ or self._non_append_mutation
+ or self._non_string_key)
+
+ def _check_external_modification(self):
+ """Checks for any changes to the wrapped dict not through the wrapper."""
+ if self._dirty:
+ return
+ if self != self._last_wrapped_dict_snapshot:
+ self._external_modification = True
+ self._last_wrapped_dict_snapshot = None
+
+ def _update_snapshot(self):
+ """Acknowledges tracked changes to the wrapped dict."""
+ if self._dirty:
+ return
+ self._last_wrapped_dict_snapshot = dict(self)
+
+ def _track_value(self, value, name):
+ """Allows storage of non-checkpointable objects."""
+ if isinstance(name, six.string_types):
+ string_key = True
+ else:
+ name = "-non_string_key"
+ string_key = False
+ try:
+ no_dependency = isinstance(value, NoDependency)
+ value = super(_DictWrapper, self)._track_value(value=value, name=name)
+ if not (string_key or no_dependency):
+ # A non-string key maps to a checkpointable value. This data structure
+ # is not saveable.
+ self._non_string_key = True
+ return value
+ except ValueError:
+ # Even if this value isn't checkpointable, we need to make sure
+ # NoDependency objects get unwrapped.
+ return sticky_attribute_assignment(
+ checkpointable=self, value=value, name=name)
+
+ def _name_element(self, key):
+ """Don't throw errors for non-string keys."""
+ if isinstance(key, six.string_types):
+ return super(_DictWrapper, self)._name_element(key)
+ else:
+ return key
+
+ def __setitem__(self, key, value):
+ """Allow any modifications, but possibly mark the wrapper as unsaveable."""
+ self._check_external_modification()
+ no_dep = isinstance(value, NoDependency)
+ if isinstance(key, six.string_types):
+ existing_dependency = self._lookup_dependency(key)
+ value = self._track_value(value, name=key)
+ else:
+ value = _wrap_or_unwrap(value)
+ existing_dependency = None
+ if not no_dep and isinstance(value, base.CheckpointableBase):
+ # Non-string keys are OK as long as we have no reason to add a
+ # dependency on the value (either because the value is not
+ # checkpointable, or because it was wrapped in a NoDependency object).
+ self._non_string_key = True
+ current_value = self._storage.setdefault(key, value)
+ if current_value is not value:
+ if ((not no_dep and isinstance(value, base.CheckpointableBase))
+ # We don't want to just check that the existing object is
+ # checkpointable, since it may have been wrapped in a NoDependency
+ # object.
+ or existing_dependency is not None):
+ # A checkpointable object was replaced under the same key; this means
+ # that restoring would be error-prone, so we'll throw an exception on
+ # save.
+ self._non_append_mutation = True
+ self._storage[key] = value
+
+ self._update_snapshot()
+
+ def __delitem__(self, key):
+ self._check_external_modification()
+ existing_value = self[key]
+ if isinstance(existing_value, base.CheckpointableBase):
+ # Deleting tracked checkpointable values means restoring is problematic,
+ # so we'll throw an exception on save.
+ self._non_append_mutation = True
+ del self._storage[key]
+ self._update_snapshot()
+
+ def __repr__(self):
+ return "DictWrapper(%s)" % (repr(self._storage),)
+
+ def __hash__(self):
+ raise TypeError("unhashable type: 'DictWrapper'")
+
+ def __eq__(self, other):
+ return self._storage == getattr(other, "_storage", other)
+
+ def update(self, *args, **kwargs):
+ for key, value in dict(*args, **kwargs).items():
+ self[key] = value