aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/checkpointable/base.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/training/checkpointable/base.py')
-rw-r--r--tensorflow/python/training/checkpointable/base.py71
1 files changed, 39 insertions, 32 deletions
diff --git a/tensorflow/python/training/checkpointable/base.py b/tensorflow/python/training/checkpointable/base.py
index 99c8098eca..f0703c8af4 100644
--- a/tensorflow/python/training/checkpointable/base.py
+++ b/tensorflow/python/training/checkpointable/base.py
@@ -33,6 +33,7 @@ from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import saveable_object
from tensorflow.python.util import nest
from tensorflow.python.util import serialization
+from tensorflow.python.util import tf_decorator
# Key where the object graph proto is saved in a TensorBundle
@@ -340,6 +341,34 @@ _SlotVariableRestoration = collections.namedtuple(
])
+def no_automatic_dependency_tracking(method):
+ """Disables automatic dependency tracking on attribute assignment.
+
+ Use to decorate any method of a Checkpointable object. Attribute assignment in
+ that method will not add dependencies (also respected in Model). Harmless if
+ used in a class which does not do automatic dependency tracking (which means
+ it's safe to use in base classes which may have subclasses which also inherit
+ from Checkpointable).
+
+ Args:
+ method: The method to decorate.
+ Returns:
+ A decorated method which sets and un-sets automatic dependency tracking for
+ the object the method is called on (not thread safe).
+ """
+
+ def _method_wrapper(self, *args, **kwargs):
+ previous_value = getattr(self, "_setattr_tracking", True)
+ self._setattr_tracking = False # pylint: disable=protected-access
+ try:
+ method(self, *args, **kwargs)
+ finally:
+ self._setattr_tracking = previous_value # pylint: disable=protected-access
+
+ return tf_decorator.make_decorator(
+ target=method, decorator_func=_method_wrapper)
+
+
class CheckpointableBase(object):
"""Base class for `Checkpointable` objects without automatic dependencies.
@@ -349,6 +378,11 @@ class CheckpointableBase(object):
checks.
"""
+ # CheckpointableBase does not do automatic dependency tracking, but uses the
+ # no_automatic_dependency_tracking decorator so it can avoid adding
+ # dependencies if a subclass is Checkpointable / inherits from Model (both of
+ # which have __setattr__ overrides).
+ @no_automatic_dependency_tracking
def _maybe_initialize_checkpointable(self):
"""Initialize dependency management.
@@ -386,6 +420,10 @@ class CheckpointableBase(object):
# building.
self._name_based_restores = set()
+ def _no_dependency(self, value):
+ """If automatic dependency tracking is enabled, ignores `value`."""
+ return value
+
def _name_based_attribute_restore(self, checkpoint):
"""Restore the object's attributes from a name-based checkpoint."""
self._name_based_restores.add(checkpoint)
@@ -463,12 +501,6 @@ class CheckpointableBase(object):
ValueError: If the variable name is not unique.
"""
self._maybe_initialize_checkpointable()
- if not overwrite and self._lookup_dependency(name) is not None:
- raise ValueError(
- ("A variable named '%s' already exists in this Checkpointable, but "
- "Checkpointable._add_variable called to create another with "
- "that name. Variable names must be unique within a Checkpointable "
- "object.") % (name,))
with ops.init_scope():
if context.executing_eagerly():
# If this is a variable with a single Tensor stored in the checkpoint,
@@ -593,9 +625,9 @@ class CheckpointableBase(object):
self._unconditional_checkpoint_dependencies[index] = new_reference
elif current_object is None:
self._unconditional_checkpoint_dependencies.append(new_reference)
- self._unconditional_dependency_names[name] = checkpointable
self._handle_deferred_dependencies(
name=name, checkpointable=checkpointable)
+ self._unconditional_dependency_names[name] = checkpointable
return checkpointable
def _handle_deferred_dependencies(self, name, checkpointable):
@@ -733,28 +765,3 @@ class CheckpointableBase(object):
return {OBJECT_CONFIG_JSON_KEY: functools.partial(
PythonStringStateSaveable,
state_callback=_state_callback)}
-
-
-class NoDependency(object):
- """Allows attribute assignment to `Checkpointable` objects with no dependency.
-
- Example usage:
- ```python
- obj = Checkpointable()
- obj.has_dependency = tf.Variable(0., name="dep")
- obj.no_dependency = NoDependency(tf.Variable(1., name="nodep"))
- assert obj.no_dependency.name == "nodep:0"
- ```
-
- `obj` in this example has a dependency on the variable "dep", and both
- attributes contain un-wrapped `Variable` objects.
-
- `NoDependency` also works with `tf.keras.Model`, but only for checkpoint
- dependencies: wrapping a `Layer` in `NoDependency` will assign the (unwrapped)
- `Layer` to the attribute without a checkpoint dependency, but the `Model` will
- still track the `Layer` (so it will appear in `Model.layers`, and its
- variables will appear in `Model.variables`).
- """
-
- def __init__(self, value):
- self.value = value