aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/checkpointable/data_structures.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/training/checkpointable/data_structures.py')
-rw-r--r--tensorflow/python/training/checkpointable/data_structures.py472
1 files changed, 442 insertions, 30 deletions
diff --git a/tensorflow/python/training/checkpointable/data_structures.py b/tensorflow/python/training/checkpointable/data_structures.py
index c46585b417..507cda8734 100644
--- a/tensorflow/python/training/checkpointable/data_structures.py
+++ b/tensorflow/python/training/checkpointable/data_structures.py
@@ -22,49 +22,128 @@ import collections
import six
from tensorflow.python.ops import variables
-from tensorflow.python.training.checkpointable import base as checkpointable_lib
+from tensorflow.python.training.checkpointable import base
from tensorflow.python.training.checkpointable import layer_utils
-# TODO(allenl): We could track regular Python data structures which get assigned
-# to Checkpointable objects. Making this work with restore-on-create would be
-# tricky; we'd need to re-create nested structures with our own wrapped objects
-# on assignment to an attribute, and track the user's original structure to make
-# sure they don't modify it except through the wrappers (since we could save the
-# user's updated structure, but would have no way to support restore-on-create
-# for those modifications).
-# TODO(allenl): A dictionary data structure would be good too.
-class CheckpointableDataStructure(checkpointable_lib.CheckpointableBase):
+class NoDependency(object):
+ """Allows attribute assignment to `Checkpointable` objects with no dependency.
+
+ Example usage:
+ ```python
+ obj = Checkpointable()
+ obj.has_dependency = tf.Variable(0., name="dep")
+ obj.no_dependency = NoDependency(tf.Variable(1., name="nodep"))
+ assert obj.no_dependency.name == "nodep:0"
+ ```
+
+ `obj` in this example has a dependency on the variable "dep", and both
+ attributes contain un-wrapped `Variable` objects.
+
+ `NoDependency` also works with `tf.keras.Model`, but only for checkpoint
+ dependencies: wrapping a `Layer` in `NoDependency` will assign the (unwrapped)
+ `Layer` to the attribute without a checkpoint dependency, but the `Model` will
+ still track the `Layer` (so it will appear in `Model.layers`, and its
+ variables will appear in `Model.variables`).
+ """
+
+ def __init__(self, value):
+ self.value = value
+
+
+def _wrap_or_unwrap(value):
+ """Wraps basic data structures, unwraps NoDependency objects."""
+ if isinstance(value, NoDependency):
+ return value.value
+ if isinstance(value, base.CheckpointableBase):
+ return value # Skip conversion for already checkpointable objects.
+ elif isinstance(value, dict):
+ return _DictWrapper(value)
+ elif isinstance(value, list):
+ return _ListWrapper(value)
+ else:
+ return value
+ # TODO(allenl): Handle other common data structures. Tuples will require
+ # special casing (tuple subclasses are not weak referenceable, so replacement
+ # with a wrapper that subclasses tuple on attribute assignment works poorly,
+ # and replacement with a wrapper that isn't a tuple is also problematic),
+ # probably a tree traversal where the leaves are non-tuples(/namedtuples) to
+ # come up with names. Dictionaries should look like lists.
+
+
+def sticky_attribute_assignment(checkpointable, name, value):
+ """Adds dependencies, generally called from __setattr__.
+
+ This behavior is shared between Checkpointable and Model.
+
+ Respects NoDependency indicators, but otherwise makes checkpointable objects
+ out of common data structures and tracks objects by their attribute names.
+
+ Args:
+ checkpointable: The object to add dependencies to (generally the one having
+ an attribute assigned).
+ name: The attribute name being assigned.
+ value: The value being assigned. Not necessarily a checkpointable object.
+
+ Returns:
+ The value which should be stored in the attribute (unwrapped from a
+ NoDependency object if necessary).
+ """
+ if isinstance(value, NoDependency):
+ add_dependency = False
+ else:
+ add_dependency = True
+ value = _wrap_or_unwrap(value)
+ if not add_dependency:
+ return value
+ if isinstance(value, base.CheckpointableBase):
+ checkpointable._track_checkpointable( # pylint: disable=protected-access
+ value, name=name,
+ # Allow the user to switch the Checkpointable which is tracked by this
+ # name, since assigning a new variable to an attribute has
+ # historically been fine (e.g. Adam did this).
+ overwrite=True)
+ return value
+
+
+class CheckpointableDataStructure(base.CheckpointableBase):
"""Base class for data structures which contain checkpointable objects."""
def __init__(self):
+ # An append-only ordered set
self._layers = []
+
self.trainable = True
self._extra_variables = []
def _track_value(self, value, name):
"""Add a dependency on `value`."""
- if isinstance(value, checkpointable_lib.CheckpointableBase):
- self._track_checkpointable(value, name=name)
- if isinstance(value, variables.Variable):
- self._extra_variables.append(value)
- else:
+ value = sticky_attribute_assignment(
+ checkpointable=self, value=value, name=name)
+ if isinstance(value, variables.Variable):
+ self._extra_variables.append(value)
+ if not isinstance(value, base.CheckpointableBase):
raise ValueError(
("Only checkpointable objects (such as Layers or Optimizers) may be "
"stored in a List object. Got %s, which does not inherit from "
"CheckpointableBase.") % (value,))
if (isinstance(value, CheckpointableDataStructure)
or layer_utils.is_layer(value)):
- if value not in self._layers:
+ # Check for object-identity rather than with __eq__ to avoid
+ # de-duplicating empty container types. Automatically generated list
+ # wrappers keep things like "[] == []" true, which means "[] in [[]]" is
+ # also true. This becomes not true once one of the lists is mutated.
+ if not any((layer is value for layer in self._layers)):
self._layers.append(value)
if hasattr(value, "_use_resource_variables"):
# In subclassed models, legacy layers (tf.layers) must always use
# resource variables.
value._use_resource_variables = True # pylint: disable=protected-access
+ return value
@property
def layers(self):
- return self._layers
+ return layer_utils.filter_empty_layer_containers(self._layers)
@property
def trainable_weights(self):
@@ -164,24 +243,28 @@ class List(CheckpointableDataStructure, collections.Sequence):
def __init__(self, *args, **kwargs):
"""Construct a new sequence. Arguments are passed to `list()`."""
super(List, self).__init__()
- self._storage = list(*args, **kwargs)
+ self._storage = self._make_storage(*args, **kwargs)
for index, element in enumerate(self._storage):
- self._track_value(element, name=self._name_element(index))
+ self._storage[index] = self._track_value(
+ element, name=self._name_element(index))
+
+ def _make_storage(self, *args, **kwargs):
+ """Determines the backing storage (overridden in subclasses)."""
+ return list(*args, **kwargs)
def _name_element(self, index):
return "%d" % (index,)
def append(self, value):
"""Add a new checkpointable value."""
- self._track_value(value, self._name_element(len(self._storage)))
+ value = self._track_value(value, self._name_element(len(self._storage)))
self._storage.append(value)
def extend(self, values):
"""Add a sequence of checkpointable values."""
- for index_offset, value in enumerate(values):
- self._track_value(
- value, name=self._name_element(len(self._storage) + index_offset))
- self._storage.extend(values)
+ for value in values:
+ self._storage.append(self._track_value(
+ value, name=self._name_element(len(self._storage))))
def __iadd__(self, values):
self.extend(values)
@@ -189,9 +272,12 @@ class List(CheckpointableDataStructure, collections.Sequence):
def __add__(self, other):
if isinstance(other, List):
- return List(self._storage + other._storage) # pylint: disable=protected-access
+ return self.__class__(self._storage + other._storage) # pylint: disable=protected-access
else:
- return List(self._storage + other)
+ return self.__class__(self._storage + other)
+
+ def __radd__(self, other):
+ return self + other
def __getitem__(self, key):
return self._storage[key]
@@ -203,6 +289,144 @@ class List(CheckpointableDataStructure, collections.Sequence):
return "List(%s)" % (repr(self._storage),)
+class _ListWrapper(List, collections.MutableSequence,
+ # Shadowed, but there for isinstance checks.
+ list):
+ """Wraps the built-in `list` to support restore-on-create for variables.
+
+ Unlike `List`, this sequence type is mutable in the same ways built-in lists
+ are. Instead of throwing an error immediately like `List`, it records
+ problematic mutations (e.g. assigning a new element to a position already
+ occupied, meaning both elements get the same names at different times) and
+ refuses to save.
+
+ On assignment to an attribute of a Model or Checkpointable object, Python
+ lists are replaced with _ListWrapper. Wrapping a list in a
+ `tf.contrib.checkpoint.NoDependency` object prevents this.
+ """
+
+ def __init__(self, wrapped_list):
+ """Construct a new list wrapper.
+
+ Args:
+ wrapped_list: The initial value of the data structure. A shallow copy may
+ be maintained for error checking. `wrapped_list` itself should not be
+ modified directly after constructing the `_ListWrapper`, and if changes
+ are detected the `_ListWrapper` will throw an exception on save.
+ """
+ # Monotonic flags which indicate this object would not be restored properly,
+ # and therefore should throw an error on save to avoid giving the impression
+ # that restoring it will work.
+ self._non_append_mutation = False
+ self._external_modification = False
+ super(_ListWrapper, self).__init__(wrapped_list)
+ self._last_wrapped_list_snapshot = list(self._storage)
+
+ def _make_storage(self, wrapped_list):
+ """Use the user's original list for storage."""
+ return wrapped_list
+
+ def _check_external_modification(self):
+ """Checks for any changes to the wrapped list not through the wrapper."""
+ if self._external_modification or self._non_append_mutation:
+ return
+ if self._storage != self._last_wrapped_list_snapshot:
+ self._external_modification = True
+ self._last_wrapped_list_snapshot = None
+
+ def _update_snapshot(self):
+ """Acknowledges tracked changes to the wrapped list."""
+ if self._external_modification or self._non_append_mutation:
+ return
+ self._last_wrapped_list_snapshot = list(self._storage)
+
+ @property
+ def _checkpoint_dependencies(self):
+ self._check_external_modification()
+ if self._non_append_mutation:
+ raise ValueError(
+ ("Unable to save the object %s (a list wrapper constructed to track "
+ "checkpointable TensorFlow objects). A list element was replaced "
+ "(__setitem__), deleted, or inserted. In order to support "
+ "restoration on object creation, tracking is exclusively for "
+ "append-only data structures.\n\nIf you don't need this list "
+ "checkpointed, wrap it in a tf.contrib.checkpoint.NoDependency "
+ "object; it will be automatically un-wrapped and subsequently "
+ "ignored." % (self,)))
+ if self._external_modification:
+ raise ValueError(
+ ("Unable to save the object %s (a list wrapper constructed to track "
+ "checkpointable TensorFlow objects). The wrapped list was modified "
+ "outside the wrapper (its final value was %s, its value when a "
+ "checkpoint dependency was added was %s), which breaks restoration "
+ "on object creation.\n\nIf you don't need this list checkpointed, "
+ "wrap it in a tf.contrib.checkpoint.NoDependency object; it will be "
+ "automatically un-wrapped and subsequently ignored." % (
+ self, self._storage, self._last_wrapped_list_snapshot)))
+ return super(_ListWrapper, self)._checkpoint_dependencies
+
+ def __delitem__(self, key):
+ self._non_append_mutation = True
+ del self._storage[key]
+
+ def __setitem__(self, key, value):
+ self._non_append_mutation = True
+ self._storage[key] = value
+
+ def append(self, value):
+ """Add a new checkpointable value."""
+ self._check_external_modification()
+ super(_ListWrapper, self).append(value)
+ self._update_snapshot()
+
+ def extend(self, values):
+ """Add a sequence of checkpointable values."""
+ self._check_external_modification()
+ super(_ListWrapper, self).extend(values)
+ self._update_snapshot()
+
+ def __eq__(self, other):
+ return self._storage == getattr(other, "_storage", other)
+
+ def __ne__(self, other):
+ return self._storage != getattr(other, "_storage", other)
+
+ def __lt__(self, other):
+ return self._storage < getattr(other, "_storage", other)
+
+ def __le__(self, other):
+ return self._storage <= getattr(other, "_storage", other)
+
+ def __gt__(self, other):
+ return self._storage > getattr(other, "_storage", other)
+
+ def __ge__(self, other):
+ return self._storage >= getattr(other, "_storage", other)
+
+ def __hash__(self):
+ # List wrappers need to compare like regular lists, and so like regular
+ # lists they don't belong in hash tables.
+ raise TypeError("unhashable type: 'ListWrapper'")
+
+ def insert(self, index, obj):
+ self._non_append_mutation = True
+ self._storage.insert(index, obj)
+
+ def _track_value(self, value, name):
+ """Allows storage of non-checkpointable objects."""
+ try:
+ value = super(_ListWrapper, self)._track_value(value=value, name=name)
+ except ValueError:
+ # Even if this value isn't checkpointable, we need to make sure
+ # NoDependency objects get unwrapped.
+ value = sticky_attribute_assignment(
+ checkpointable=self, value=value, name=name)
+ return value
+
+ def __repr__(self):
+ return "ListWrapper(%s)" % (repr(self._storage),)
+
+
class Mapping(CheckpointableDataStructure, collections.Mapping):
"""An append-only checkpointable mapping data structure with string keys.
@@ -216,9 +440,14 @@ class Mapping(CheckpointableDataStructure, collections.Mapping):
def __init__(self, *args, **kwargs):
"""Construct a new sequence. Arguments are passed to `dict()`."""
super(Mapping, self).__init__()
- self._storage = dict(*args, **kwargs)
- for key, value in self._storage.items():
- self._track_value(value, name=self._name_element(key))
+ self._storage = self._make_storage(*args, **kwargs)
+ self._storage.update(
+ {key: self._track_value(
+ value, name=self._name_element(key))
+ for key, value in self._storage.items()})
+
+ def _make_storage(self, *args, **kwargs):
+ return dict(*args, **kwargs)
def _name_element(self, key):
if not isinstance(key, six.string_types):
@@ -228,13 +457,14 @@ class Mapping(CheckpointableDataStructure, collections.Mapping):
return str(key)
def __setitem__(self, key, value):
+ name = self._name_element(key)
+ value = self._track_value(value, name=name)
current_value = self._storage.setdefault(key, value)
if current_value is not value:
raise ValueError(
("Mappings are an append-only data structure. Tried to overwrite the "
"key '%s' with value %s, but it already contains %s")
% (key, value, current_value))
- self._track_value(value, name=self._name_element(key))
def update(self, *args, **kwargs):
for key, value in dict(*args, **kwargs).items():
@@ -251,3 +481,185 @@ class Mapping(CheckpointableDataStructure, collections.Mapping):
def __iter__(self):
return iter(self._storage)
+
+
+# Unlike _ListWrapper, having _DictWrapper inherit from dict and pass isinstance
+# checks seems infeasible. CPython will not call Python methods/properties on
+# dictionary subclasses when running e.g. {}.update(dict_subclass), and instead
+# collects elements directly from dict_subclass's C structs. So subclassing dict
+# implies that the storage has to be "self" (i.e. the C structs for the object
+# must be updated correctly), but we also need that storage to be the wrapped
+# dictionary to avoid synchronization bugs (un-tracked external modifications
+# should still show up when the dict is accessed through the wrapper). Monkey
+# patching all of the "wrapped" dict's methods instead of creating a wrapper
+# object is an option, but not a very attractive one (replacing methods without
+# creating reference cycles is difficult, and then dicts would need to be
+# special cased everywhere as being checkpointable).
+class _DictWrapper(Mapping, collections.MutableMapping):
+ """Wraps built-in dicts to support restore-on-create for variables.
+
+ _DictWrapper is to Mapping as _ListWrapper is to List. Unlike Mapping,
+ _DictWrapper allows non-string keys and values and arbitrary mutations (delete
+ keys, reassign values). Like _ListWrapper, these mutations mean that
+ _DictWrapper will raise an exception on save.
+ """
+
+ def __new__(cls, *args):
+ if len(args) == 1 and isinstance(args[0], dict):
+ return super(_DictWrapper, cls).__new__(cls)
+ else:
+ # Allow construction from a sequence, e.g. for nest.pack_sequence_as. In
+ # this case there's nothing to wrap, so we make a normal dictionary. Also
+ # allows constructing empty instances of the _DictWrapper type, as Session
+ # is wont to do (and again there's nothing to wrap, so a normal dictionary
+ # makes more sense).
+ return dict(*args)
+
+ def __init__(self, wrapped_dict):
+ self._non_string_key = False
+ self._non_append_mutation = False
+ self._external_modification = False
+ super(_DictWrapper, self).__init__(wrapped_dict)
+ self._update_snapshot()
+
+ def _make_storage(self, wrapped_dict):
+ """Re-use the wrapped dict for storage (to force them to be in sync)."""
+ return wrapped_dict
+
+ @property
+ def _checkpoint_dependencies(self):
+ """Check that the object is saveable before listing its dependencies."""
+ self._check_external_modification()
+ if self._non_string_key:
+ raise ValueError(
+ "Unable to save the object %s (a dictionary wrapper constructed "
+ "automatically on attribute assignment). The wrapped dictionary "
+ "contains a non-string key which maps to a checkpointable object or "
+ "mutable data structure.\n\nIf you don't need this dictionary "
+ "checkpointed, wrap it in a tf.contrib.checkpoint.NoDependency "
+ "object; it will be automatically un-wrapped and subsequently "
+ "ignored." % (self,))
+ if self._non_append_mutation:
+ raise ValueError(
+ "Unable to save the object %s (a dictionary wrapper constructed "
+ "automatically on attribute assignment). A key mapping to a "
+ "checkpointable object was overwritten or deleted, which would "
+ "cause problems for restoration.\n\nIf you don't need this "
+ "dictionary checkpointed, wrap it in a "
+ "tf.contrib.checkpoint.NoDependency object; it will be automatically "
+ "un-wrapped and subsequently ignored." % (self,))
+ if self._external_modification:
+ raise ValueError(
+ "Unable to save the object %s (a dictionary wrapper constructed "
+ "automatically on attribute assignment). The wrapped dictionary was "
+ "modified outside the wrapper (its final value was %s, its value "
+ "when a checkpoint dependency was added was %s), which breaks "
+ "restoration on object creation.\n\nIf you don't need this "
+ "dictionary checkpointed, wrap it in a "
+ "tf.contrib.checkpoint.NoDependency object; it will be automatically "
+ "un-wrapped and subsequently ignored." % (
+ self, self, self._last_wrapped_dict_snapshot))
+ assert not self._dirty # Any reason for dirtiness should have an exception.
+ return super(_DictWrapper, self)._checkpoint_dependencies
+
+ @property
+ def _dirty(self):
+ """Check if there has already been a mutation which prevents saving."""
+ return (self._external_modification
+ or self._non_append_mutation
+ or self._non_string_key)
+
+ def _check_external_modification(self):
+ """Checks for any changes to the wrapped dict not through the wrapper."""
+ if self._dirty:
+ return
+ if self != self._last_wrapped_dict_snapshot:
+ self._external_modification = True
+ self._last_wrapped_dict_snapshot = None
+
+ def _update_snapshot(self):
+ """Acknowledges tracked changes to the wrapped dict."""
+ if self._dirty:
+ return
+ self._last_wrapped_dict_snapshot = dict(self)
+
+ def _track_value(self, value, name):
+ """Allows storage of non-checkpointable objects."""
+ if isinstance(name, six.string_types):
+ string_key = True
+ else:
+ name = "-non_string_key"
+ string_key = False
+ try:
+ no_dependency = isinstance(value, NoDependency)
+ value = super(_DictWrapper, self)._track_value(value=value, name=name)
+ if not (string_key or no_dependency):
+ # A non-string key maps to a checkpointable value. This data structure
+ # is not saveable.
+ self._non_string_key = True
+ return value
+ except ValueError:
+ # Even if this value isn't checkpointable, we need to make sure
+ # NoDependency objects get unwrapped.
+ return sticky_attribute_assignment(
+ checkpointable=self, value=value, name=name)
+
+ def _name_element(self, key):
+ """Don't throw errors for non-string keys."""
+ if isinstance(key, six.string_types):
+ return super(_DictWrapper, self)._name_element(key)
+ else:
+ return key
+
+ def __setitem__(self, key, value):
+ """Allow any modifications, but possibly mark the wrapper as unsaveable."""
+ self._check_external_modification()
+ no_dep = isinstance(value, NoDependency)
+ if isinstance(key, six.string_types):
+ existing_dependency = self._lookup_dependency(key)
+ value = self._track_value(value, name=key)
+ else:
+ value = _wrap_or_unwrap(value)
+ existing_dependency = None
+ if not no_dep and isinstance(value, base.CheckpointableBase):
+ # Non-string keys are OK as long as we have no reason to add a
+ # dependency on the value (either because the value is not
+ # checkpointable, or because it was wrapped in a NoDependency object).
+ self._non_string_key = True
+ current_value = self._storage.setdefault(key, value)
+ if current_value is not value:
+ if ((not no_dep and isinstance(value, base.CheckpointableBase))
+ # We don't want to just check that the existing object is
+ # checkpointable, since it may have been wrapped in a NoDependency
+ # object.
+ or existing_dependency is not None):
+ # A checkpointable object was replaced under the same key; this means
+ # that restoring would be error-prone, so we'll throw an exception on
+ # save.
+ self._non_append_mutation = True
+ self._storage[key] = value
+
+ self._update_snapshot()
+
+ def __delitem__(self, key):
+ self._check_external_modification()
+ existing_value = self[key]
+ if isinstance(existing_value, base.CheckpointableBase):
+ # Deleting tracked checkpointable values means restoring is problematic,
+ # so we'll throw an exception on save.
+ self._non_append_mutation = True
+ del self._storage[key]
+ self._update_snapshot()
+
+ def __repr__(self):
+ return "DictWrapper(%s)" % (repr(self._storage),)
+
+ def __hash__(self):
+ raise TypeError("unhashable type: 'DictWrapper'")
+
+ def __eq__(self, other):
+ return self._storage == getattr(other, "_storage", other)
+
+ def update(self, *args, **kwargs):
+ for key, value in dict(*args, **kwargs).items():
+ self[key] = value