aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Allen Lavoie <allenl@google.com>2018-06-29 14:02:26 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-29 14:04:49 -0700
commitdcaa037571ab0933977f70574f4f78875155ae20 (patch)
tree4968e1966ca334f42296beae6cb1ecd8d483215e
parentb3c163a754574faed4337f869c2d650a9f45c09c (diff)
Auto tracking for Python lists assigned to attributes of Model/Checkpointable
Conceptually lists just get replaced with a list-like wrapper. A shallow copy is maintained for error checking (since appends to it aren't monitored, we can't do restore-on-create for variables unless it's being modified through the wrapper). There are lots of other details. I gave up on generalizing our isinstance(obj, (list, tuple)) checks and just subclassed list. Behaving like a list means the type should be unhashable, which requires some workarounds when we're collecting objects (object-identity collections, and object-identity versions of weak reference containers). Adds a decorator for exempting whole methods from automatic dependency tracking so we don't need to track down every last self.inputs = [] statement to avoid polluting dependencies. There's a TODO for tuples and dictionaries. PiperOrigin-RevId: 202703271
-rw-r--r--tensorflow/contrib/checkpoint/__init__.py2
-rw-r--r--tensorflow/contrib/checkpoint/python/containers_test.py3
-rw-r--r--tensorflow/python/client/session.py2
-rw-r--r--tensorflow/python/estimator/keras.py13
-rw-r--r--tensorflow/python/keras/engine/base_layer.py6
-rw-r--r--tensorflow/python/keras/engine/network.py56
-rw-r--r--tensorflow/python/keras/engine/sequential.py4
-rw-r--r--tensorflow/python/keras/engine/training.py10
-rw-r--r--tensorflow/python/keras/model_subclassing_test.py6
-rw-r--r--tensorflow/python/keras/optimizers.py5
-rw-r--r--tensorflow/python/training/checkpointable/BUILD1
-rw-r--r--tensorflow/python/training/checkpointable/base.py63
-rw-r--r--tensorflow/python/training/checkpointable/data_structures.py283
-rw-r--r--tensorflow/python/training/checkpointable/data_structures_test.py65
-rw-r--r--tensorflow/python/training/checkpointable/layer_utils.py8
-rw-r--r--tensorflow/python/training/checkpointable/tracking.py47
-rw-r--r--tensorflow/python/training/checkpointable/tracking_test.py123
-rw-r--r--tensorflow/python/training/checkpointable/util.py151
-rw-r--r--tensorflow/python/util/nest.py11
-rw-r--r--tensorflow/python/util/util.cc6
20 files changed, 728 insertions, 137 deletions
diff --git a/tensorflow/contrib/checkpoint/__init__.py b/tensorflow/contrib/checkpoint/__init__.py
index 8c1ce5c2a2..2fbaa31d5e 100644
--- a/tensorflow/contrib/checkpoint/__init__.py
+++ b/tensorflow/contrib/checkpoint/__init__.py
@@ -44,8 +44,8 @@ from tensorflow.core.protobuf.checkpointable_object_graph_pb2 import Checkpointa
from tensorflow.python.training.checkpointable.base import CheckpointableBase
from tensorflow.python.training.checkpointable.data_structures import List
from tensorflow.python.training.checkpointable.data_structures import Mapping
+from tensorflow.python.training.checkpointable.data_structures import NoDependency
from tensorflow.python.training.checkpointable.tracking import Checkpointable
-from tensorflow.python.training.checkpointable.tracking import NoDependency
from tensorflow.python.training.checkpointable.util import capture_dependencies
from tensorflow.python.training.checkpointable.util import list_objects
from tensorflow.python.training.checkpointable.util import object_metadata
diff --git a/tensorflow/contrib/checkpoint/python/containers_test.py b/tensorflow/contrib/checkpoint/python/containers_test.py
index 64d056bd68..ac85c7be80 100644
--- a/tensorflow/contrib/checkpoint/python/containers_test.py
+++ b/tensorflow/contrib/checkpoint/python/containers_test.py
@@ -26,6 +26,7 @@ from tensorflow.python.keras import layers
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.platform import test
+from tensorflow.python.training.checkpointable import data_structures
from tensorflow.python.training.checkpointable import tracking
from tensorflow.python.training.checkpointable import util
@@ -79,7 +80,7 @@ class UniqueNameTrackerTests(test.TestCase):
resource_variable_ops.ResourceVariable(4.), "y"))
slots.append(slotdeps.track(
resource_variable_ops.ResourceVariable(5.), "x"))
- self.slots = slots
+ self.slots = data_structures.NoDependency(slots)
manager = SlotManager()
self.evaluate([v.initializer for v in manager.slots])
diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py
index f3b788f931..e037925961 100644
--- a/tensorflow/python/client/session.py
+++ b/tensorflow/python/client/session.py
@@ -361,7 +361,7 @@ class _ListFetchMapper(_FetchMapper):
for m, vi in zip(self._mappers, self._value_indices):
results.append(m.build_results([values[j] for j in vi]))
# Return a value of the original type of the fetches.
- if self._fetch_type == list:
+ if issubclass(self._fetch_type, list):
return results
elif self._fetch_type == tuple:
return tuple(results)
diff --git a/tensorflow/python/estimator/keras.py b/tensorflow/python/estimator/keras.py
index 5769f5739c..cb37f99704 100644
--- a/tensorflow/python/estimator/keras.py
+++ b/tensorflow/python/estimator/keras.py
@@ -45,6 +45,8 @@ from tensorflow.python.saved_model import signature_constants
from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.training import saver as saver_lib
from tensorflow.python.training import training_util
+from tensorflow.python.training.checkpointable import base as checkpointable
+from tensorflow.python.training.checkpointable import data_structures
_DEFAULT_SERVING_KEY = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
@@ -241,8 +243,17 @@ def _in_place_subclassed_model_state_restoration(model):
# Restore layers and build attributes
if (hasattr(model, '_original_attributes_cache') and
model._original_attributes_cache is not None):
- model._layers = []
+ # Models have sticky attribute assignment, so we want to be careful to add
+ # back the previous attributes and track Layers by their original names
+ # without adding dependencies on "utility" attributes which Models exempt
+ # when they're constructed.
+ model._layers = data_structures.NoDependency([])
for name, value in model._original_attributes_cache.items():
+ if not isinstance(value, checkpointable.CheckpointableBase):
+ # If this value is not already checkpointable, it's probably that way
+ # for a reason; we don't want to start tracking data structures that the
+ # original Model didn't.
+ value = data_structures.NoDependency(value)
setattr(model, name, value)
model._original_attributes_cache = None
else:
diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py
index 4814275fd5..361778570b 100644
--- a/tensorflow/python/keras/engine/base_layer.py
+++ b/tensorflow/python/keras/engine/base_layer.py
@@ -116,6 +116,7 @@ class Layer(checkpointable.CheckpointableBase):
constraints on inputs that can be accepted by the layer.
"""
+ @checkpointable.no_automatic_dependency_tracking
def __init__(self, trainable=True, name=None, dtype=None, **kwargs):
# These properties should be set by the user via keyword arguments.
# note that 'dtype', 'input_shape' and 'batch_input_shape'
@@ -217,7 +218,7 @@ class Layer(checkpointable.CheckpointableBase):
@activity_regularizer.setter
def activity_regularizer(self, regularizer):
"""Optional regularizer function for the output of this layer."""
- self._activity_regularizer = regularizer
+ self._activity_regularizer = self._no_dependency(regularizer)
@property
def trainable_weights(self):
@@ -658,7 +659,8 @@ class Layer(checkpointable.CheckpointableBase):
self._compute_previous_mask):
previous_mask = collect_previous_mask(inputs)
if not hasattr(self, '_call_fn_args'):
- self._call_fn_args = function_utils.fn_args(self.call)
+ self._call_fn_args = self._no_dependency(
+ function_utils.fn_args(self.call))
if ('mask' in self._call_fn_args and 'mask' not in kwargs and
not generic_utils.is_all_none(previous_mask)):
# The previous layer generated a mask, and mask was not explicitly pass
diff --git a/tensorflow/python/keras/engine/network.py b/tensorflow/python/keras/engine/network.py
index aa84eaa8ab..a4d96de74f 100644
--- a/tensorflow/python/keras/engine/network.py
+++ b/tensorflow/python/keras/engine/network.py
@@ -81,6 +81,20 @@ class Network(base_layer.Layer):
# Subclassed network
self._init_subclassed_network(**kwargs)
+ # Several Network methods have "no_automatic_dependency_tracking"
+ # annotations. Since Network does automatic dependency tracking on attribute
+ # assignment, including for common data structures such as lists, by default
+ # we'd have quite a few empty dependencies which users don't care about (or
+ # would need some way to ignore dependencies automatically, which is confusing
+ # when applied to user code). Some attributes, such as _layers, would cause
+ # structural issues (_layers being the place where Layers assigned to tracked
+ # attributes are stored).
+ #
+ # Aside from these aesthetic and structural issues, useless dependencies on
+ # empty lists shouldn't cause issues; adding or removing them will not break
+ # checkpoints, but may cause "all Python objects matched" assertions to fail
+ # (in which case less strict assertions may be substituted if necessary).
+ @checkpointable.no_automatic_dependency_tracking
def _base_init(self, name=None):
# The following are implemented as property functions:
# self.trainable_weights
@@ -135,6 +149,7 @@ class Network(base_layer.Layer):
# restore operations when graph building.
self._in_progress_restore_finalizer = None
+ @checkpointable.no_automatic_dependency_tracking
def _init_graph_network(self, inputs, outputs, name=None):
self._call_convention = base_layer.CallConvention.EXPLICIT_INPUTS_ARGUMENT
# Normalize and set self.inputs, self.outputs.
@@ -293,6 +308,7 @@ class Network(base_layer.Layer):
for layer in self._output_layers:
self.output_names.append(layer.name)
+ @checkpointable.no_automatic_dependency_tracking
def _init_subclassed_network(self, name=None):
self._base_init(name=name)
self._is_graph_network = False
@@ -362,10 +378,31 @@ class Network(base_layer.Layer):
self._track_checkpointable(
layer, name='layer-%d' % layer_index, overwrite=True)
+ def _no_dependency(self, value):
+ """Override to allow `Layer` to disable dependency tracking.
+
+ `CheckpointableBase` defines this method, whose semantics are "if a subclass
+ does dependency tracking, this method exempts `value`." Layer uses
+ `_no_dependency` to exempt some of its attribute assignments (conditional on
+ attribute assignment causing tracking in the subclass).
+
+ Args:
+ value: An object which will be assigned to an object attribute, whose
+ value should not be tracked.
+
+ Returns:
+ A wrapped object which, when assigned to an attribute, will not be
+ tracked (`value` will be stored in the attribute).
+ """
+ return data_structures.NoDependency(value)
+
def __setattr__(self, name, value):
- no_dependency = isinstance(value, checkpointable.NoDependency)
- if no_dependency:
- value = value.value
+ if not getattr(self, '_setattr_tracking', True):
+ super(Network, self).__setattr__(name, value)
+ return
+ no_dependency = isinstance(value, data_structures.NoDependency)
+ value = data_structures.sticky_attribute_assignment(
+ checkpointable=self, value=value, name=name)
if isinstance(value, (
base_layer.Layer,
Network,
@@ -377,7 +414,9 @@ class Network(base_layer.Layer):
'forgot to call `super(YourClass, self).__init__()`.'
' Always start with this line.')
if not is_graph_network:
- if value not in self._layers:
+ # We need to check object identity to avoid de-duplicating empty
+ # container types which compare equal.
+ if not any((layer is value for layer in self._layers)):
self._layers.append(value)
if hasattr(value, '_use_resource_variables'):
# In subclassed models, legacy layers (tf.layers) must always use
@@ -385,12 +424,6 @@ class Network(base_layer.Layer):
value._use_resource_variables = True
if (not no_dependency
and isinstance(value, checkpointable.CheckpointableBase)):
- # Layer (and therefore Network/Model) inherit from CheckpointableBase
- # rather than Checkpointable, which means there is no Checkpointable
- # __setattr__ override (it would be a performance issue for functional
- # layers). Therefore Model tracks Checkpointable objects itself.
- self._track_checkpointable(
- checkpointable=value, name=name, overwrite=True)
if ( # For subclassed models only, users may add extra weights/variables
# simply by assigning them to attributes.
not self._is_graph_network
@@ -493,7 +526,8 @@ class Network(base_layer.Layer):
@property
def layers(self):
- return self._layers
+ return checkpointable_layer_utils.filter_empty_layer_containers(
+ self._layers)
def get_layer(self, name=None, index=None):
"""Retrieves a layer based on either its name (unique) or index.
diff --git a/tensorflow/python/keras/engine/sequential.py b/tensorflow/python/keras/engine/sequential.py
index cd76f08a32..371504a503 100644
--- a/tensorflow/python/keras/engine/sequential.py
+++ b/tensorflow/python/keras/engine/sequential.py
@@ -29,6 +29,7 @@ from tensorflow.python.keras.engine.input_layer import InputLayer
from tensorflow.python.keras.engine.training import Model
from tensorflow.python.keras.utils import layer_utils
from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.training.checkpointable import base as checkpointable
from tensorflow.python.util.tf_export import tf_export
@@ -108,6 +109,7 @@ class Sequential(Model):
return self._layers[1:]
return self._layers
+ @checkpointable.no_automatic_dependency_tracking
def add(self, layer):
"""Adds a layer instance on top of the layer stack.
@@ -191,6 +193,7 @@ class Sequential(Model):
else:
self._layers.append(layer)
+ @checkpointable.no_automatic_dependency_tracking
def pop(self):
"""Removes the last layer in the model.
@@ -210,6 +213,7 @@ class Sequential(Model):
self.outputs = [self.layers[-1].output]
self.build()
+ @checkpointable.no_automatic_dependency_tracking
def build(self, input_shape=None):
if input_shape and not self.inputs:
batch_shape = tuple(input_shape)
diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py
index fce6cbdb7a..8e632651fa 100644
--- a/tensorflow/python/keras/engine/training.py
+++ b/tensorflow/python/keras/engine/training.py
@@ -42,6 +42,7 @@ from tensorflow.python.keras.utils.generic_utils import slice_arrays
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import optimizer as tf_optimizer_module
+from tensorflow.python.training.checkpointable import base as checkpointable
from tensorflow.python.util.tf_export import tf_export
@@ -115,6 +116,7 @@ class Model(Network):
# Create a cache for dataset - uninitialized iterators
self._dataset_iterator_cache = weakref.WeakKeyDictionary()
+ @checkpointable.no_automatic_dependency_tracking
def compile(self,
optimizer,
loss=None,
@@ -178,6 +180,11 @@ class Model(Network):
raise ValueError('Only TF native optimizers are supported in Eager mode.')
self.optimizer = optimizers.get(optimizer)
+ # We've disabled automatic dependency tracking for this method, but do want
+ # to add a checkpoint dependency on the optimizer if it's checkpointable.
+ if isinstance(self.optimizer, checkpointable.CheckpointableBase):
+ self._track_checkpointable(
+ self.optimizer, name='optimizer', overwrite=True)
self.loss = loss
self.metrics = metrics or []
self.loss_weights = loss_weights
@@ -941,6 +948,7 @@ class Model(Network):
str(x[0].shape[0]) + ' samples')
return x, y, sample_weights
+ @checkpointable.no_automatic_dependency_tracking
def _set_inputs(self, inputs, training=None):
"""Set model's input and output specs based on the input data received.
@@ -989,6 +997,7 @@ class Model(Network):
else:
self._symbolic_set_inputs(inputs, training=training)
+ @checkpointable.no_automatic_dependency_tracking
def _eager_set_inputs(self, inputs):
"""Set model's input and output specs based on the input data received.
@@ -1041,6 +1050,7 @@ class Model(Network):
'output_%d' % (i + 1) for i in range(len(dummy_output_values))]
self.built = True
+ @checkpointable.no_automatic_dependency_tracking
def _symbolic_set_inputs(self, inputs, outputs=None, training=None):
"""Set model's inputs and output specs based.
diff --git a/tensorflow/python/keras/model_subclassing_test.py b/tensorflow/python/keras/model_subclassing_test.py
index b7e16a41dd..3ac4852eff 100644
--- a/tensorflow/python/keras/model_subclassing_test.py
+++ b/tensorflow/python/keras/model_subclassing_test.py
@@ -31,7 +31,7 @@ from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.platform import test
-from tensorflow.python.training.checkpointable import base as checkpointable
+from tensorflow.python.training.checkpointable import data_structures
from tensorflow.python.training.rmsprop import RMSPropOptimizer
try:
@@ -679,8 +679,8 @@ class ModelSubclassingTest(test.TestCase):
def __init__(self):
super(Foo, self).__init__()
self.isdep = keras.layers.Dense(1)
- self.notdep = checkpointable.NoDependency(keras.layers.Dense(2))
- self.notdep_var = checkpointable.NoDependency(
+ self.notdep = data_structures.NoDependency(keras.layers.Dense(2))
+ self.notdep_var = data_structures.NoDependency(
resource_variable_ops.ResourceVariable(1., name='notdep_var'))
m = Foo()
diff --git a/tensorflow/python/keras/optimizers.py b/tensorflow/python/keras/optimizers.py
index b02cafcf61..0b440185ca 100644
--- a/tensorflow/python/keras/optimizers.py
+++ b/tensorflow/python/keras/optimizers.py
@@ -31,7 +31,7 @@ from tensorflow.python.ops import state_ops
from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.training import optimizer as tf_optimizer_module
from tensorflow.python.training import training_util
-from tensorflow.python.training.checkpointable import tracking as checkpointable
+from tensorflow.python.training.checkpointable import base as checkpointable
from tensorflow.python.util.tf_export import tf_export
@@ -688,12 +688,13 @@ class Nadam(Optimizer):
return dict(list(base_config.items()) + list(config.items()))
-class TFOptimizer(Optimizer, checkpointable.Checkpointable):
+class TFOptimizer(Optimizer, checkpointable.CheckpointableBase):
"""Wrapper class for native TensorFlow optimizers.
"""
def __init__(self, optimizer): # pylint: disable=super-init-not-called
self.optimizer = optimizer
+ self._track_checkpointable(optimizer, name='optimizer')
with K.name_scope(self.__class__.__name__):
self.iterations = K.variable(0, dtype='int64', name='iterations')
diff --git a/tensorflow/python/training/checkpointable/BUILD b/tensorflow/python/training/checkpointable/BUILD
index 54f359489e..35007653a0 100644
--- a/tensorflow/python/training/checkpointable/BUILD
+++ b/tensorflow/python/training/checkpointable/BUILD
@@ -47,6 +47,7 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":base",
+ ":data_structures",
],
)
diff --git a/tensorflow/python/training/checkpointable/base.py b/tensorflow/python/training/checkpointable/base.py
index 99c8098eca..e9c8c21905 100644
--- a/tensorflow/python/training/checkpointable/base.py
+++ b/tensorflow/python/training/checkpointable/base.py
@@ -33,6 +33,7 @@ from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import saveable_object
from tensorflow.python.util import nest
from tensorflow.python.util import serialization
+from tensorflow.python.util import tf_decorator
# Key where the object graph proto is saved in a TensorBundle
@@ -340,6 +341,34 @@ _SlotVariableRestoration = collections.namedtuple(
])
+def no_automatic_dependency_tracking(method):
+ """Disables automatic dependency tracking on attribute assignment.
+
+ Use to decorate any method of a Checkpointable object. Attribute assignment in
+ that method will not add dependencies (also respected in Model). Harmless if
+ used in a class which does not do automatic dependency tracking (which means
+ it's safe to use in base classes which may have subclasses which also inherit
+ from Checkpointable).
+
+ Args:
+ method: The method to decorate.
+ Returns:
+ A decorated method which sets and un-sets automatic dependency tracking for
+ the object the method is called on (not thread safe).
+ """
+
+ def _method_wrapper(self, *args, **kwargs):
+ previous_value = getattr(self, "_setattr_tracking", True)
+ self._setattr_tracking = False # pylint: disable=protected-access
+ try:
+ method(self, *args, **kwargs)
+ finally:
+ self._setattr_tracking = previous_value # pylint: disable=protected-access
+
+ return tf_decorator.make_decorator(
+ target=method, decorator_func=_method_wrapper)
+
+
class CheckpointableBase(object):
"""Base class for `Checkpointable` objects without automatic dependencies.
@@ -349,6 +378,11 @@ class CheckpointableBase(object):
checks.
"""
+ # CheckpointableBase does not do automatic dependency tracking, but uses the
+ # no_automatic_dependency_tracking decorator so it can avoid adding
+ # dependencies if a subclass is Checkpointable / inherits from Model (both of
+ # which have __setattr__ overrides).
+ @no_automatic_dependency_tracking
def _maybe_initialize_checkpointable(self):
"""Initialize dependency management.
@@ -386,6 +420,10 @@ class CheckpointableBase(object):
# building.
self._name_based_restores = set()
+ def _no_dependency(self, value):
+ """If automatic dependency tracking is enabled, ignores `value`."""
+ return value
+
def _name_based_attribute_restore(self, checkpoint):
"""Restore the object's attributes from a name-based checkpoint."""
self._name_based_restores.add(checkpoint)
@@ -733,28 +771,3 @@ class CheckpointableBase(object):
return {OBJECT_CONFIG_JSON_KEY: functools.partial(
PythonStringStateSaveable,
state_callback=_state_callback)}
-
-
-class NoDependency(object):
- """Allows attribute assignment to `Checkpointable` objects with no dependency.
-
- Example usage:
- ```python
- obj = Checkpointable()
- obj.has_dependency = tf.Variable(0., name="dep")
- obj.no_dependency = NoDependency(tf.Variable(1., name="nodep"))
- assert obj.no_dependency.name == "nodep:0"
- ```
-
- `obj` in this example has a dependency on the variable "dep", and both
- attributes contain un-wrapped `Variable` objects.
-
- `NoDependency` also works with `tf.keras.Model`, but only for checkpoint
- dependencies: wrapping a `Layer` in `NoDependency` will assign the (unwrapped)
- `Layer` to the attribute without a checkpoint dependency, but the `Model` will
- still track the `Layer` (so it will appear in `Model.layers`, and its
- variables will appear in `Model.variables`).
- """
-
- def __init__(self, value):
- self.value = value
diff --git a/tensorflow/python/training/checkpointable/data_structures.py b/tensorflow/python/training/checkpointable/data_structures.py
index c46585b417..019d43f09c 100644
--- a/tensorflow/python/training/checkpointable/data_structures.py
+++ b/tensorflow/python/training/checkpointable/data_structures.py
@@ -22,49 +22,126 @@ import collections
import six
from tensorflow.python.ops import variables
-from tensorflow.python.training.checkpointable import base as checkpointable_lib
+from tensorflow.python.training.checkpointable import base
from tensorflow.python.training.checkpointable import layer_utils
-# TODO(allenl): We could track regular Python data structures which get assigned
-# to Checkpointable objects. Making this work with restore-on-create would be
-# tricky; we'd need to re-create nested structures with our own wrapped objects
-# on assignment to an attribute, and track the user's original structure to make
-# sure they don't modify it except through the wrappers (since we could save the
-# user's updated structure, but would have no way to support restore-on-create
-# for those modifications).
-# TODO(allenl): A dictionary data structure would be good too.
-class CheckpointableDataStructure(checkpointable_lib.CheckpointableBase):
+class NoDependency(object):
+ """Allows attribute assignment to `Checkpointable` objects with no dependency.
+
+ Example usage:
+ ```python
+ obj = Checkpointable()
+ obj.has_dependency = tf.Variable(0., name="dep")
+ obj.no_dependency = NoDependency(tf.Variable(1., name="nodep"))
+ assert obj.no_dependency.name == "nodep:0"
+ ```
+
+ `obj` in this example has a dependency on the variable "dep", and both
+ attributes contain un-wrapped `Variable` objects.
+
+ `NoDependency` also works with `tf.keras.Model`, but only for checkpoint
+ dependencies: wrapping a `Layer` in `NoDependency` will assign the (unwrapped)
+ `Layer` to the attribute without a checkpoint dependency, but the `Model` will
+ still track the `Layer` (so it will appear in `Model.layers`, and its
+ variables will appear in `Model.variables`).
+ """
+
+ def __init__(self, value):
+ self.value = value
+
+
+def _wrap_or_unwrap(value):
+ """Wraps basic data structures, unwraps NoDependency objects."""
+ if isinstance(value, NoDependency):
+ return value.value
+ if isinstance(value, base.CheckpointableBase):
+ return value # Skip conversion for already checkpointable objects.
+ elif isinstance(value, list):
+ return _ListWrapper(value)
+ else:
+ return value
+ # TODO(allenl): Handle other common data structures. Tuples will require
+ # special casing (tuple subclasses are not weak referenceable, so replacement
+ # with a wrapper that subclasses tuple on attribute assignment works poorly,
+ # and replacement with a wrapper that isn't a tuple is also problematic),
+ # probably a tree traversal where the leaves are non-tuples(/namedtuples) to
+ # come up with names. Dictionaries should look like lists.
+
+
+def sticky_attribute_assignment(checkpointable, name, value):
+ """Adds dependencies, generally called from __setattr__.
+
+ This behavior is shared between Checkpointable and Model.
+
+ Respects NoDependency indicators, but otherwise makes checkpointable objects
+ out of common data structures and tracks objects by their attribute names.
+
+ Args:
+ checkpointable: The object to add dependencies to (generally the one having
+ an attribute assigned).
+ name: The attribute name being assigned.
+ value: The value being assigned. Not necessarily a checkpointable object.
+
+ Returns:
+ The value which should be stored in the attribute (unwrapped from a
+ NoDependency object if necessary).
+ """
+ if isinstance(value, NoDependency):
+ add_dependency = False
+ else:
+ add_dependency = True
+ value = _wrap_or_unwrap(value)
+ if not add_dependency:
+ return value
+ if isinstance(value, base.CheckpointableBase):
+ checkpointable._track_checkpointable( # pylint: disable=protected-access
+ value, name=name,
+ # Allow the user to switch the Checkpointable which is tracked by this
+ # name, since assigning a new variable to an attribute has
+ # historically been fine (e.g. Adam did this).
+ overwrite=True)
+ return value
+
+
+class CheckpointableDataStructure(base.CheckpointableBase):
"""Base class for data structures which contain checkpointable objects."""
def __init__(self):
+ # An append-only ordered set
self._layers = []
+
self.trainable = True
self._extra_variables = []
def _track_value(self, value, name):
"""Add a dependency on `value`."""
- if isinstance(value, checkpointable_lib.CheckpointableBase):
- self._track_checkpointable(value, name=name)
- if isinstance(value, variables.Variable):
- self._extra_variables.append(value)
- else:
+ value = sticky_attribute_assignment(
+ checkpointable=self, value=value, name=name)
+ if isinstance(value, variables.Variable):
+ self._extra_variables.append(value)
+ if not isinstance(value, base.CheckpointableBase):
raise ValueError(
("Only checkpointable objects (such as Layers or Optimizers) may be "
"stored in a List object. Got %s, which does not inherit from "
"CheckpointableBase.") % (value,))
if (isinstance(value, CheckpointableDataStructure)
or layer_utils.is_layer(value)):
- if value not in self._layers:
+ # Check for object-identity rather than with __eq__ to avoid
+ # de-duplicating empty container types. Automatically generated list
+ # wrappers keep things like "[] == []" true, which means "[] in [[]]" is
+ # also true. This becomes not true once one of the lists is mutated.
+ if not any((layer is value for layer in self._layers)):
self._layers.append(value)
if hasattr(value, "_use_resource_variables"):
# In subclassed models, legacy layers (tf.layers) must always use
# resource variables.
value._use_resource_variables = True # pylint: disable=protected-access
+ return value
@property
def layers(self):
- return self._layers
+ return layer_utils.filter_empty_layer_containers(self._layers)
@property
def trainable_weights(self):
@@ -164,24 +241,28 @@ class List(CheckpointableDataStructure, collections.Sequence):
def __init__(self, *args, **kwargs):
"""Construct a new sequence. Arguments are passed to `list()`."""
super(List, self).__init__()
- self._storage = list(*args, **kwargs)
+ self._storage = self._make_storage(*args, **kwargs)
for index, element in enumerate(self._storage):
- self._track_value(element, name=self._name_element(index))
+ self._storage[index] = self._track_value(
+ element, name=self._name_element(index))
+
+ def _make_storage(self, *args, **kwargs):
+ """Determines the backing storage (overridden in subclasses)."""
+ return list(*args, **kwargs)
def _name_element(self, index):
return "%d" % (index,)
def append(self, value):
"""Add a new checkpointable value."""
- self._track_value(value, self._name_element(len(self._storage)))
+ value = self._track_value(value, self._name_element(len(self._storage)))
self._storage.append(value)
def extend(self, values):
"""Add a sequence of checkpointable values."""
- for index_offset, value in enumerate(values):
- self._track_value(
- value, name=self._name_element(len(self._storage) + index_offset))
- self._storage.extend(values)
+ for value in values:
+ self._storage.append(self._track_value(
+ value, name=self._name_element(len(self._storage))))
def __iadd__(self, values):
self.extend(values)
@@ -189,9 +270,12 @@ class List(CheckpointableDataStructure, collections.Sequence):
def __add__(self, other):
if isinstance(other, List):
- return List(self._storage + other._storage) # pylint: disable=protected-access
+ return self.__class__(self._storage + other._storage) # pylint: disable=protected-access
else:
- return List(self._storage + other)
+ return self.__class__(self._storage + other)
+
+ def __radd__(self, other):
+ return self + other
def __getitem__(self, key):
return self._storage[key]
@@ -203,6 +287,144 @@ class List(CheckpointableDataStructure, collections.Sequence):
return "List(%s)" % (repr(self._storage),)
+class _ListWrapper(List, collections.MutableSequence,
+ # Shadowed, but there for isinstance checks.
+ list):
+ """Wraps the built-in `list` to support restore-on-create for variables.
+
+ Unlike `List`, this sequence type is mutable in the same ways built-in lists
+ are. Instead of throwing an error immediately like `List`, it records
+ problematic mutations (e.g. assigning a new element to a position already
+ occupied, meaning both elements get the same names at different times) and
+ refuses to save.
+
+ On assignment to an attribute of a Model or Checkpointable object, Python
+ lists are replaced with _ListWrapper. Wrapping a list in a
+ `tf.contrib.checkpoint.NoDependency` object prevents this.
+ """
+
+ def __init__(self, wrapped_list):
+ """Construct a new list wrapper.
+
+ Args:
+ wrapped_list: The initial value of the data structure. A shallow copy may
+ be maintained for error checking. `wrapped_list` itself should not be
+ modified directly after constructing the `_ListWrapper`, and if changes
+ are detected the `_ListWrapper` will throw an exception on save.
+ """
+ # Monotonic flags which indicate this object would not be restored properly,
+ # and therefore should throw an error on save to avoid giving the impression
+ # that restoring it will work.
+ self._non_append_mutation = False
+ self._external_modification = False
+ super(_ListWrapper, self).__init__(wrapped_list)
+ self._last_wrapped_list_snapshot = list(self._storage)
+
+ def _make_storage(self, wrapped_list):
+ """Use the user's original list for storage."""
+ return wrapped_list
+
+ def _check_external_modification(self):
+ """Checks for any changes to the wrapped list not through the wrapper."""
+ if self._external_modification or self._non_append_mutation:
+ return
+ if self._storage != self._last_wrapped_list_snapshot:
+ self._external_modification = True
+ self._last_wrapped_list_snapshot = None
+
+ def _update_snapshot(self):
+ """Acknowledges tracked changes to the wrapped list."""
+ if self._external_modification or self._non_append_mutation:
+ return
+ self._last_wrapped_list_snapshot = list(self._storage)
+
+ @property
+ def _checkpoint_dependencies(self):
+ self._check_external_modification()
+ if self._non_append_mutation:
+ raise ValueError(
+ ("Unable to save the object %s (a list wrapper constructed to track "
+ "checkpointable TensorFlow objects). A list element was replaced "
+ "(__setitem__), deleted, or inserted. In order to support "
+ "restoration on object creation, tracking is exclusively for "
+ "append-only data structures.\n\nIf you don't need this list "
+ "checkpointed, wrap it in a tf.contrib.checkpoint.NoDependency "
+ "object; it will be automatically un-wrapped and subsequently "
+ "ignored." % (self,)))
+ if self._external_modification:
+ raise ValueError(
+ ("Unable to save the object %s (a list wrapper constructed to track "
+ "checkpointable TensorFlow objects). The wrapped list was modified "
+ "outside the wrapper (its final value was %s, its value when a "
+ "checkpoint dependency was added was %s), which breaks restoration "
+ "on object creation.\n\nIf you don't need this list checkpointed, "
+ "wrap it in a tf.contrib.checkpoint.NoDependency object; it will be "
+ "automatically un-wrapped and subsequently ignored." % (
+ self, self._storage, self._last_wrapped_list_snapshot)))
+ return super(_ListWrapper, self)._checkpoint_dependencies
+
+ def __delitem__(self, key):
+ self._non_append_mutation = True
+ del self._storage[key]
+
+ def __setitem__(self, key, value):
+ self._non_append_mutation = True
+ self._storage[key] = value
+
+ def append(self, value):
+ """Add a new checkpointable value."""
+ self._check_external_modification()
+ super(_ListWrapper, self).append(value)
+ self._update_snapshot()
+
+ def extend(self, values):
+ """Add a sequence of checkpointable values."""
+ self._check_external_modification()
+ super(_ListWrapper, self).extend(values)
+ self._update_snapshot()
+
+ def __eq__(self, other):
+ return self._storage == getattr(other, "_storage", other)
+
+ def __ne__(self, other):
+ return self._storage != getattr(other, "_storage", other)
+
+ def __lt__(self, other):
+ return self._storage < getattr(other, "_storage", other)
+
+ def __le__(self, other):
+ return self._storage <= getattr(other, "_storage", other)
+
+ def __gt__(self, other):
+ return self._storage > getattr(other, "_storage", other)
+
+ def __ge__(self, other):
+ return self._storage >= getattr(other, "_storage", other)
+
+ def __hash__(self):
+ # List wrappers need to compare like regular lists, and so like regular
+ # lists they don't belong in hash tables.
+ raise TypeError("unhashable type: 'ListWrapper'")
+
+ def insert(self, index, obj):
+ self._non_append_mutation = True
+ self._storage.insert(index, obj)
+
+ def _track_value(self, value, name):
+ """Allows storage of non-checkpointable objects."""
+ try:
+ value = super(_ListWrapper, self)._track_value(value=value, name=name)
+ except ValueError:
+ # Even if this value isn't checkpointable, we need to make sure
+ # NoDependency objects get unwrapped.
+ value = sticky_attribute_assignment(
+ checkpointable=self, value=value, name=name)
+ return value
+
+ def __repr__(self):
+ return "ListWrapper(%s)" % (repr(self._storage),)
+
+
class Mapping(CheckpointableDataStructure, collections.Mapping):
"""An append-only checkpointable mapping data structure with string keys.
@@ -217,8 +439,10 @@ class Mapping(CheckpointableDataStructure, collections.Mapping):
"""Construct a new sequence. Arguments are passed to `dict()`."""
super(Mapping, self).__init__()
self._storage = dict(*args, **kwargs)
- for key, value in self._storage.items():
- self._track_value(value, name=self._name_element(key))
+ self._storage.update(
+ {key: self._track_value(
+ value, name=self._name_element(key))
+ for key, value in self._storage.items()})
def _name_element(self, key):
if not isinstance(key, six.string_types):
@@ -228,13 +452,14 @@ class Mapping(CheckpointableDataStructure, collections.Mapping):
return str(key)
def __setitem__(self, key, value):
+ name = self._name_element(key)
+ value = self._track_value(value, name=name)
current_value = self._storage.setdefault(key, value)
if current_value is not value:
raise ValueError(
("Mappings are an append-only data structure. Tried to overwrite the "
"key '%s' with value %s, but it already contains %s")
% (key, value, current_value))
- self._track_value(value, name=self._name_element(key))
def update(self, *args, **kwargs):
for key, value in dict(*args, **kwargs).items():
diff --git a/tensorflow/python/training/checkpointable/data_structures_test.py b/tensorflow/python/training/checkpointable/data_structures_test.py
index ce5852dd6e..ec8c9da809 100644
--- a/tensorflow/python/training/checkpointable/data_structures_test.py
+++ b/tensorflow/python/training/checkpointable/data_structures_test.py
@@ -31,6 +31,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.training.checkpointable import data_structures
+from tensorflow.python.training.checkpointable import tracking
class HasList(training.Model):
@@ -113,6 +114,19 @@ class ListTests(test.TestCase):
model(model_input)
self.assertEqual(2, len(model.losses))
+ def testModelContainersCompareEqual(self):
+ class HasEqualContainers(training.Model):
+
+ def __init__(self):
+ super(HasEqualContainers, self).__init__()
+ self.l1 = []
+ self.l2 = []
+
+ model = HasEqualContainers()
+ model.l1.append(HasEqualContainers())
+ model.l2.append(HasEqualContainers())
+ self.assertEqual([model.l1, model.l2], model.layers)
+
def testNotCheckpointable(self):
class NotCheckpointable(object):
pass
@@ -158,11 +172,62 @@ class ListTests(test.TestCase):
self.assertEqual([v], l.trainable_weights)
self.assertEqual([v2], l.non_trainable_weights)
+ def testListWrapperBasic(self):
+ # _ListWrapper, unlike List, compares like the built-in list type (since it
+ # is used to automatically replace lists).
+ a = tracking.Checkpointable()
+ b = tracking.Checkpointable()
+ self.assertEqual([a, a],
+ [a, a])
+ self.assertEqual(data_structures._ListWrapper([a, a]),
+ data_structures._ListWrapper([a, a]))
+ self.assertEqual([a, a],
+ data_structures._ListWrapper([a, a]))
+ self.assertEqual(data_structures._ListWrapper([a, a]),
+ [a, a])
+ self.assertNotEqual([a, a],
+ [b, a])
+ self.assertNotEqual(data_structures._ListWrapper([a, a]),
+ data_structures._ListWrapper([b, a]))
+ self.assertNotEqual([a, a],
+ data_structures._ListWrapper([b, a]))
+ self.assertLess([a], [a, b])
+ self.assertLess(data_structures._ListWrapper([a]),
+ data_structures._ListWrapper([a, b]))
+ self.assertLessEqual([a], [a, b])
+ self.assertLessEqual(data_structures._ListWrapper([a]),
+ data_structures._ListWrapper([a, b]))
+ self.assertGreater([a, b], [a])
+ self.assertGreater(data_structures._ListWrapper([a, b]),
+ data_structures._ListWrapper([a]))
+ self.assertGreaterEqual([a, b], [a])
+ self.assertGreaterEqual(data_structures._ListWrapper([a, b]),
+ data_structures._ListWrapper([a]))
+ self.assertEqual([a], data_structures._ListWrapper([a]))
+ self.assertEqual([a], list(data_structures.List([a])))
+ self.assertEqual([a, a], data_structures._ListWrapper([a]) + [a])
+ self.assertEqual([a, a], [a] + data_structures._ListWrapper([a]))
+ self.assertIsInstance(data_structures._ListWrapper([a]), list)
+
+ def testWrapperChangesList(self):
+ l = []
+ l_wrapper = data_structures._ListWrapper(l)
+ l_wrapper.append(1)
+ self.assertEqual([1], l)
+
+ def testListChangesWrapper(self):
+ l = []
+ l_wrapper = data_structures._ListWrapper(l)
+ l.append(1)
+ self.assertEqual([1], l_wrapper)
+
def testHashing(self):
has_sequences = set([data_structures.List(),
data_structures.List()])
self.assertEqual(2, len(has_sequences))
self.assertNotIn(data_structures.List(), has_sequences)
+ with self.assertRaises(TypeError):
+ has_sequences.add(data_structures._ListWrapper([]))
class HasMapping(training.Model):
diff --git a/tensorflow/python/training/checkpointable/layer_utils.py b/tensorflow/python/training/checkpointable/layer_utils.py
index fdcf963d32..978fcb2252 100644
--- a/tensorflow/python/training/checkpointable/layer_utils.py
+++ b/tensorflow/python/training/checkpointable/layer_utils.py
@@ -30,6 +30,14 @@ def is_layer(obj):
and hasattr(obj, "variables"))
+def filter_empty_layer_containers(layer_list):
+ """Filter out empty Layer-like containers."""
+ return [layer for layer in layer_list
+ # Filter out only empty Checkpointable data structures. Empty Networks
+ # will still show up in Model.layers.
+ if is_layer(layer) or getattr(layer, "layers", True)]
+
+
def gather_trainable_weights(trainable, sub_layers, extra_variables):
"""Lists the trainable weights for an object with sub-layers.
diff --git a/tensorflow/python/training/checkpointable/tracking.py b/tensorflow/python/training/checkpointable/tracking.py
index 00e14ac982..bd0bed9d46 100644
--- a/tensorflow/python/training/checkpointable/tracking.py
+++ b/tensorflow/python/training/checkpointable/tracking.py
@@ -18,31 +18,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.training.checkpointable import base
-
-
-class NoDependency(object):
- """Allows attribute assignment to `Checkpointable` objects with no dependency.
-
- Example usage:
- ```python
- obj = Checkpointable()
- obj.has_dependency = tf.Variable(0., name="dep")
- obj.no_dependency = NoDependency(tf.Variable(1., name="nodep"))
- assert obj.no_dependency.name == "nodep:0"
- ```
-
- `obj` in this example has a dependency on the variable "dep", and both
- attributes contain un-wrapped `Variable` objects.
-
- `NoDependency` also works with `tf.keras.Model`, but only for checkpoint
- dependencies: wrapping a `Layer` in `NoDependency` will assign the (unwrapped)
- `Layer` to the attribute without a checkpoint dependency, but the `Model` will
- still track the `Layer` (so it will appear in `Model.layers`, and its
- variables will appear in `Model.variables`).
- """
-
- def __init__(self, value):
- self.value = value
+from tensorflow.python.training.checkpointable import data_structures
class NotCheckpointable(object):
@@ -86,18 +62,11 @@ class Checkpointable(base.CheckpointableBase):
def __setattr__(self, name, value):
"""Support self.foo = checkpointable syntax."""
- # Perform the attribute assignment, and potentially call other __setattr__
- # overrides such as that for tf.keras.Model.
- no_dependency = isinstance(value, NoDependency)
- if no_dependency:
- value = value.value
+ if getattr(self, "_setattr_tracking", True):
+ value = data_structures.sticky_attribute_assignment(
+ checkpointable=self, value=value, name=name)
super(Checkpointable, self).__setattr__(name, value)
- if not no_dependency and isinstance(value, base.CheckpointableBase):
- self._track_checkpointable(
- value, name=name,
- # Allow the user to switch the Checkpointable which is tracked by this
- # name, since assigning a new variable to an attribute has
- # historically been fine (e.g. Adam did this).
- # TODO(allenl): Should this be a warning once Checkpointable save/load
- # is usable?
- overwrite=True)
+
+ def _no_dependency(self, value):
+ """Override to allow CheckpointableBase to disable dependency tracking."""
+ return data_structures.NoDependency(value)
diff --git a/tensorflow/python/training/checkpointable/tracking_test.py b/tensorflow/python/training/checkpointable/tracking_test.py
index baf6f57efb..f0178b074d 100644
--- a/tensorflow/python/training/checkpointable/tracking_test.py
+++ b/tensorflow/python/training/checkpointable/tracking_test.py
@@ -16,8 +16,19 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import os
+
+import numpy
+
+from tensorflow.python.framework import test_util
+from tensorflow.python.keras.engine import training
+from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
+from tensorflow.python.training.checkpointable import base
+from tensorflow.python.training.checkpointable import data_structures
from tensorflow.python.training.checkpointable import tracking
+from tensorflow.python.training.checkpointable import util
+from tensorflow.python.util import nest
class InterfaceTests(test.TestCase):
@@ -27,7 +38,7 @@ class InterfaceTests(test.TestCase):
root.leaf = tracking.Checkpointable()
root.leaf = root.leaf
duplicate_name_dep = tracking.Checkpointable()
- with self.assertRaises(ValueError):
+ with self.assertRaisesRegexp(ValueError, "already declared"):
root._track_checkpointable(duplicate_name_dep, name="leaf")
# No error; we're overriding __setattr__, so we can't really stop people
# from doing this while maintaining backward compatibility.
@@ -39,11 +50,119 @@ class InterfaceTests(test.TestCase):
hasdep = tracking.Checkpointable()
root.hasdep = hasdep
nodep = tracking.Checkpointable()
- root.nodep = tracking.NoDependency(nodep)
+ root.nodep = data_structures.NoDependency(nodep)
self.assertEqual(1, len(root._checkpoint_dependencies))
self.assertIs(root._checkpoint_dependencies[0].ref, root.hasdep)
self.assertIs(root.hasdep, hasdep)
self.assertIs(root.nodep, nodep)
+ class NoDependencyModel(training.Model):
+
+ @base.no_automatic_dependency_tracking
+ def __init__(self):
+ super(NoDependencyModel, self).__init__()
+ self.a = []
+ self.b = tracking.Checkpointable()
+
+ nodeps = NoDependencyModel()
+ self.assertEqual([nodeps], util.list_objects(nodeps))
+
+ def testListBasic(self):
+ a = tracking.Checkpointable()
+ b = tracking.Checkpointable()
+ a.l = [b]
+ c = tracking.Checkpointable()
+ a.l.append(c)
+ a_deps = util.list_objects(a)
+ self.assertIn(b, a_deps)
+ self.assertIn(c, a_deps)
+ direct_a_dep, = a._checkpoint_dependencies
+ self.assertEqual("l", direct_a_dep.name)
+ self.assertIn(b, direct_a_dep.ref)
+ self.assertIn(c, direct_a_dep.ref)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testMutationDirtiesList(self):
+ a = tracking.Checkpointable()
+ b = tracking.Checkpointable()
+ a.l = [b]
+ c = tracking.Checkpointable()
+ a.l.insert(0, c)
+ checkpoint = util.Checkpoint(a=a)
+ with self.assertRaisesRegexp(ValueError, "A list element was replaced"):
+ checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
+
+ @test_util.run_in_graph_and_eager_modes
+ def testOutOfBandEditDirtiesList(self):
+ a = tracking.Checkpointable()
+ b = tracking.Checkpointable()
+ held_reference = [b]
+ a.l = held_reference
+ c = tracking.Checkpointable()
+ held_reference.append(c)
+ checkpoint = util.Checkpoint(a=a)
+ with self.assertRaisesRegexp(ValueError, "The wrapped list was modified"):
+ checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
+
+ @test_util.run_in_graph_and_eager_modes
+ def testNestedLists(self):
+ a = tracking.Checkpointable()
+ a.l = []
+ b = tracking.Checkpointable()
+ a.l.append([b])
+ c = tracking.Checkpointable()
+ a.l[0].append(c)
+ a_deps = util.list_objects(a)
+ self.assertIn(b, a_deps)
+ self.assertIn(c, a_deps)
+ a.l[0].append(1)
+ d = tracking.Checkpointable()
+ a.l[0].append(d)
+ a_deps = util.list_objects(a)
+ self.assertIn(d, a_deps)
+ self.assertIn(b, a_deps)
+ self.assertIn(c, a_deps)
+ self.assertNotIn(1, a_deps)
+ e = tracking.Checkpointable()
+ f = tracking.Checkpointable()
+ a.l1 = [[], [e]]
+ a.l1[0].append(f)
+ a_deps = util.list_objects(a)
+ self.assertIn(e, a_deps)
+ self.assertIn(f, a_deps)
+ checkpoint = util.Checkpoint(a=a)
+ checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
+ a.l[0].append(data_structures.NoDependency([]))
+ a.l[0][-1].append(5)
+ checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
+ # Dirtying the inner list means the root object is unsaveable.
+ a.l[0][1] = 2
+ with self.assertRaisesRegexp(ValueError, "A list element was replaced"):
+ checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
+
+ @test_util.run_in_graph_and_eager_modes
+ def testNoDepList(self):
+ a = training.Model()
+ a.l1 = data_structures.NoDependency([])
+ a.l1.insert(1, 0)
+ self.assertTrue(isinstance(a.l1, list))
+ checkpoint = util.Checkpoint(a=a)
+ checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
+ a.l2 = []
+ a.l2.insert(1, 0)
+ with self.assertRaisesRegexp(ValueError, "A list element was replaced"):
+ checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
+
+ @test_util.run_in_graph_and_eager_modes
+ def testAssertions(self):
+ a = tracking.Checkpointable()
+ a.l = [numpy.zeros([2, 2])]
+ self.assertAllEqual([numpy.zeros([2, 2])], a.l)
+ self.assertAllClose([numpy.zeros([2, 2])], a.l)
+ nest.map_structure(self.assertAllClose, a.l, [numpy.zeros([2, 2])])
+ a.tensors = [array_ops.ones([2, 2]), array_ops.zeros([3, 3])]
+ self.assertAllClose([numpy.ones([2, 2]), numpy.zeros([3, 3])],
+ self.evaluate(a.tensors))
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/training/checkpointable/util.py b/tensorflow/python/training/checkpointable/util.py
index e0f61137b1..6ae5765b13 100644
--- a/tensorflow/python/training/checkpointable/util.py
+++ b/tensorflow/python/training/checkpointable/util.py
@@ -40,6 +40,7 @@ from tensorflow.python.training import optimizer as optimizer_lib
from tensorflow.python.training import saveable_object as saveable_object_lib
from tensorflow.python.training import saver as saver_lib
from tensorflow.python.training.checkpointable import base
+from tensorflow.python.training.checkpointable import data_structures
from tensorflow.python.training.checkpointable import tracking
from tensorflow.python.util import deprecation
from tensorflow.python.util import tf_contextlib
@@ -93,7 +94,7 @@ class _CheckpointRestoreCoordinator(object):
# use them (for example because of inconsistent references when
# loading). Used to make status assertions fail when loading checkpoints
# that don't quite match.
- self.all_python_objects = weakref.WeakSet()
+ self.all_python_objects = _ObjectIdentityWeakSet()
self.save_path = save_path
self.dtype_map = dtype_map
# When graph building, contains a list of ops to run to restore objects from
@@ -272,11 +273,129 @@ def object_metadata(save_path):
return object_graph_proto
+class _ObjectIdentityWrapper(object):
+ """Wraps an object, mapping __eq__ on wrapper to "is" on wrapped.
+
+ Since __eq__ is based on object identity, it's safe to also define __hash__
+ based on object ids. This lets us add unhashable types like checkpointable
+ _ListWrapper objects to object-identity collections.
+ """
+
+ def __init__(self, wrapped):
+ self._wrapped = wrapped
+
+ @property
+ def unwrapped(self):
+ return self._wrapped
+
+ def __eq__(self, other):
+ if isinstance(other, _ObjectIdentityWrapper):
+ return self._wrapped is other._wrapped # pylint: disable=protected-access
+ return self._wrapped is other
+
+ def __hash__(self):
+ # Wrapper id() is also fine for weakrefs. In fact, we rely on
+ # id(weakref.ref(a)) == id(weakref.ref(a)) and weakref.ref(a) is
+ # weakref.ref(a) in _WeakObjectIdentityWrapper.
+ return id(self._wrapped)
+
+
+class _WeakObjectIdentityWrapper(_ObjectIdentityWrapper):
+
+ def __init__(self, wrapped):
+ super(_WeakObjectIdentityWrapper, self).__init__(weakref.ref(wrapped))
+
+ @property
+ def unwrapped(self):
+ return self._wrapped()
+
+
+class _ObjectIdentityDictionary(collections.MutableMapping):
+ """A mutable mapping data structure which compares using "is".
+
+ This is necessary because we have checkpointable objects (_ListWrapper) which
+ have behavior identical to built-in Python lists (including being unhashable
+ and comparing based on the equality of their contents by default).
+ """
+
+ def __init__(self):
+ self._storage = {}
+
+ def _wrap_key(self, key):
+ return _ObjectIdentityWrapper(key)
+
+ def __getitem__(self, key):
+ return self._storage[self._wrap_key(key)]
+
+ def __setitem__(self, key, value):
+ self._storage[self._wrap_key(key)] = value
+
+ def __delitem__(self, key):
+ del self._storage[self._wrap_key(key)]
+
+ def __len__(self):
+ return len(self._storage)
+
+ def __iter__(self):
+ for key in self._storage:
+ yield key.unwrapped
+
+
+class _ObjectIdentityWeakKeyDictionary(_ObjectIdentityDictionary):
+ """Like weakref.WeakKeyDictionary, but compares objects with "is"."""
+
+ def _wrap_key(self, key):
+ return _WeakObjectIdentityWrapper(key)
+
+ def __len__(self):
+ # Iterate, discarding old weak refs
+ return len(list(self._storage))
+
+ def __iter__(self):
+ keys = self._storage.keys()
+ for key in keys:
+ unwrapped = key.unwrapped
+ if unwrapped is None:
+ del self[key]
+ else:
+ yield unwrapped
+
+
+class _ObjectIdentityWeakSet(collections.MutableSet):
+ """Like weakref.WeakSet, but compares objects with "is"."""
+
+ def __init__(self):
+ self._storage = set()
+
+ def __contains__(self, key):
+ return _WeakObjectIdentityWrapper(key) in self._storage
+
+ def discard(self, key):
+ self._storage.discard(_WeakObjectIdentityWrapper(key))
+
+ def add(self, key):
+ self._storage.add(_WeakObjectIdentityWrapper(key))
+
+ def __len__(self):
+ # Iterate, discarding old weak refs
+ return len(list(self))
+
+ def __iter__(self):
+ keys = list(self._storage)
+ for key in keys:
+ unwrapped = key.unwrapped
+ if unwrapped is None:
+ self.discard(key)
+ else:
+ yield unwrapped
+
+
def _breadth_first_checkpointable_traversal(root_checkpointable):
"""Find shortest paths to all variables owned by dependencies of root."""
bfs_sorted = []
to_visit = collections.deque([root_checkpointable])
- path_to_root = {root_checkpointable: ()}
+ path_to_root = _ObjectIdentityDictionary()
+ path_to_root[root_checkpointable] = ()
while to_visit:
current_checkpointable = to_visit.popleft()
if isinstance(current_checkpointable, tracking.NotCheckpointable):
@@ -337,7 +456,7 @@ def _slot_variable_naming_for_optimizer(optimizer_path):
def _serialize_slot_variables(checkpointable_objects, node_ids, object_names):
"""Gather and name slot variables."""
non_slot_objects = list(checkpointable_objects)
- slot_variables = {}
+ slot_variables = _ObjectIdentityDictionary()
for checkpointable in non_slot_objects:
if isinstance(checkpointable, optimizer_lib.Optimizer):
naming_scheme = _slot_variable_naming_for_optimizer(
@@ -500,11 +619,12 @@ def _serialize_object_graph(root_checkpointable, saveables_cache):
"""
checkpointable_objects, path_to_root = (
_breadth_first_checkpointable_traversal(root_checkpointable))
- object_names = {
- obj: _object_prefix_from_path(path)
- for obj, path in path_to_root.items()}
- node_ids = {node: node_id for node_id, node
- in enumerate(checkpointable_objects)}
+ object_names = _ObjectIdentityDictionary()
+ for obj, path in path_to_root.items():
+ object_names[obj] = _object_prefix_from_path(path)
+ node_ids = _ObjectIdentityDictionary()
+ for node_id, node in enumerate(checkpointable_objects):
+ node_ids[node] = node_id
slot_variables = _serialize_slot_variables(
checkpointable_objects=checkpointable_objects,
node_ids=node_ids,
@@ -535,11 +655,12 @@ def list_objects(root_checkpointable):
# to run.
checkpointable_objects, path_to_root = (
_breadth_first_checkpointable_traversal(root_checkpointable))
- object_names = {
- obj: _object_prefix_from_path(path)
- for obj, path in path_to_root.items()}
- node_ids = {node: node_id for node_id, node
- in enumerate(checkpointable_objects)}
+ object_names = _ObjectIdentityDictionary()
+ for obj, path in path_to_root.items():
+ object_names[obj] = _object_prefix_from_path(path)
+ node_ids = _ObjectIdentityDictionary()
+ for node_id, node in enumerate(checkpointable_objects):
+ node_ids[node] = node_id
_serialize_slot_variables(
checkpointable_objects=checkpointable_objects,
node_ids=node_ids,
@@ -988,7 +1109,7 @@ class CheckpointableSaver(object):
else:
# Maps Checkpointable objects -> attribute names -> SaveableObjects, to
# avoid re-creating SaveableObjects when graph building.
- self._saveable_object_cache = weakref.WeakKeyDictionary()
+ self._saveable_object_cache = _ObjectIdentityWeakKeyDictionary()
@property
def _root_checkpointable(self):
@@ -1310,7 +1431,7 @@ class Checkpoint(tracking.Checkpointable):
with ops.device("/cpu:0"):
# add_variable creates a dependency named "save_counter"; NoDependency
# prevents creating a second dependency named "_save_counter".
- self._save_counter = tracking.NoDependency(
+ self._save_counter = data_structures.NoDependency(
add_variable(self, name="save_counter", initializer=0,
dtype=dtypes.int64))
diff --git a/tensorflow/python/util/nest.py b/tensorflow/python/util/nest.py
index 1104768ae8..d63f59a8c8 100644
--- a/tensorflow/python/util/nest.py
+++ b/tensorflow/python/util/nest.py
@@ -167,11 +167,14 @@ def assert_same_structure(nest1, nest2, check_types=True):
Args:
nest1: an arbitrarily nested structure.
nest2: an arbitrarily nested structure.
- check_types: if `True` (default) types of sequences are checked as
- well, including the keys of dictionaries. If set to `False`, for example
- a list and a tuple of objects will look the same if they have the same
+ check_types: if `True` (default) types of sequences are checked as well,
+ including the keys of dictionaries. If set to `False`, for example a
+ list and a tuple of objects will look the same if they have the same
size. Note that namedtuples with identical name and fields are always
- considered to have the same shallow structure.
+ considered to have the same shallow structure. Two types will also be
+ considered the same if they are both list subtypes (which allows "list"
+ and "_ListWrapper" from checkpointable dependency tracking to compare
+ equal).
Raises:
ValueError: If the two structures do not have the same number of elements or
diff --git a/tensorflow/python/util/util.cc b/tensorflow/python/util/util.cc
index c79d8a8445..366f8a0deb 100644
--- a/tensorflow/python/util/util.cc
+++ b/tensorflow/python/util/util.cc
@@ -394,7 +394,11 @@ bool AssertSameStructureHelper(PyObject* o1, PyObject* o2, bool check_types,
type2->tp_name);
return true;
}
- } else if (type1 != type2) {
+ } else if (type1 != type2
+ /* If both sequences are list types, don't complain. This allows
+ one to be a list subclass (e.g. _ListWrapper used for automatic
+ dependency tracking.) */
+ && !(PyList_Check(o1) && PyList_Check(o2))) {
*is_type_error = true;
*error_msg = tensorflow::strings::StrCat(
"The two namedtuples don't have the same sequence type. "