diff options
author | Allen Lavoie <allenl@google.com> | 2018-06-29 14:02:26 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-29 14:04:49 -0700 |
commit | dcaa037571ab0933977f70574f4f78875155ae20 (patch) | |
tree | 4968e1966ca334f42296beae6cb1ecd8d483215e | |
parent | b3c163a754574faed4337f869c2d650a9f45c09c (diff) |
Auto tracking for Python lists assigned to attributes of Model/Checkpointable
Conceptually lists just get replaced with a list-like wrapper. A shallow copy is maintained for error checking (since appends to it aren't monitored, we can't do restore-on-create for variables unless it's being modified through the wrapper).
There are lots of other details. I gave up on generalizing our isinstance(obj, (list, tuple)) checks and just subclassed list. Behaving like a list means the type should be unhashable, which requires some workarounds when we're collecting objects (object-identity collections, and object-identity versions of weak reference containers).
Adds a decorator for exempting whole methods from automatic dependency tracking so we don't need to track down every last self.inputs = [] statement to avoid polluting dependencies.
There's a TODO for tuples and dictionaries.
PiperOrigin-RevId: 202703271
20 files changed, 728 insertions, 137 deletions
diff --git a/tensorflow/contrib/checkpoint/__init__.py b/tensorflow/contrib/checkpoint/__init__.py index 8c1ce5c2a2..2fbaa31d5e 100644 --- a/tensorflow/contrib/checkpoint/__init__.py +++ b/tensorflow/contrib/checkpoint/__init__.py @@ -44,8 +44,8 @@ from tensorflow.core.protobuf.checkpointable_object_graph_pb2 import Checkpointa from tensorflow.python.training.checkpointable.base import CheckpointableBase from tensorflow.python.training.checkpointable.data_structures import List from tensorflow.python.training.checkpointable.data_structures import Mapping +from tensorflow.python.training.checkpointable.data_structures import NoDependency from tensorflow.python.training.checkpointable.tracking import Checkpointable -from tensorflow.python.training.checkpointable.tracking import NoDependency from tensorflow.python.training.checkpointable.util import capture_dependencies from tensorflow.python.training.checkpointable.util import list_objects from tensorflow.python.training.checkpointable.util import object_metadata diff --git a/tensorflow/contrib/checkpoint/python/containers_test.py b/tensorflow/contrib/checkpoint/python/containers_test.py index 64d056bd68..ac85c7be80 100644 --- a/tensorflow/contrib/checkpoint/python/containers_test.py +++ b/tensorflow/contrib/checkpoint/python/containers_test.py @@ -26,6 +26,7 @@ from tensorflow.python.keras import layers from tensorflow.python.ops import array_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.platform import test +from tensorflow.python.training.checkpointable import data_structures from tensorflow.python.training.checkpointable import tracking from tensorflow.python.training.checkpointable import util @@ -79,7 +80,7 @@ class UniqueNameTrackerTests(test.TestCase): resource_variable_ops.ResourceVariable(4.), "y")) slots.append(slotdeps.track( resource_variable_ops.ResourceVariable(5.), "x")) - self.slots = slots + self.slots = data_structures.NoDependency(slots) manager = SlotManager() self.evaluate([v.initializer for v in manager.slots]) diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py index f3b788f931..e037925961 100644 --- a/tensorflow/python/client/session.py +++ b/tensorflow/python/client/session.py @@ -361,7 +361,7 @@ class _ListFetchMapper(_FetchMapper): for m, vi in zip(self._mappers, self._value_indices): results.append(m.build_results([values[j] for j in vi])) # Return a value of the original type of the fetches. - if self._fetch_type == list: + if issubclass(self._fetch_type, list): return results elif self._fetch_type == tuple: return tuple(results) diff --git a/tensorflow/python/estimator/keras.py b/tensorflow/python/estimator/keras.py index 5769f5739c..cb37f99704 100644 --- a/tensorflow/python/estimator/keras.py +++ b/tensorflow/python/estimator/keras.py @@ -45,6 +45,8 @@ from tensorflow.python.saved_model import signature_constants from tensorflow.python.training import distribute as distribute_lib from tensorflow.python.training import saver as saver_lib from tensorflow.python.training import training_util +from tensorflow.python.training.checkpointable import base as checkpointable +from tensorflow.python.training.checkpointable import data_structures _DEFAULT_SERVING_KEY = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY @@ -241,8 +243,17 @@ def _in_place_subclassed_model_state_restoration(model): # Restore layers and build attributes if (hasattr(model, '_original_attributes_cache') and model._original_attributes_cache is not None): - model._layers = [] + # Models have sticky attribute assignment, so we want to be careful to add + # back the previous attributes and track Layers by their original names + # without adding dependencies on "utility" attributes which Models exempt + # when they're constructed. + model._layers = data_structures.NoDependency([]) for name, value in model._original_attributes_cache.items(): + if not isinstance(value, checkpointable.CheckpointableBase): + # If this value is not already checkpointable, it's probably that way + # for a reason; we don't want to start tracking data structures that the + # original Model didn't. + value = data_structures.NoDependency(value) setattr(model, name, value) model._original_attributes_cache = None else: diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py index 4814275fd5..361778570b 100644 --- a/tensorflow/python/keras/engine/base_layer.py +++ b/tensorflow/python/keras/engine/base_layer.py @@ -116,6 +116,7 @@ class Layer(checkpointable.CheckpointableBase): constraints on inputs that can be accepted by the layer. """ + @checkpointable.no_automatic_dependency_tracking def __init__(self, trainable=True, name=None, dtype=None, **kwargs): # These properties should be set by the user via keyword arguments. # note that 'dtype', 'input_shape' and 'batch_input_shape' @@ -217,7 +218,7 @@ class Layer(checkpointable.CheckpointableBase): @activity_regularizer.setter def activity_regularizer(self, regularizer): """Optional regularizer function for the output of this layer.""" - self._activity_regularizer = regularizer + self._activity_regularizer = self._no_dependency(regularizer) @property def trainable_weights(self): @@ -658,7 +659,8 @@ class Layer(checkpointable.CheckpointableBase): self._compute_previous_mask): previous_mask = collect_previous_mask(inputs) if not hasattr(self, '_call_fn_args'): - self._call_fn_args = function_utils.fn_args(self.call) + self._call_fn_args = self._no_dependency( + function_utils.fn_args(self.call)) if ('mask' in self._call_fn_args and 'mask' not in kwargs and not generic_utils.is_all_none(previous_mask)): # The previous layer generated a mask, and mask was not explicitly pass diff --git a/tensorflow/python/keras/engine/network.py b/tensorflow/python/keras/engine/network.py index aa84eaa8ab..a4d96de74f 100644 --- a/tensorflow/python/keras/engine/network.py +++ b/tensorflow/python/keras/engine/network.py @@ -81,6 +81,20 @@ class Network(base_layer.Layer): # Subclassed network self._init_subclassed_network(**kwargs) + # Several Network methods have "no_automatic_dependency_tracking" + # annotations. Since Network does automatic dependency tracking on attribute + # assignment, including for common data structures such as lists, by default + # we'd have quite a few empty dependencies which users don't care about (or + # would need some way to ignore dependencies automatically, which is confusing + # when applied to user code). Some attributes, such as _layers, would cause + # structural issues (_layers being the place where Layers assigned to tracked + # attributes are stored). + # + # Aside from these aesthetic and structural issues, useless dependencies on + # empty lists shouldn't cause issues; adding or removing them will not break + # checkpoints, but may cause "all Python objects matched" assertions to fail + # (in which case less strict assertions may be substituted if necessary). + @checkpointable.no_automatic_dependency_tracking def _base_init(self, name=None): # The following are implemented as property functions: # self.trainable_weights @@ -135,6 +149,7 @@ class Network(base_layer.Layer): # restore operations when graph building. self._in_progress_restore_finalizer = None + @checkpointable.no_automatic_dependency_tracking def _init_graph_network(self, inputs, outputs, name=None): self._call_convention = base_layer.CallConvention.EXPLICIT_INPUTS_ARGUMENT # Normalize and set self.inputs, self.outputs. @@ -293,6 +308,7 @@ class Network(base_layer.Layer): for layer in self._output_layers: self.output_names.append(layer.name) + @checkpointable.no_automatic_dependency_tracking def _init_subclassed_network(self, name=None): self._base_init(name=name) self._is_graph_network = False @@ -362,10 +378,31 @@ class Network(base_layer.Layer): self._track_checkpointable( layer, name='layer-%d' % layer_index, overwrite=True) + def _no_dependency(self, value): + """Override to allow `Layer` to disable dependency tracking. + + `CheckpointableBase` defines this method, whose semantics are "if a subclass + does dependency tracking, this method exempts `value`." Layer uses + `_no_dependency` to exempt some of its attribute assignments (conditional on + attribute assignment causing tracking in the subclass). + + Args: + value: An object which will be assigned to an object attribute, whose + value should not be tracked. + + Returns: + A wrapped object which, when assigned to an attribute, will not be + tracked (`value` will be stored in the attribute). + """ + return data_structures.NoDependency(value) + def __setattr__(self, name, value): - no_dependency = isinstance(value, checkpointable.NoDependency) - if no_dependency: - value = value.value + if not getattr(self, '_setattr_tracking', True): + super(Network, self).__setattr__(name, value) + return + no_dependency = isinstance(value, data_structures.NoDependency) + value = data_structures.sticky_attribute_assignment( + checkpointable=self, value=value, name=name) if isinstance(value, ( base_layer.Layer, Network, @@ -377,7 +414,9 @@ class Network(base_layer.Layer): 'forgot to call `super(YourClass, self).__init__()`.' ' Always start with this line.') if not is_graph_network: - if value not in self._layers: + # We need to check object identity to avoid de-duplicating empty + # container types which compare equal. + if not any((layer is value for layer in self._layers)): self._layers.append(value) if hasattr(value, '_use_resource_variables'): # In subclassed models, legacy layers (tf.layers) must always use @@ -385,12 +424,6 @@ class Network(base_layer.Layer): value._use_resource_variables = True if (not no_dependency and isinstance(value, checkpointable.CheckpointableBase)): - # Layer (and therefore Network/Model) inherit from CheckpointableBase - # rather than Checkpointable, which means there is no Checkpointable - # __setattr__ override (it would be a performance issue for functional - # layers). Therefore Model tracks Checkpointable objects itself. - self._track_checkpointable( - checkpointable=value, name=name, overwrite=True) if ( # For subclassed models only, users may add extra weights/variables # simply by assigning them to attributes. not self._is_graph_network @@ -493,7 +526,8 @@ class Network(base_layer.Layer): @property def layers(self): - return self._layers + return checkpointable_layer_utils.filter_empty_layer_containers( + self._layers) def get_layer(self, name=None, index=None): """Retrieves a layer based on either its name (unique) or index. diff --git a/tensorflow/python/keras/engine/sequential.py b/tensorflow/python/keras/engine/sequential.py index cd76f08a32..371504a503 100644 --- a/tensorflow/python/keras/engine/sequential.py +++ b/tensorflow/python/keras/engine/sequential.py @@ -29,6 +29,7 @@ from tensorflow.python.keras.engine.input_layer import InputLayer from tensorflow.python.keras.engine.training import Model from tensorflow.python.keras.utils import layer_utils from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.training.checkpointable import base as checkpointable from tensorflow.python.util.tf_export import tf_export @@ -108,6 +109,7 @@ class Sequential(Model): return self._layers[1:] return self._layers + @checkpointable.no_automatic_dependency_tracking def add(self, layer): """Adds a layer instance on top of the layer stack. @@ -191,6 +193,7 @@ class Sequential(Model): else: self._layers.append(layer) + @checkpointable.no_automatic_dependency_tracking def pop(self): """Removes the last layer in the model. @@ -210,6 +213,7 @@ class Sequential(Model): self.outputs = [self.layers[-1].output] self.build() + @checkpointable.no_automatic_dependency_tracking def build(self, input_shape=None): if input_shape and not self.inputs: batch_shape = tuple(input_shape) diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py index fce6cbdb7a..8e632651fa 100644 --- a/tensorflow/python/keras/engine/training.py +++ b/tensorflow/python/keras/engine/training.py @@ -42,6 +42,7 @@ from tensorflow.python.keras.utils.generic_utils import slice_arrays from tensorflow.python.ops import array_ops from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import optimizer as tf_optimizer_module +from tensorflow.python.training.checkpointable import base as checkpointable from tensorflow.python.util.tf_export import tf_export @@ -115,6 +116,7 @@ class Model(Network): # Create a cache for dataset - uninitialized iterators self._dataset_iterator_cache = weakref.WeakKeyDictionary() + @checkpointable.no_automatic_dependency_tracking def compile(self, optimizer, loss=None, @@ -178,6 +180,11 @@ class Model(Network): raise ValueError('Only TF native optimizers are supported in Eager mode.') self.optimizer = optimizers.get(optimizer) + # We've disabled automatic dependency tracking for this method, but do want + # to add a checkpoint dependency on the optimizer if it's checkpointable. + if isinstance(self.optimizer, checkpointable.CheckpointableBase): + self._track_checkpointable( + self.optimizer, name='optimizer', overwrite=True) self.loss = loss self.metrics = metrics or [] self.loss_weights = loss_weights @@ -941,6 +948,7 @@ class Model(Network): str(x[0].shape[0]) + ' samples') return x, y, sample_weights + @checkpointable.no_automatic_dependency_tracking def _set_inputs(self, inputs, training=None): """Set model's input and output specs based on the input data received. @@ -989,6 +997,7 @@ class Model(Network): else: self._symbolic_set_inputs(inputs, training=training) + @checkpointable.no_automatic_dependency_tracking def _eager_set_inputs(self, inputs): """Set model's input and output specs based on the input data received. @@ -1041,6 +1050,7 @@ class Model(Network): 'output_%d' % (i + 1) for i in range(len(dummy_output_values))] self.built = True + @checkpointable.no_automatic_dependency_tracking def _symbolic_set_inputs(self, inputs, outputs=None, training=None): """Set model's inputs and output specs based. diff --git a/tensorflow/python/keras/model_subclassing_test.py b/tensorflow/python/keras/model_subclassing_test.py index b7e16a41dd..3ac4852eff 100644 --- a/tensorflow/python/keras/model_subclassing_test.py +++ b/tensorflow/python/keras/model_subclassing_test.py @@ -31,7 +31,7 @@ from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.platform import test -from tensorflow.python.training.checkpointable import base as checkpointable +from tensorflow.python.training.checkpointable import data_structures from tensorflow.python.training.rmsprop import RMSPropOptimizer try: @@ -679,8 +679,8 @@ class ModelSubclassingTest(test.TestCase): def __init__(self): super(Foo, self).__init__() self.isdep = keras.layers.Dense(1) - self.notdep = checkpointable.NoDependency(keras.layers.Dense(2)) - self.notdep_var = checkpointable.NoDependency( + self.notdep = data_structures.NoDependency(keras.layers.Dense(2)) + self.notdep_var = data_structures.NoDependency( resource_variable_ops.ResourceVariable(1., name='notdep_var')) m = Foo() diff --git a/tensorflow/python/keras/optimizers.py b/tensorflow/python/keras/optimizers.py index b02cafcf61..0b440185ca 100644 --- a/tensorflow/python/keras/optimizers.py +++ b/tensorflow/python/keras/optimizers.py @@ -31,7 +31,7 @@ from tensorflow.python.ops import state_ops from tensorflow.python.training import distribute as distribute_lib from tensorflow.python.training import optimizer as tf_optimizer_module from tensorflow.python.training import training_util -from tensorflow.python.training.checkpointable import tracking as checkpointable +from tensorflow.python.training.checkpointable import base as checkpointable from tensorflow.python.util.tf_export import tf_export @@ -688,12 +688,13 @@ class Nadam(Optimizer): return dict(list(base_config.items()) + list(config.items())) -class TFOptimizer(Optimizer, checkpointable.Checkpointable): +class TFOptimizer(Optimizer, checkpointable.CheckpointableBase): """Wrapper class for native TensorFlow optimizers. """ def __init__(self, optimizer): # pylint: disable=super-init-not-called self.optimizer = optimizer + self._track_checkpointable(optimizer, name='optimizer') with K.name_scope(self.__class__.__name__): self.iterations = K.variable(0, dtype='int64', name='iterations') diff --git a/tensorflow/python/training/checkpointable/BUILD b/tensorflow/python/training/checkpointable/BUILD index 54f359489e..35007653a0 100644 --- a/tensorflow/python/training/checkpointable/BUILD +++ b/tensorflow/python/training/checkpointable/BUILD @@ -47,6 +47,7 @@ py_library( srcs_version = "PY2AND3", deps = [ ":base", + ":data_structures", ], ) diff --git a/tensorflow/python/training/checkpointable/base.py b/tensorflow/python/training/checkpointable/base.py index 99c8098eca..e9c8c21905 100644 --- a/tensorflow/python/training/checkpointable/base.py +++ b/tensorflow/python/training/checkpointable/base.py @@ -33,6 +33,7 @@ from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import saveable_object from tensorflow.python.util import nest from tensorflow.python.util import serialization +from tensorflow.python.util import tf_decorator # Key where the object graph proto is saved in a TensorBundle @@ -340,6 +341,34 @@ _SlotVariableRestoration = collections.namedtuple( ]) +def no_automatic_dependency_tracking(method): + """Disables automatic dependency tracking on attribute assignment. + + Use to decorate any method of a Checkpointable object. Attribute assignment in + that method will not add dependencies (also respected in Model). Harmless if + used in a class which does not do automatic dependency tracking (which means + it's safe to use in base classes which may have subclasses which also inherit + from Checkpointable). + + Args: + method: The method to decorate. + Returns: + A decorated method which sets and un-sets automatic dependency tracking for + the object the method is called on (not thread safe). + """ + + def _method_wrapper(self, *args, **kwargs): + previous_value = getattr(self, "_setattr_tracking", True) + self._setattr_tracking = False # pylint: disable=protected-access + try: + method(self, *args, **kwargs) + finally: + self._setattr_tracking = previous_value # pylint: disable=protected-access + + return tf_decorator.make_decorator( + target=method, decorator_func=_method_wrapper) + + class CheckpointableBase(object): """Base class for `Checkpointable` objects without automatic dependencies. @@ -349,6 +378,11 @@ class CheckpointableBase(object): checks. """ + # CheckpointableBase does not do automatic dependency tracking, but uses the + # no_automatic_dependency_tracking decorator so it can avoid adding + # dependencies if a subclass is Checkpointable / inherits from Model (both of + # which have __setattr__ overrides). + @no_automatic_dependency_tracking def _maybe_initialize_checkpointable(self): """Initialize dependency management. @@ -386,6 +420,10 @@ class CheckpointableBase(object): # building. self._name_based_restores = set() + def _no_dependency(self, value): + """If automatic dependency tracking is enabled, ignores `value`.""" + return value + def _name_based_attribute_restore(self, checkpoint): """Restore the object's attributes from a name-based checkpoint.""" self._name_based_restores.add(checkpoint) @@ -733,28 +771,3 @@ class CheckpointableBase(object): return {OBJECT_CONFIG_JSON_KEY: functools.partial( PythonStringStateSaveable, state_callback=_state_callback)} - - -class NoDependency(object): - """Allows attribute assignment to `Checkpointable` objects with no dependency. - - Example usage: - ```python - obj = Checkpointable() - obj.has_dependency = tf.Variable(0., name="dep") - obj.no_dependency = NoDependency(tf.Variable(1., name="nodep")) - assert obj.no_dependency.name == "nodep:0" - ``` - - `obj` in this example has a dependency on the variable "dep", and both - attributes contain un-wrapped `Variable` objects. - - `NoDependency` also works with `tf.keras.Model`, but only for checkpoint - dependencies: wrapping a `Layer` in `NoDependency` will assign the (unwrapped) - `Layer` to the attribute without a checkpoint dependency, but the `Model` will - still track the `Layer` (so it will appear in `Model.layers`, and its - variables will appear in `Model.variables`). - """ - - def __init__(self, value): - self.value = value diff --git a/tensorflow/python/training/checkpointable/data_structures.py b/tensorflow/python/training/checkpointable/data_structures.py index c46585b417..019d43f09c 100644 --- a/tensorflow/python/training/checkpointable/data_structures.py +++ b/tensorflow/python/training/checkpointable/data_structures.py @@ -22,49 +22,126 @@ import collections import six from tensorflow.python.ops import variables -from tensorflow.python.training.checkpointable import base as checkpointable_lib +from tensorflow.python.training.checkpointable import base from tensorflow.python.training.checkpointable import layer_utils -# TODO(allenl): We could track regular Python data structures which get assigned -# to Checkpointable objects. Making this work with restore-on-create would be -# tricky; we'd need to re-create nested structures with our own wrapped objects -# on assignment to an attribute, and track the user's original structure to make -# sure they don't modify it except through the wrappers (since we could save the -# user's updated structure, but would have no way to support restore-on-create -# for those modifications). -# TODO(allenl): A dictionary data structure would be good too. -class CheckpointableDataStructure(checkpointable_lib.CheckpointableBase): +class NoDependency(object): + """Allows attribute assignment to `Checkpointable` objects with no dependency. + + Example usage: + ```python + obj = Checkpointable() + obj.has_dependency = tf.Variable(0., name="dep") + obj.no_dependency = NoDependency(tf.Variable(1., name="nodep")) + assert obj.no_dependency.name == "nodep:0" + ``` + + `obj` in this example has a dependency on the variable "dep", and both + attributes contain un-wrapped `Variable` objects. + + `NoDependency` also works with `tf.keras.Model`, but only for checkpoint + dependencies: wrapping a `Layer` in `NoDependency` will assign the (unwrapped) + `Layer` to the attribute without a checkpoint dependency, but the `Model` will + still track the `Layer` (so it will appear in `Model.layers`, and its + variables will appear in `Model.variables`). + """ + + def __init__(self, value): + self.value = value + + +def _wrap_or_unwrap(value): + """Wraps basic data structures, unwraps NoDependency objects.""" + if isinstance(value, NoDependency): + return value.value + if isinstance(value, base.CheckpointableBase): + return value # Skip conversion for already checkpointable objects. + elif isinstance(value, list): + return _ListWrapper(value) + else: + return value + # TODO(allenl): Handle other common data structures. Tuples will require + # special casing (tuple subclasses are not weak referenceable, so replacement + # with a wrapper that subclasses tuple on attribute assignment works poorly, + # and replacement with a wrapper that isn't a tuple is also problematic), + # probably a tree traversal where the leaves are non-tuples(/namedtuples) to + # come up with names. Dictionaries should look like lists. + + +def sticky_attribute_assignment(checkpointable, name, value): + """Adds dependencies, generally called from __setattr__. + + This behavior is shared between Checkpointable and Model. + + Respects NoDependency indicators, but otherwise makes checkpointable objects + out of common data structures and tracks objects by their attribute names. + + Args: + checkpointable: The object to add dependencies to (generally the one having + an attribute assigned). + name: The attribute name being assigned. + value: The value being assigned. Not necessarily a checkpointable object. + + Returns: + The value which should be stored in the attribute (unwrapped from a + NoDependency object if necessary). + """ + if isinstance(value, NoDependency): + add_dependency = False + else: + add_dependency = True + value = _wrap_or_unwrap(value) + if not add_dependency: + return value + if isinstance(value, base.CheckpointableBase): + checkpointable._track_checkpointable( # pylint: disable=protected-access + value, name=name, + # Allow the user to switch the Checkpointable which is tracked by this + # name, since assigning a new variable to an attribute has + # historically been fine (e.g. Adam did this). + overwrite=True) + return value + + +class CheckpointableDataStructure(base.CheckpointableBase): """Base class for data structures which contain checkpointable objects.""" def __init__(self): + # An append-only ordered set self._layers = [] + self.trainable = True self._extra_variables = [] def _track_value(self, value, name): """Add a dependency on `value`.""" - if isinstance(value, checkpointable_lib.CheckpointableBase): - self._track_checkpointable(value, name=name) - if isinstance(value, variables.Variable): - self._extra_variables.append(value) - else: + value = sticky_attribute_assignment( + checkpointable=self, value=value, name=name) + if isinstance(value, variables.Variable): + self._extra_variables.append(value) + if not isinstance(value, base.CheckpointableBase): raise ValueError( ("Only checkpointable objects (such as Layers or Optimizers) may be " "stored in a List object. Got %s, which does not inherit from " "CheckpointableBase.") % (value,)) if (isinstance(value, CheckpointableDataStructure) or layer_utils.is_layer(value)): - if value not in self._layers: + # Check for object-identity rather than with __eq__ to avoid + # de-duplicating empty container types. Automatically generated list + # wrappers keep things like "[] == []" true, which means "[] in [[]]" is + # also true. This becomes not true once one of the lists is mutated. + if not any((layer is value for layer in self._layers)): self._layers.append(value) if hasattr(value, "_use_resource_variables"): # In subclassed models, legacy layers (tf.layers) must always use # resource variables. value._use_resource_variables = True # pylint: disable=protected-access + return value @property def layers(self): - return self._layers + return layer_utils.filter_empty_layer_containers(self._layers) @property def trainable_weights(self): @@ -164,24 +241,28 @@ class List(CheckpointableDataStructure, collections.Sequence): def __init__(self, *args, **kwargs): """Construct a new sequence. Arguments are passed to `list()`.""" super(List, self).__init__() - self._storage = list(*args, **kwargs) + self._storage = self._make_storage(*args, **kwargs) for index, element in enumerate(self._storage): - self._track_value(element, name=self._name_element(index)) + self._storage[index] = self._track_value( + element, name=self._name_element(index)) + + def _make_storage(self, *args, **kwargs): + """Determines the backing storage (overridden in subclasses).""" + return list(*args, **kwargs) def _name_element(self, index): return "%d" % (index,) def append(self, value): """Add a new checkpointable value.""" - self._track_value(value, self._name_element(len(self._storage))) + value = self._track_value(value, self._name_element(len(self._storage))) self._storage.append(value) def extend(self, values): """Add a sequence of checkpointable values.""" - for index_offset, value in enumerate(values): - self._track_value( - value, name=self._name_element(len(self._storage) + index_offset)) - self._storage.extend(values) + for value in values: + self._storage.append(self._track_value( + value, name=self._name_element(len(self._storage)))) def __iadd__(self, values): self.extend(values) @@ -189,9 +270,12 @@ class List(CheckpointableDataStructure, collections.Sequence): def __add__(self, other): if isinstance(other, List): - return List(self._storage + other._storage) # pylint: disable=protected-access + return self.__class__(self._storage + other._storage) # pylint: disable=protected-access else: - return List(self._storage + other) + return self.__class__(self._storage + other) + + def __radd__(self, other): + return self + other def __getitem__(self, key): return self._storage[key] @@ -203,6 +287,144 @@ class List(CheckpointableDataStructure, collections.Sequence): return "List(%s)" % (repr(self._storage),) +class _ListWrapper(List, collections.MutableSequence, + # Shadowed, but there for isinstance checks. + list): + """Wraps the built-in `list` to support restore-on-create for variables. + + Unlike `List`, this sequence type is mutable in the same ways built-in lists + are. Instead of throwing an error immediately like `List`, it records + problematic mutations (e.g. assigning a new element to a position already + occupied, meaning both elements get the same names at different times) and + refuses to save. + + On assignment to an attribute of a Model or Checkpointable object, Python + lists are replaced with _ListWrapper. Wrapping a list in a + `tf.contrib.checkpoint.NoDependency` object prevents this. + """ + + def __init__(self, wrapped_list): + """Construct a new list wrapper. + + Args: + wrapped_list: The initial value of the data structure. A shallow copy may + be maintained for error checking. `wrapped_list` itself should not be + modified directly after constructing the `_ListWrapper`, and if changes + are detected the `_ListWrapper` will throw an exception on save. + """ + # Monotonic flags which indicate this object would not be restored properly, + # and therefore should throw an error on save to avoid giving the impression + # that restoring it will work. + self._non_append_mutation = False + self._external_modification = False + super(_ListWrapper, self).__init__(wrapped_list) + self._last_wrapped_list_snapshot = list(self._storage) + + def _make_storage(self, wrapped_list): + """Use the user's original list for storage.""" + return wrapped_list + + def _check_external_modification(self): + """Checks for any changes to the wrapped list not through the wrapper.""" + if self._external_modification or self._non_append_mutation: + return + if self._storage != self._last_wrapped_list_snapshot: + self._external_modification = True + self._last_wrapped_list_snapshot = None + + def _update_snapshot(self): + """Acknowledges tracked changes to the wrapped list.""" + if self._external_modification or self._non_append_mutation: + return + self._last_wrapped_list_snapshot = list(self._storage) + + @property + def _checkpoint_dependencies(self): + self._check_external_modification() + if self._non_append_mutation: + raise ValueError( + ("Unable to save the object %s (a list wrapper constructed to track " + "checkpointable TensorFlow objects). A list element was replaced " + "(__setitem__), deleted, or inserted. In order to support " + "restoration on object creation, tracking is exclusively for " + "append-only data structures.\n\nIf you don't need this list " + "checkpointed, wrap it in a tf.contrib.checkpoint.NoDependency " + "object; it will be automatically un-wrapped and subsequently " + "ignored." % (self,))) + if self._external_modification: + raise ValueError( + ("Unable to save the object %s (a list wrapper constructed to track " + "checkpointable TensorFlow objects). The wrapped list was modified " + "outside the wrapper (its final value was %s, its value when a " + "checkpoint dependency was added was %s), which breaks restoration " + "on object creation.\n\nIf you don't need this list checkpointed, " + "wrap it in a tf.contrib.checkpoint.NoDependency object; it will be " + "automatically un-wrapped and subsequently ignored." % ( + self, self._storage, self._last_wrapped_list_snapshot))) + return super(_ListWrapper, self)._checkpoint_dependencies + + def __delitem__(self, key): + self._non_append_mutation = True + del self._storage[key] + + def __setitem__(self, key, value): + self._non_append_mutation = True + self._storage[key] = value + + def append(self, value): + """Add a new checkpointable value.""" + self._check_external_modification() + super(_ListWrapper, self).append(value) + self._update_snapshot() + + def extend(self, values): + """Add a sequence of checkpointable values.""" + self._check_external_modification() + super(_ListWrapper, self).extend(values) + self._update_snapshot() + + def __eq__(self, other): + return self._storage == getattr(other, "_storage", other) + + def __ne__(self, other): + return self._storage != getattr(other, "_storage", other) + + def __lt__(self, other): + return self._storage < getattr(other, "_storage", other) + + def __le__(self, other): + return self._storage <= getattr(other, "_storage", other) + + def __gt__(self, other): + return self._storage > getattr(other, "_storage", other) + + def __ge__(self, other): + return self._storage >= getattr(other, "_storage", other) + + def __hash__(self): + # List wrappers need to compare like regular lists, and so like regular + # lists they don't belong in hash tables. + raise TypeError("unhashable type: 'ListWrapper'") + + def insert(self, index, obj): + self._non_append_mutation = True + self._storage.insert(index, obj) + + def _track_value(self, value, name): + """Allows storage of non-checkpointable objects.""" + try: + value = super(_ListWrapper, self)._track_value(value=value, name=name) + except ValueError: + # Even if this value isn't checkpointable, we need to make sure + # NoDependency objects get unwrapped. + value = sticky_attribute_assignment( + checkpointable=self, value=value, name=name) + return value + + def __repr__(self): + return "ListWrapper(%s)" % (repr(self._storage),) + + class Mapping(CheckpointableDataStructure, collections.Mapping): """An append-only checkpointable mapping data structure with string keys. @@ -217,8 +439,10 @@ class Mapping(CheckpointableDataStructure, collections.Mapping): """Construct a new sequence. Arguments are passed to `dict()`.""" super(Mapping, self).__init__() self._storage = dict(*args, **kwargs) - for key, value in self._storage.items(): - self._track_value(value, name=self._name_element(key)) + self._storage.update( + {key: self._track_value( + value, name=self._name_element(key)) + for key, value in self._storage.items()}) def _name_element(self, key): if not isinstance(key, six.string_types): @@ -228,13 +452,14 @@ class Mapping(CheckpointableDataStructure, collections.Mapping): return str(key) def __setitem__(self, key, value): + name = self._name_element(key) + value = self._track_value(value, name=name) current_value = self._storage.setdefault(key, value) if current_value is not value: raise ValueError( ("Mappings are an append-only data structure. Tried to overwrite the " "key '%s' with value %s, but it already contains %s") % (key, value, current_value)) - self._track_value(value, name=self._name_element(key)) def update(self, *args, **kwargs): for key, value in dict(*args, **kwargs).items(): diff --git a/tensorflow/python/training/checkpointable/data_structures_test.py b/tensorflow/python/training/checkpointable/data_structures_test.py index ce5852dd6e..ec8c9da809 100644 --- a/tensorflow/python/training/checkpointable/data_structures_test.py +++ b/tensorflow/python/training/checkpointable/data_structures_test.py @@ -31,6 +31,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.training.checkpointable import data_structures +from tensorflow.python.training.checkpointable import tracking class HasList(training.Model): @@ -113,6 +114,19 @@ class ListTests(test.TestCase): model(model_input) self.assertEqual(2, len(model.losses)) + def testModelContainersCompareEqual(self): + class HasEqualContainers(training.Model): + + def __init__(self): + super(HasEqualContainers, self).__init__() + self.l1 = [] + self.l2 = [] + + model = HasEqualContainers() + model.l1.append(HasEqualContainers()) + model.l2.append(HasEqualContainers()) + self.assertEqual([model.l1, model.l2], model.layers) + def testNotCheckpointable(self): class NotCheckpointable(object): pass @@ -158,11 +172,62 @@ class ListTests(test.TestCase): self.assertEqual([v], l.trainable_weights) self.assertEqual([v2], l.non_trainable_weights) + def testListWrapperBasic(self): + # _ListWrapper, unlike List, compares like the built-in list type (since it + # is used to automatically replace lists). + a = tracking.Checkpointable() + b = tracking.Checkpointable() + self.assertEqual([a, a], + [a, a]) + self.assertEqual(data_structures._ListWrapper([a, a]), + data_structures._ListWrapper([a, a])) + self.assertEqual([a, a], + data_structures._ListWrapper([a, a])) + self.assertEqual(data_structures._ListWrapper([a, a]), + [a, a]) + self.assertNotEqual([a, a], + [b, a]) + self.assertNotEqual(data_structures._ListWrapper([a, a]), + data_structures._ListWrapper([b, a])) + self.assertNotEqual([a, a], + data_structures._ListWrapper([b, a])) + self.assertLess([a], [a, b]) + self.assertLess(data_structures._ListWrapper([a]), + data_structures._ListWrapper([a, b])) + self.assertLessEqual([a], [a, b]) + self.assertLessEqual(data_structures._ListWrapper([a]), + data_structures._ListWrapper([a, b])) + self.assertGreater([a, b], [a]) + self.assertGreater(data_structures._ListWrapper([a, b]), + data_structures._ListWrapper([a])) + self.assertGreaterEqual([a, b], [a]) + self.assertGreaterEqual(data_structures._ListWrapper([a, b]), + data_structures._ListWrapper([a])) + self.assertEqual([a], data_structures._ListWrapper([a])) + self.assertEqual([a], list(data_structures.List([a]))) + self.assertEqual([a, a], data_structures._ListWrapper([a]) + [a]) + self.assertEqual([a, a], [a] + data_structures._ListWrapper([a])) + self.assertIsInstance(data_structures._ListWrapper([a]), list) + + def testWrapperChangesList(self): + l = [] + l_wrapper = data_structures._ListWrapper(l) + l_wrapper.append(1) + self.assertEqual([1], l) + + def testListChangesWrapper(self): + l = [] + l_wrapper = data_structures._ListWrapper(l) + l.append(1) + self.assertEqual([1], l_wrapper) + def testHashing(self): has_sequences = set([data_structures.List(), data_structures.List()]) self.assertEqual(2, len(has_sequences)) self.assertNotIn(data_structures.List(), has_sequences) + with self.assertRaises(TypeError): + has_sequences.add(data_structures._ListWrapper([])) class HasMapping(training.Model): diff --git a/tensorflow/python/training/checkpointable/layer_utils.py b/tensorflow/python/training/checkpointable/layer_utils.py index fdcf963d32..978fcb2252 100644 --- a/tensorflow/python/training/checkpointable/layer_utils.py +++ b/tensorflow/python/training/checkpointable/layer_utils.py @@ -30,6 +30,14 @@ def is_layer(obj): and hasattr(obj, "variables")) +def filter_empty_layer_containers(layer_list): + """Filter out empty Layer-like containers.""" + return [layer for layer in layer_list + # Filter out only empty Checkpointable data structures. Empty Networks + # will still show up in Model.layers. + if is_layer(layer) or getattr(layer, "layers", True)] + + def gather_trainable_weights(trainable, sub_layers, extra_variables): """Lists the trainable weights for an object with sub-layers. diff --git a/tensorflow/python/training/checkpointable/tracking.py b/tensorflow/python/training/checkpointable/tracking.py index 00e14ac982..bd0bed9d46 100644 --- a/tensorflow/python/training/checkpointable/tracking.py +++ b/tensorflow/python/training/checkpointable/tracking.py @@ -18,31 +18,7 @@ from __future__ import division from __future__ import print_function from tensorflow.python.training.checkpointable import base - - -class NoDependency(object): - """Allows attribute assignment to `Checkpointable` objects with no dependency. - - Example usage: - ```python - obj = Checkpointable() - obj.has_dependency = tf.Variable(0., name="dep") - obj.no_dependency = NoDependency(tf.Variable(1., name="nodep")) - assert obj.no_dependency.name == "nodep:0" - ``` - - `obj` in this example has a dependency on the variable "dep", and both - attributes contain un-wrapped `Variable` objects. - - `NoDependency` also works with `tf.keras.Model`, but only for checkpoint - dependencies: wrapping a `Layer` in `NoDependency` will assign the (unwrapped) - `Layer` to the attribute without a checkpoint dependency, but the `Model` will - still track the `Layer` (so it will appear in `Model.layers`, and its - variables will appear in `Model.variables`). - """ - - def __init__(self, value): - self.value = value +from tensorflow.python.training.checkpointable import data_structures class NotCheckpointable(object): @@ -86,18 +62,11 @@ class Checkpointable(base.CheckpointableBase): def __setattr__(self, name, value): """Support self.foo = checkpointable syntax.""" - # Perform the attribute assignment, and potentially call other __setattr__ - # overrides such as that for tf.keras.Model. - no_dependency = isinstance(value, NoDependency) - if no_dependency: - value = value.value + if getattr(self, "_setattr_tracking", True): + value = data_structures.sticky_attribute_assignment( + checkpointable=self, value=value, name=name) super(Checkpointable, self).__setattr__(name, value) - if not no_dependency and isinstance(value, base.CheckpointableBase): - self._track_checkpointable( - value, name=name, - # Allow the user to switch the Checkpointable which is tracked by this - # name, since assigning a new variable to an attribute has - # historically been fine (e.g. Adam did this). - # TODO(allenl): Should this be a warning once Checkpointable save/load - # is usable? - overwrite=True) + + def _no_dependency(self, value): + """Override to allow CheckpointableBase to disable dependency tracking.""" + return data_structures.NoDependency(value) diff --git a/tensorflow/python/training/checkpointable/tracking_test.py b/tensorflow/python/training/checkpointable/tracking_test.py index baf6f57efb..f0178b074d 100644 --- a/tensorflow/python/training/checkpointable/tracking_test.py +++ b/tensorflow/python/training/checkpointable/tracking_test.py @@ -16,8 +16,19 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os + +import numpy + +from tensorflow.python.framework import test_util +from tensorflow.python.keras.engine import training +from tensorflow.python.ops import array_ops from tensorflow.python.platform import test +from tensorflow.python.training.checkpointable import base +from tensorflow.python.training.checkpointable import data_structures from tensorflow.python.training.checkpointable import tracking +from tensorflow.python.training.checkpointable import util +from tensorflow.python.util import nest class InterfaceTests(test.TestCase): @@ -27,7 +38,7 @@ class InterfaceTests(test.TestCase): root.leaf = tracking.Checkpointable() root.leaf = root.leaf duplicate_name_dep = tracking.Checkpointable() - with self.assertRaises(ValueError): + with self.assertRaisesRegexp(ValueError, "already declared"): root._track_checkpointable(duplicate_name_dep, name="leaf") # No error; we're overriding __setattr__, so we can't really stop people # from doing this while maintaining backward compatibility. @@ -39,11 +50,119 @@ class InterfaceTests(test.TestCase): hasdep = tracking.Checkpointable() root.hasdep = hasdep nodep = tracking.Checkpointable() - root.nodep = tracking.NoDependency(nodep) + root.nodep = data_structures.NoDependency(nodep) self.assertEqual(1, len(root._checkpoint_dependencies)) self.assertIs(root._checkpoint_dependencies[0].ref, root.hasdep) self.assertIs(root.hasdep, hasdep) self.assertIs(root.nodep, nodep) + class NoDependencyModel(training.Model): + + @base.no_automatic_dependency_tracking + def __init__(self): + super(NoDependencyModel, self).__init__() + self.a = [] + self.b = tracking.Checkpointable() + + nodeps = NoDependencyModel() + self.assertEqual([nodeps], util.list_objects(nodeps)) + + def testListBasic(self): + a = tracking.Checkpointable() + b = tracking.Checkpointable() + a.l = [b] + c = tracking.Checkpointable() + a.l.append(c) + a_deps = util.list_objects(a) + self.assertIn(b, a_deps) + self.assertIn(c, a_deps) + direct_a_dep, = a._checkpoint_dependencies + self.assertEqual("l", direct_a_dep.name) + self.assertIn(b, direct_a_dep.ref) + self.assertIn(c, direct_a_dep.ref) + + @test_util.run_in_graph_and_eager_modes + def testMutationDirtiesList(self): + a = tracking.Checkpointable() + b = tracking.Checkpointable() + a.l = [b] + c = tracking.Checkpointable() + a.l.insert(0, c) + checkpoint = util.Checkpoint(a=a) + with self.assertRaisesRegexp(ValueError, "A list element was replaced"): + checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt")) + + @test_util.run_in_graph_and_eager_modes + def testOutOfBandEditDirtiesList(self): + a = tracking.Checkpointable() + b = tracking.Checkpointable() + held_reference = [b] + a.l = held_reference + c = tracking.Checkpointable() + held_reference.append(c) + checkpoint = util.Checkpoint(a=a) + with self.assertRaisesRegexp(ValueError, "The wrapped list was modified"): + checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt")) + + @test_util.run_in_graph_and_eager_modes + def testNestedLists(self): + a = tracking.Checkpointable() + a.l = [] + b = tracking.Checkpointable() + a.l.append([b]) + c = tracking.Checkpointable() + a.l[0].append(c) + a_deps = util.list_objects(a) + self.assertIn(b, a_deps) + self.assertIn(c, a_deps) + a.l[0].append(1) + d = tracking.Checkpointable() + a.l[0].append(d) + a_deps = util.list_objects(a) + self.assertIn(d, a_deps) + self.assertIn(b, a_deps) + self.assertIn(c, a_deps) + self.assertNotIn(1, a_deps) + e = tracking.Checkpointable() + f = tracking.Checkpointable() + a.l1 = [[], [e]] + a.l1[0].append(f) + a_deps = util.list_objects(a) + self.assertIn(e, a_deps) + self.assertIn(f, a_deps) + checkpoint = util.Checkpoint(a=a) + checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt")) + a.l[0].append(data_structures.NoDependency([])) + a.l[0][-1].append(5) + checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt")) + # Dirtying the inner list means the root object is unsaveable. + a.l[0][1] = 2 + with self.assertRaisesRegexp(ValueError, "A list element was replaced"): + checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt")) + + @test_util.run_in_graph_and_eager_modes + def testNoDepList(self): + a = training.Model() + a.l1 = data_structures.NoDependency([]) + a.l1.insert(1, 0) + self.assertTrue(isinstance(a.l1, list)) + checkpoint = util.Checkpoint(a=a) + checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt")) + a.l2 = [] + a.l2.insert(1, 0) + with self.assertRaisesRegexp(ValueError, "A list element was replaced"): + checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt")) + + @test_util.run_in_graph_and_eager_modes + def testAssertions(self): + a = tracking.Checkpointable() + a.l = [numpy.zeros([2, 2])] + self.assertAllEqual([numpy.zeros([2, 2])], a.l) + self.assertAllClose([numpy.zeros([2, 2])], a.l) + nest.map_structure(self.assertAllClose, a.l, [numpy.zeros([2, 2])]) + a.tensors = [array_ops.ones([2, 2]), array_ops.zeros([3, 3])] + self.assertAllClose([numpy.ones([2, 2]), numpy.zeros([3, 3])], + self.evaluate(a.tensors)) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/training/checkpointable/util.py b/tensorflow/python/training/checkpointable/util.py index e0f61137b1..6ae5765b13 100644 --- a/tensorflow/python/training/checkpointable/util.py +++ b/tensorflow/python/training/checkpointable/util.py @@ -40,6 +40,7 @@ from tensorflow.python.training import optimizer as optimizer_lib from tensorflow.python.training import saveable_object as saveable_object_lib from tensorflow.python.training import saver as saver_lib from tensorflow.python.training.checkpointable import base +from tensorflow.python.training.checkpointable import data_structures from tensorflow.python.training.checkpointable import tracking from tensorflow.python.util import deprecation from tensorflow.python.util import tf_contextlib @@ -93,7 +94,7 @@ class _CheckpointRestoreCoordinator(object): # use them (for example because of inconsistent references when # loading). Used to make status assertions fail when loading checkpoints # that don't quite match. - self.all_python_objects = weakref.WeakSet() + self.all_python_objects = _ObjectIdentityWeakSet() self.save_path = save_path self.dtype_map = dtype_map # When graph building, contains a list of ops to run to restore objects from @@ -272,11 +273,129 @@ def object_metadata(save_path): return object_graph_proto +class _ObjectIdentityWrapper(object): + """Wraps an object, mapping __eq__ on wrapper to "is" on wrapped. + + Since __eq__ is based on object identity, it's safe to also define __hash__ + based on object ids. This lets us add unhashable types like checkpointable + _ListWrapper objects to object-identity collections. + """ + + def __init__(self, wrapped): + self._wrapped = wrapped + + @property + def unwrapped(self): + return self._wrapped + + def __eq__(self, other): + if isinstance(other, _ObjectIdentityWrapper): + return self._wrapped is other._wrapped # pylint: disable=protected-access + return self._wrapped is other + + def __hash__(self): + # Wrapper id() is also fine for weakrefs. In fact, we rely on + # id(weakref.ref(a)) == id(weakref.ref(a)) and weakref.ref(a) is + # weakref.ref(a) in _WeakObjectIdentityWrapper. + return id(self._wrapped) + + +class _WeakObjectIdentityWrapper(_ObjectIdentityWrapper): + + def __init__(self, wrapped): + super(_WeakObjectIdentityWrapper, self).__init__(weakref.ref(wrapped)) + + @property + def unwrapped(self): + return self._wrapped() + + +class _ObjectIdentityDictionary(collections.MutableMapping): + """A mutable mapping data structure which compares using "is". + + This is necessary because we have checkpointable objects (_ListWrapper) which + have behavior identical to built-in Python lists (including being unhashable + and comparing based on the equality of their contents by default). + """ + + def __init__(self): + self._storage = {} + + def _wrap_key(self, key): + return _ObjectIdentityWrapper(key) + + def __getitem__(self, key): + return self._storage[self._wrap_key(key)] + + def __setitem__(self, key, value): + self._storage[self._wrap_key(key)] = value + + def __delitem__(self, key): + del self._storage[self._wrap_key(key)] + + def __len__(self): + return len(self._storage) + + def __iter__(self): + for key in self._storage: + yield key.unwrapped + + +class _ObjectIdentityWeakKeyDictionary(_ObjectIdentityDictionary): + """Like weakref.WeakKeyDictionary, but compares objects with "is".""" + + def _wrap_key(self, key): + return _WeakObjectIdentityWrapper(key) + + def __len__(self): + # Iterate, discarding old weak refs + return len(list(self._storage)) + + def __iter__(self): + keys = self._storage.keys() + for key in keys: + unwrapped = key.unwrapped + if unwrapped is None: + del self[key] + else: + yield unwrapped + + +class _ObjectIdentityWeakSet(collections.MutableSet): + """Like weakref.WeakSet, but compares objects with "is".""" + + def __init__(self): + self._storage = set() + + def __contains__(self, key): + return _WeakObjectIdentityWrapper(key) in self._storage + + def discard(self, key): + self._storage.discard(_WeakObjectIdentityWrapper(key)) + + def add(self, key): + self._storage.add(_WeakObjectIdentityWrapper(key)) + + def __len__(self): + # Iterate, discarding old weak refs + return len(list(self)) + + def __iter__(self): + keys = list(self._storage) + for key in keys: + unwrapped = key.unwrapped + if unwrapped is None: + self.discard(key) + else: + yield unwrapped + + def _breadth_first_checkpointable_traversal(root_checkpointable): """Find shortest paths to all variables owned by dependencies of root.""" bfs_sorted = [] to_visit = collections.deque([root_checkpointable]) - path_to_root = {root_checkpointable: ()} + path_to_root = _ObjectIdentityDictionary() + path_to_root[root_checkpointable] = () while to_visit: current_checkpointable = to_visit.popleft() if isinstance(current_checkpointable, tracking.NotCheckpointable): @@ -337,7 +456,7 @@ def _slot_variable_naming_for_optimizer(optimizer_path): def _serialize_slot_variables(checkpointable_objects, node_ids, object_names): """Gather and name slot variables.""" non_slot_objects = list(checkpointable_objects) - slot_variables = {} + slot_variables = _ObjectIdentityDictionary() for checkpointable in non_slot_objects: if isinstance(checkpointable, optimizer_lib.Optimizer): naming_scheme = _slot_variable_naming_for_optimizer( @@ -500,11 +619,12 @@ def _serialize_object_graph(root_checkpointable, saveables_cache): """ checkpointable_objects, path_to_root = ( _breadth_first_checkpointable_traversal(root_checkpointable)) - object_names = { - obj: _object_prefix_from_path(path) - for obj, path in path_to_root.items()} - node_ids = {node: node_id for node_id, node - in enumerate(checkpointable_objects)} + object_names = _ObjectIdentityDictionary() + for obj, path in path_to_root.items(): + object_names[obj] = _object_prefix_from_path(path) + node_ids = _ObjectIdentityDictionary() + for node_id, node in enumerate(checkpointable_objects): + node_ids[node] = node_id slot_variables = _serialize_slot_variables( checkpointable_objects=checkpointable_objects, node_ids=node_ids, @@ -535,11 +655,12 @@ def list_objects(root_checkpointable): # to run. checkpointable_objects, path_to_root = ( _breadth_first_checkpointable_traversal(root_checkpointable)) - object_names = { - obj: _object_prefix_from_path(path) - for obj, path in path_to_root.items()} - node_ids = {node: node_id for node_id, node - in enumerate(checkpointable_objects)} + object_names = _ObjectIdentityDictionary() + for obj, path in path_to_root.items(): + object_names[obj] = _object_prefix_from_path(path) + node_ids = _ObjectIdentityDictionary() + for node_id, node in enumerate(checkpointable_objects): + node_ids[node] = node_id _serialize_slot_variables( checkpointable_objects=checkpointable_objects, node_ids=node_ids, @@ -988,7 +1109,7 @@ class CheckpointableSaver(object): else: # Maps Checkpointable objects -> attribute names -> SaveableObjects, to # avoid re-creating SaveableObjects when graph building. - self._saveable_object_cache = weakref.WeakKeyDictionary() + self._saveable_object_cache = _ObjectIdentityWeakKeyDictionary() @property def _root_checkpointable(self): @@ -1310,7 +1431,7 @@ class Checkpoint(tracking.Checkpointable): with ops.device("/cpu:0"): # add_variable creates a dependency named "save_counter"; NoDependency # prevents creating a second dependency named "_save_counter". - self._save_counter = tracking.NoDependency( + self._save_counter = data_structures.NoDependency( add_variable(self, name="save_counter", initializer=0, dtype=dtypes.int64)) diff --git a/tensorflow/python/util/nest.py b/tensorflow/python/util/nest.py index 1104768ae8..d63f59a8c8 100644 --- a/tensorflow/python/util/nest.py +++ b/tensorflow/python/util/nest.py @@ -167,11 +167,14 @@ def assert_same_structure(nest1, nest2, check_types=True): Args: nest1: an arbitrarily nested structure. nest2: an arbitrarily nested structure. - check_types: if `True` (default) types of sequences are checked as - well, including the keys of dictionaries. If set to `False`, for example - a list and a tuple of objects will look the same if they have the same + check_types: if `True` (default) types of sequences are checked as well, + including the keys of dictionaries. If set to `False`, for example a + list and a tuple of objects will look the same if they have the same size. Note that namedtuples with identical name and fields are always - considered to have the same shallow structure. + considered to have the same shallow structure. Two types will also be + considered the same if they are both list subtypes (which allows "list" + and "_ListWrapper" from checkpointable dependency tracking to compare + equal). Raises: ValueError: If the two structures do not have the same number of elements or diff --git a/tensorflow/python/util/util.cc b/tensorflow/python/util/util.cc index c79d8a8445..366f8a0deb 100644 --- a/tensorflow/python/util/util.cc +++ b/tensorflow/python/util/util.cc @@ -394,7 +394,11 @@ bool AssertSameStructureHelper(PyObject* o1, PyObject* o2, bool check_types, type2->tp_name); return true; } - } else if (type1 != type2) { + } else if (type1 != type2 + /* If both sequences are list types, don't complain. This allows + one to be a list subclass (e.g. _ListWrapper used for automatic + dependency tracking.) */ + && !(PyList_Check(o1) && PyList_Check(o2))) { *is_type_error = true; *error_msg = tensorflow::strings::StrCat( "The two namedtuples don't have the same sequence type. " |