aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py14
-rw-r--r--tensorflow/python/training/checkpointable.py23
-rw-r--r--tensorflow/python/training/checkpointable_utils.py226
-rw-r--r--tensorflow/python/training/checkpointable_utils_test.py14
4 files changed, 205 insertions, 72 deletions
diff --git a/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py b/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py
index b1f2e9d860..20316ec0e3 100644
--- a/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py
+++ b/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py
@@ -722,12 +722,22 @@ class CheckpointCompatibilityTests(test.TestCase):
with self.assertRaises(AssertionError):
self._check_sentinels(root)
object_saver = checkpointable_utils.CheckpointableSaver(root)
+ self._set_sentinels(root)
status = object_saver.restore(save_path)
- with self.assertRaises(AssertionError):
- status.assert_consumed()
+ if context.executing_eagerly():
+ self._check_sentinels(root)
+ if context.executing_eagerly():
+ with self.assertRaisesRegexp(AssertionError, "OBJECT_CONFIG_JSON"):
+ status.assert_consumed()
+ else:
+ # When graph building, we haven't read any keys, so we don't know
+ # whether the restore will be complete.
+ with self.assertRaisesRegexp(AssertionError, "not restored"):
+ status.assert_consumed()
status.run_restore_ops()
self._check_sentinels(root)
self._set_sentinels(root)
+ status = object_saver.restore(save_path)
status.initialize_or_restore()
self._check_sentinels(root)
diff --git a/tensorflow/python/training/checkpointable.py b/tensorflow/python/training/checkpointable.py
index a57bcaea69..e378f0e898 100644
--- a/tensorflow/python/training/checkpointable.py
+++ b/tensorflow/python/training/checkpointable.py
@@ -377,6 +377,21 @@ class CheckpointableBase(object):
"Internal error: the object had an update UID set before its "
"initialization code was run.")
self._update_uid = -1
+ # When executing eagerly, holds a collection of _NameBasedRestoreCoordinator
+ # instances, which should be checked when creating variables or other
+ # saveables. These are passed on recursively to all dependencies, since
+ # unlike object-based checkpoint restores we don't know which subgraph is
+ # being restored in advance. This mechanism is only necessary for
+ # restore-on-create when executing eagerly, and so is unused when graph
+ # building.
+ self._name_based_restores = set()
+
+ def _name_based_attribute_restore(self, checkpoint):
+ """Restore the object's attributes from a name-based checkpoint."""
+ self._name_based_restores.add(checkpoint)
+ if self._update_uid < checkpoint.restore_uid:
+ checkpoint.eager_restore(self)
+ self._update_uid = checkpoint.restore_uid
@property
def _checkpoint_dependencies(self):
@@ -607,6 +622,7 @@ class CheckpointableBase(object):
`CheckpointableBase`).
"""
self._maybe_initialize_checkpointable()
+ checkpointable._maybe_initialize_checkpointable() # pylint: disable=protected-access
deferred_dependencies_list = self._deferred_dependencies.pop(name, ())
for checkpoint_position in sorted(
deferred_dependencies_list,
@@ -614,6 +630,13 @@ class CheckpointableBase(object):
reverse=True):
checkpoint_position.restore(checkpointable)
+ # Pass on any name-based restores queued in this object.
+ for name_based_restore in sorted(
+ self._name_based_restores,
+ key=lambda checkpoint: checkpoint.restore_uid,
+ reverse=True):
+ checkpointable._name_based_attribute_restore(name_based_restore) # pylint: disable=protected-access
+
def _restore_from_checkpoint_position(self, checkpoint_position):
"""Restore this object and its dependencies (may be deferred)."""
# Attempt a breadth-first traversal, since presumably the user has more
diff --git a/tensorflow/python/training/checkpointable_utils.py b/tensorflow/python/training/checkpointable_utils.py
index 72be434fb2..b7d9755268 100644
--- a/tensorflow/python/training/checkpointable_utils.py
+++ b/tensorflow/python/training/checkpointable_utils.py
@@ -30,13 +30,15 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import gen_io_ops as io_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.training import checkpointable as checkpointable_lib
from tensorflow.python.training import optimizer as optimizer_lib
-from tensorflow.python.training import saveable_object
+from tensorflow.python.training import saveable_object as saveable_object_lib
from tensorflow.python.training import saver as saver_lib
from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
@@ -117,6 +119,76 @@ class _CheckpointRestoreCoordinator(object):
slot_name=slot_reference.slot_name))
+class _NameBasedRestoreCoordinator(object):
+ """Keeps the status of a name-based checkpoint restore."""
+
+ def __init__(self, save_path, dtype_map=None):
+ self.save_path = save_path
+ self.dtype_map = dtype_map
+ self.unused_attributes = weakref.WeakKeyDictionary()
+ self.restore_uid = ops.uid()
+
+ def globally_named_object_attributes(self, checkpointable):
+ """Create globally named SaveableObjects from attributes.
+
+ If an object's attribute has no global name specified (default construction
+ for the SaveableObject factory), records the failure in
+ `self.unused_attributes` (which can then be used to make status assertions
+ fail; see `NameBasedSaverStatus`).
+
+ Args:
+ checkpointable: An object to save.
+
+ Yields:
+ SaveableObjects for `checkpointable`'s attributes.
+ """
+ for attribute_name, saveable_factory in (
+ checkpointable._gather_saveables_for_checkpoint().items()): # pylint: disable=protected-access
+ if callable(saveable_factory):
+ try:
+ # This saveable object factory does not have a default name= argument,
+ # which means there's no way to save/restore it using a name-based
+ # checkpoint. Ignore the error now and make sure assert_consumed()
+ # fails.
+ saveable = saveable_factory()
+ except TypeError:
+ self.unused_attributes.setdefault(checkpointable, []).append(
+ attribute_name)
+ continue
+ else:
+ saveable = saveable_factory
+ names_to_saveables = saver_lib.BaseSaverBuilder.OpListToDict(
+ [saveable],
+ convert_variable_to_tensor=False)
+ for name, op in names_to_saveables.items():
+ for saveable_object in saver_lib.BaseSaverBuilder.SaveableObjectsForOp(
+ op=op, name=name):
+ yield saveable_object
+
+ def eager_restore(self, checkpointable):
+ """Runs restore ops for `checkpointable`'s attributes."""
+ # When graph building, we don't add any restore ops to the graph until
+ # run_restore_ops/initialize_or_restore on the status object for name-based
+ # checkpoints.
+ assert context.executing_eagerly()
+ for saveable in self.globally_named_object_attributes(
+ checkpointable):
+ restored_tensors = []
+ for spec in saveable.specs:
+ if spec.name in self.dtype_map:
+ with ops.device("cpu:0"):
+ restored, = io_ops.restore_v2(
+ prefix=self.save_path,
+ tensor_names=[spec.name],
+ shape_and_slices=[""],
+ dtypes=[self.dtype_map[spec.name]],
+ name="%s_checkpoint_read" % (spec.name,))
+ restored_tensors.append(array_ops.identity(restored))
+
+ saveable.restore(restored_tensors=restored_tensors,
+ restored_shapes=None)
+
+
# TODO(allenl): If this ends up in a public API, consider adding LINT.IfChange
# or consolidating the implementation with get_variable.
def _default_getter(name, shape, dtype, initializer=None,
@@ -349,7 +421,7 @@ def _serialize_checkpointables(
maybe_saveable = saveable_factory(name=attribute.checkpoint_key)
else:
maybe_saveable = saveable_factory
- if isinstance(maybe_saveable, saveable_object.SaveableObject):
+ if isinstance(maybe_saveable, saveable_object_lib.SaveableObject):
saveables = (maybe_saveable,)
else:
# Figure out the name-based Saver's name for this variable. If it's
@@ -687,32 +759,61 @@ _DEPRECATED_RESTORE_INSTRUCTIONS = (
"Restoring a name-based tf.train.Saver checkpoint using the object-based "
"restore API. This mode uses global names to match variables, and so is "
"somewhat fragile. It also adds new restore ops to the graph each time it "
- "is called. Prefer re-encoding training checkpoints in the object-based "
- "format: run save() on the object-based saver (the same one this message "
- "is coming from) and use that checkpoint in the future.")
+ "is called when graph building. Prefer re-encoding training checkpoints in "
+ "the object-based format: run save() on the object-based saver (the same "
+ "one this message is coming from) and use that checkpoint in the future.")
+@deprecation.deprecated(
+ date=None, instructions=_DEPRECATED_RESTORE_INSTRUCTIONS)
class NameBasedSaverStatus(_LoadStatus):
"""Status for loading a name-based training checkpoint."""
- def __init__(self, object_saver, save_path):
- self._object_saver = object_saver
- self._save_path = save_path
+ def __init__(self, checkpoint, root_checkpointable):
+ self._checkpoint = checkpoint
+ self._root_checkpointable = root_checkpointable
def assert_consumed(self):
- """Assertion for consistency with `CheckpointLoadStatus`. Always fails."""
- raise AssertionError(
- "Restoring a name-based checkpoint. No load status is available.")
+ """Raises an exception if any variables/objects are unmatched."""
+ unused_attributes = dict(self._checkpoint.unused_attributes)
+ if unused_attributes:
+ raise AssertionError(
+ "Some objects had attributes which were not restored: %s"
+ % (unused_attributes,))
+ for checkpointable in list_objects(self._root_checkpointable):
+ # pylint: disable=protected-access
+ checkpointable._maybe_initialize_checkpointable()
+ if checkpointable._update_uid < self._checkpoint.restore_uid:
+ raise AssertionError("Object not restored: %s" % (checkpointable,))
+ # pylint: enable=protected-access
+
+ def _gather_saveable_objects(self):
+ """Walk the object graph, using global names for SaveableObjects."""
+ objects = list_objects(self._root_checkpointable)
+ saveable_objects = []
+ for checkpointable in objects:
+ # pylint: disable=protected-access
+ checkpointable._maybe_initialize_checkpointable()
+ if checkpointable._update_uid < self._checkpoint.restore_uid:
+ checkpointable._update_uid = self._checkpoint.restore_uid
+ else:
+ continue
+ # pylint: enable=protected-access
+ saveable_objects.extend(
+ self._checkpoint.globally_named_object_attributes(
+ checkpointable))
+ return saveable_objects
- @deprecation.deprecated(
- date=None, instructions=_DEPRECATED_RESTORE_INSTRUCTIONS)
def run_restore_ops(self, session=None):
"""Load the name-based training checkpoint using a new `tf.train.Saver`."""
- if session is None and not context.executing_eagerly():
+ if context.executing_eagerly():
+ return # Nothing to do, variables are restored on creation.
+ if session is None:
session = ops.get_default_session()
with ops.device("/cpu:0"):
- saver_lib.Saver(self._object_saver._global_variable_names()).restore( # pylint: disable=protected-access
- sess=session, save_path=self._save_path)
+ saveables = self._gather_saveable_objects()
+ saver_lib.Saver(saveables).restore(
+ sess=session, save_path=self._checkpoint.save_path)
def initialize_or_restore(self, session=None):
"""Alias for `run_restore_ops`."""
@@ -875,27 +976,6 @@ class CheckpointableSaver(object):
global_step=checkpoint_number)
return save_path
- def _global_variable_names(self):
- """Generate a `tf.train.Saver`-style `var_list` using `variable.name`s."""
- named_saveables, graph_proto, _ = _serialize_object_graph(
- self._root_checkpointable,
- # We destructively modify SaveableObjects, so don't do any caching.
- saveables_cache=None)
- named_saveables = {v.name: v for v in named_saveables}
- saver_names = {}
- for object_proto in graph_proto.nodes:
- for attribute_proto in object_proto.attributes:
- if attribute_proto.full_name:
- # Ignore attributes, such as Python object JSON, which don't have a
- # name-based Saver name.
- saveable = named_saveables[attribute_proto.checkpoint_key]
- saveable.name = attribute_proto.full_name
- for spec in saveable.specs:
- spec.name = spec.name.replace(attribute_proto.checkpoint_key,
- attribute_proto.full_name)
- saver_names[attribute_proto.full_name] = saveable
- return saver_names
-
def restore(self, save_path):
"""Restore a training checkpoint.
@@ -956,8 +1036,32 @@ class CheckpointableSaver(object):
"""
if save_path is None:
return InitializationOnlyStatus(self._root_checkpointable, ops.uid())
- in_graph_mode = not context.executing_eagerly()
- if in_graph_mode:
+ reader = pywrap_tensorflow.NewCheckpointReader(save_path)
+ graph_building = not context.executing_eagerly()
+ if graph_building:
+ dtype_map = None
+ else:
+ dtype_map = reader.get_variable_to_dtype_map()
+ try:
+ object_graph_string = reader.get_tensor(
+ checkpointable_lib.OBJECT_GRAPH_PROTO_KEY)
+ except errors_impl.NotFoundError:
+ # The object graph proto does not exist in this checkpoint. Try the
+ # name-based compatibility mode.
+ restore_coordinator = _NameBasedRestoreCoordinator(
+ save_path=save_path, dtype_map=dtype_map)
+ if not graph_building:
+ for existing_checkpointable in list_objects(self._root_checkpointable):
+ # pylint: disable=protected-access
+ existing_checkpointable._maybe_initialize_checkpointable()
+ existing_checkpointable._name_based_restores.add(restore_coordinator)
+ existing_checkpointable._name_based_attribute_restore(
+ restore_coordinator)
+ # pylint: enable=protected-access
+ return NameBasedSaverStatus(
+ restore_coordinator, root_checkpointable=self._root_checkpointable)
+
+ if graph_building:
if self._file_prefix_placeholder is None:
with ops.device("/cpu:0"):
self._file_prefix_placeholder = constant_op.constant("model")
@@ -967,30 +1071,17 @@ class CheckpointableSaver(object):
with ops.device("/cpu:0"):
file_prefix_tensor = constant_op.constant(save_path)
file_prefix_feed_dict = None
- reader = pywrap_tensorflow.NewCheckpointReader(save_path)
- try:
- object_graph_string = reader.get_tensor(
- checkpointable_lib.OBJECT_GRAPH_PROTO_KEY)
- except errors_impl.NotFoundError:
- # The object graph proto does not exist in this checkpoint. Try again with
- # name-based saving.
- return NameBasedSaverStatus(self, save_path)
-
object_graph_proto = (
checkpointable_object_graph_pb2.CheckpointableObjectGraph())
object_graph_proto.ParseFromString(object_graph_string)
- if in_graph_mode and object_graph_proto == self._last_restore_object_graph:
+ if graph_building and object_graph_proto == self._last_restore_object_graph:
checkpoint = self._last_restore_checkpoint
else:
- if in_graph_mode:
- dtype_map = None
- else:
- dtype_map = reader.get_variable_to_dtype_map()
checkpoint = _CheckpointRestoreCoordinator(
object_graph_proto=object_graph_proto,
save_path=file_prefix_tensor,
dtype_map=dtype_map)
- if in_graph_mode:
+ if graph_building:
if self._last_restore_object_graph is not None:
raise NotImplementedError(
"Using a single Saver to restore different object graphs is not "
@@ -1164,8 +1255,8 @@ class Checkpoint(checkpointable_lib.Checkpointable):
Returns:
The full path to the checkpoint.
"""
- in_graph_mode = not context.executing_eagerly()
- if in_graph_mode:
+ graph_building = not context.executing_eagerly()
+ if graph_building:
if session is None:
session = ops.get_default_session()
if self._save_counter is None:
@@ -1173,12 +1264,12 @@ class Checkpoint(checkpointable_lib.Checkpointable):
# needs to be initialized before assign_add. This is only an issue if
# restore() has not been called first.
session.run(self.save_counter.initializer)
- if not in_graph_mode or self._save_assign_op is None:
+ if not graph_building or self._save_assign_op is None:
with ops.colocate_with(self.save_counter):
assign_op = self.save_counter.assign_add(1, read_value=False)
- if in_graph_mode:
+ if graph_building:
self._save_assign_op = assign_op
- if in_graph_mode:
+ if graph_building:
session.run(self._save_assign_op)
return self._saver.save(
file_prefix=file_prefix,
@@ -1224,9 +1315,9 @@ class Checkpoint(checkpointable_lib.Checkpointable):
ops will grow as more objects are added to the dependency graph.
Name-based `tf.train.Saver` checkpoints can be loaded using this
- method. There is no deferred loading, and names are used to match
- variables. No restore ops are created/run until `run_restore_ops()` or
- `initialize_or_restore()` are called on the returned status object, even
+ method. Names are used to match variables. No restore ops are created/run
+ until `run_restore_ops()` or `initialize_or_restore()` are called on the
+ returned status object when graph building, but there is restore-on-creation
when executing eagerly. Re-encode name-based checkpoints using
`tf.train.Checkpoint.save` as soon as possible.
@@ -1252,14 +1343,13 @@ class Checkpoint(checkpointable_lib.Checkpointable):
- `initialize_or_restore(session=None)`:
When graph building, runs variable initializers if `save_path` is
`None`, but otherwise runs restore operations. If no `session` is
- explicitly specified, the default session is used. No effect for
- object-based checkpoints when executing eagerly (variables are
- initialized or restored eagerly).
+ explicitly specified, the default session is used. No effect when
+ executing eagerly (variables are initialized or restored eagerly).
- `run_restore_ops(session=None)`:
When graph building, runs restore operations. If no `session` is
- explicitly specified, the default session is used. No effect for
- object-based checkpoints when executing eagerly (restore operations
- are run eagerly). May only be called when `save_path` is not `None`.
+ explicitly specified, the default session is used. No effect when
+ executing eagerly (restore operations are run eagerly). May only be
+ called when `save_path` is not `None`.
"""
status = self._saver.restore(save_path=save_path)
# Create the save counter now so it gets initialized with other variables
diff --git a/tensorflow/python/training/checkpointable_utils_test.py b/tensorflow/python/training/checkpointable_utils_test.py
index d94cdcfc06..79a61584e8 100644
--- a/tensorflow/python/training/checkpointable_utils_test.py
+++ b/tensorflow/python/training/checkpointable_utils_test.py
@@ -1396,12 +1396,22 @@ class CheckpointCompatibilityTests(test.TestCase):
with self.assertRaises(AssertionError):
self._check_sentinels(root)
object_saver = checkpointable_utils.CheckpointableSaver(root)
+ self._set_sentinels(root)
status = object_saver.restore(save_path)
- with self.assertRaises(AssertionError):
- status.assert_consumed()
+ if context.executing_eagerly():
+ self._check_sentinels(root)
+ if context.executing_eagerly():
+ with self.assertRaisesRegexp(AssertionError, "OBJECT_CONFIG_JSON"):
+ status.assert_consumed()
+ else:
+ # When graph building, we haven't read any keys, so we don't know
+ # whether the restore will be complete.
+ with self.assertRaisesRegexp(AssertionError, "not restored"):
+ status.assert_consumed()
status.run_restore_ops()
self._check_sentinels(root)
self._set_sentinels(root)
+ status = object_saver.restore(save_path)
status.initialize_or_restore()
self._check_sentinels(root)