diff options
4 files changed, 205 insertions, 72 deletions
diff --git a/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py b/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py index b1f2e9d860..20316ec0e3 100644 --- a/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py +++ b/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py @@ -722,12 +722,22 @@ class CheckpointCompatibilityTests(test.TestCase): with self.assertRaises(AssertionError): self._check_sentinels(root) object_saver = checkpointable_utils.CheckpointableSaver(root) + self._set_sentinels(root) status = object_saver.restore(save_path) - with self.assertRaises(AssertionError): - status.assert_consumed() + if context.executing_eagerly(): + self._check_sentinels(root) + if context.executing_eagerly(): + with self.assertRaisesRegexp(AssertionError, "OBJECT_CONFIG_JSON"): + status.assert_consumed() + else: + # When graph building, we haven't read any keys, so we don't know + # whether the restore will be complete. + with self.assertRaisesRegexp(AssertionError, "not restored"): + status.assert_consumed() status.run_restore_ops() self._check_sentinels(root) self._set_sentinels(root) + status = object_saver.restore(save_path) status.initialize_or_restore() self._check_sentinels(root) diff --git a/tensorflow/python/training/checkpointable.py b/tensorflow/python/training/checkpointable.py index a57bcaea69..e378f0e898 100644 --- a/tensorflow/python/training/checkpointable.py +++ b/tensorflow/python/training/checkpointable.py @@ -377,6 +377,21 @@ class CheckpointableBase(object): "Internal error: the object had an update UID set before its " "initialization code was run.") self._update_uid = -1 + # When executing eagerly, holds a collection of _NameBasedRestoreCoordinator + # instances, which should be checked when creating variables or other + # saveables. These are passed on recursively to all dependencies, since + # unlike object-based checkpoint restores we don't know which subgraph is + # being restored in advance. This mechanism is only necessary for + # restore-on-create when executing eagerly, and so is unused when graph + # building. + self._name_based_restores = set() + + def _name_based_attribute_restore(self, checkpoint): + """Restore the object's attributes from a name-based checkpoint.""" + self._name_based_restores.add(checkpoint) + if self._update_uid < checkpoint.restore_uid: + checkpoint.eager_restore(self) + self._update_uid = checkpoint.restore_uid @property def _checkpoint_dependencies(self): @@ -607,6 +622,7 @@ class CheckpointableBase(object): `CheckpointableBase`). """ self._maybe_initialize_checkpointable() + checkpointable._maybe_initialize_checkpointable() # pylint: disable=protected-access deferred_dependencies_list = self._deferred_dependencies.pop(name, ()) for checkpoint_position in sorted( deferred_dependencies_list, @@ -614,6 +630,13 @@ class CheckpointableBase(object): reverse=True): checkpoint_position.restore(checkpointable) + # Pass on any name-based restores queued in this object. + for name_based_restore in sorted( + self._name_based_restores, + key=lambda checkpoint: checkpoint.restore_uid, + reverse=True): + checkpointable._name_based_attribute_restore(name_based_restore) # pylint: disable=protected-access + def _restore_from_checkpoint_position(self, checkpoint_position): """Restore this object and its dependencies (may be deferred).""" # Attempt a breadth-first traversal, since presumably the user has more diff --git a/tensorflow/python/training/checkpointable_utils.py b/tensorflow/python/training/checkpointable_utils.py index 72be434fb2..b7d9755268 100644 --- a/tensorflow/python/training/checkpointable_utils.py +++ b/tensorflow/python/training/checkpointable_utils.py @@ -30,13 +30,15 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import gen_io_ops as io_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope from tensorflow.python.training import checkpointable as checkpointable_lib from tensorflow.python.training import optimizer as optimizer_lib -from tensorflow.python.training import saveable_object +from tensorflow.python.training import saveable_object as saveable_object_lib from tensorflow.python.training import saver as saver_lib from tensorflow.python.util import deprecation from tensorflow.python.util.tf_export import tf_export @@ -117,6 +119,76 @@ class _CheckpointRestoreCoordinator(object): slot_name=slot_reference.slot_name)) +class _NameBasedRestoreCoordinator(object): + """Keeps the status of a name-based checkpoint restore.""" + + def __init__(self, save_path, dtype_map=None): + self.save_path = save_path + self.dtype_map = dtype_map + self.unused_attributes = weakref.WeakKeyDictionary() + self.restore_uid = ops.uid() + + def globally_named_object_attributes(self, checkpointable): + """Create globally named SaveableObjects from attributes. + + If an object's attribute has no global name specified (default construction + for the SaveableObject factory), records the failure in + `self.unused_attributes` (which can then be used to make status assertions + fail; see `NameBasedSaverStatus`). + + Args: + checkpointable: An object to save. + + Yields: + SaveableObjects for `checkpointable`'s attributes. + """ + for attribute_name, saveable_factory in ( + checkpointable._gather_saveables_for_checkpoint().items()): # pylint: disable=protected-access + if callable(saveable_factory): + try: + # This saveable object factory does not have a default name= argument, + # which means there's no way to save/restore it using a name-based + # checkpoint. Ignore the error now and make sure assert_consumed() + # fails. + saveable = saveable_factory() + except TypeError: + self.unused_attributes.setdefault(checkpointable, []).append( + attribute_name) + continue + else: + saveable = saveable_factory + names_to_saveables = saver_lib.BaseSaverBuilder.OpListToDict( + [saveable], + convert_variable_to_tensor=False) + for name, op in names_to_saveables.items(): + for saveable_object in saver_lib.BaseSaverBuilder.SaveableObjectsForOp( + op=op, name=name): + yield saveable_object + + def eager_restore(self, checkpointable): + """Runs restore ops for `checkpointable`'s attributes.""" + # When graph building, we don't add any restore ops to the graph until + # run_restore_ops/initialize_or_restore on the status object for name-based + # checkpoints. + assert context.executing_eagerly() + for saveable in self.globally_named_object_attributes( + checkpointable): + restored_tensors = [] + for spec in saveable.specs: + if spec.name in self.dtype_map: + with ops.device("cpu:0"): + restored, = io_ops.restore_v2( + prefix=self.save_path, + tensor_names=[spec.name], + shape_and_slices=[""], + dtypes=[self.dtype_map[spec.name]], + name="%s_checkpoint_read" % (spec.name,)) + restored_tensors.append(array_ops.identity(restored)) + + saveable.restore(restored_tensors=restored_tensors, + restored_shapes=None) + + # TODO(allenl): If this ends up in a public API, consider adding LINT.IfChange # or consolidating the implementation with get_variable. def _default_getter(name, shape, dtype, initializer=None, @@ -349,7 +421,7 @@ def _serialize_checkpointables( maybe_saveable = saveable_factory(name=attribute.checkpoint_key) else: maybe_saveable = saveable_factory - if isinstance(maybe_saveable, saveable_object.SaveableObject): + if isinstance(maybe_saveable, saveable_object_lib.SaveableObject): saveables = (maybe_saveable,) else: # Figure out the name-based Saver's name for this variable. If it's @@ -687,32 +759,61 @@ _DEPRECATED_RESTORE_INSTRUCTIONS = ( "Restoring a name-based tf.train.Saver checkpoint using the object-based " "restore API. This mode uses global names to match variables, and so is " "somewhat fragile. It also adds new restore ops to the graph each time it " - "is called. Prefer re-encoding training checkpoints in the object-based " - "format: run save() on the object-based saver (the same one this message " - "is coming from) and use that checkpoint in the future.") + "is called when graph building. Prefer re-encoding training checkpoints in " + "the object-based format: run save() on the object-based saver (the same " + "one this message is coming from) and use that checkpoint in the future.") +@deprecation.deprecated( + date=None, instructions=_DEPRECATED_RESTORE_INSTRUCTIONS) class NameBasedSaverStatus(_LoadStatus): """Status for loading a name-based training checkpoint.""" - def __init__(self, object_saver, save_path): - self._object_saver = object_saver - self._save_path = save_path + def __init__(self, checkpoint, root_checkpointable): + self._checkpoint = checkpoint + self._root_checkpointable = root_checkpointable def assert_consumed(self): - """Assertion for consistency with `CheckpointLoadStatus`. Always fails.""" - raise AssertionError( - "Restoring a name-based checkpoint. No load status is available.") + """Raises an exception if any variables/objects are unmatched.""" + unused_attributes = dict(self._checkpoint.unused_attributes) + if unused_attributes: + raise AssertionError( + "Some objects had attributes which were not restored: %s" + % (unused_attributes,)) + for checkpointable in list_objects(self._root_checkpointable): + # pylint: disable=protected-access + checkpointable._maybe_initialize_checkpointable() + if checkpointable._update_uid < self._checkpoint.restore_uid: + raise AssertionError("Object not restored: %s" % (checkpointable,)) + # pylint: enable=protected-access + + def _gather_saveable_objects(self): + """Walk the object graph, using global names for SaveableObjects.""" + objects = list_objects(self._root_checkpointable) + saveable_objects = [] + for checkpointable in objects: + # pylint: disable=protected-access + checkpointable._maybe_initialize_checkpointable() + if checkpointable._update_uid < self._checkpoint.restore_uid: + checkpointable._update_uid = self._checkpoint.restore_uid + else: + continue + # pylint: enable=protected-access + saveable_objects.extend( + self._checkpoint.globally_named_object_attributes( + checkpointable)) + return saveable_objects - @deprecation.deprecated( - date=None, instructions=_DEPRECATED_RESTORE_INSTRUCTIONS) def run_restore_ops(self, session=None): """Load the name-based training checkpoint using a new `tf.train.Saver`.""" - if session is None and not context.executing_eagerly(): + if context.executing_eagerly(): + return # Nothing to do, variables are restored on creation. + if session is None: session = ops.get_default_session() with ops.device("/cpu:0"): - saver_lib.Saver(self._object_saver._global_variable_names()).restore( # pylint: disable=protected-access - sess=session, save_path=self._save_path) + saveables = self._gather_saveable_objects() + saver_lib.Saver(saveables).restore( + sess=session, save_path=self._checkpoint.save_path) def initialize_or_restore(self, session=None): """Alias for `run_restore_ops`.""" @@ -875,27 +976,6 @@ class CheckpointableSaver(object): global_step=checkpoint_number) return save_path - def _global_variable_names(self): - """Generate a `tf.train.Saver`-style `var_list` using `variable.name`s.""" - named_saveables, graph_proto, _ = _serialize_object_graph( - self._root_checkpointable, - # We destructively modify SaveableObjects, so don't do any caching. - saveables_cache=None) - named_saveables = {v.name: v for v in named_saveables} - saver_names = {} - for object_proto in graph_proto.nodes: - for attribute_proto in object_proto.attributes: - if attribute_proto.full_name: - # Ignore attributes, such as Python object JSON, which don't have a - # name-based Saver name. - saveable = named_saveables[attribute_proto.checkpoint_key] - saveable.name = attribute_proto.full_name - for spec in saveable.specs: - spec.name = spec.name.replace(attribute_proto.checkpoint_key, - attribute_proto.full_name) - saver_names[attribute_proto.full_name] = saveable - return saver_names - def restore(self, save_path): """Restore a training checkpoint. @@ -956,8 +1036,32 @@ class CheckpointableSaver(object): """ if save_path is None: return InitializationOnlyStatus(self._root_checkpointable, ops.uid()) - in_graph_mode = not context.executing_eagerly() - if in_graph_mode: + reader = pywrap_tensorflow.NewCheckpointReader(save_path) + graph_building = not context.executing_eagerly() + if graph_building: + dtype_map = None + else: + dtype_map = reader.get_variable_to_dtype_map() + try: + object_graph_string = reader.get_tensor( + checkpointable_lib.OBJECT_GRAPH_PROTO_KEY) + except errors_impl.NotFoundError: + # The object graph proto does not exist in this checkpoint. Try the + # name-based compatibility mode. + restore_coordinator = _NameBasedRestoreCoordinator( + save_path=save_path, dtype_map=dtype_map) + if not graph_building: + for existing_checkpointable in list_objects(self._root_checkpointable): + # pylint: disable=protected-access + existing_checkpointable._maybe_initialize_checkpointable() + existing_checkpointable._name_based_restores.add(restore_coordinator) + existing_checkpointable._name_based_attribute_restore( + restore_coordinator) + # pylint: enable=protected-access + return NameBasedSaverStatus( + restore_coordinator, root_checkpointable=self._root_checkpointable) + + if graph_building: if self._file_prefix_placeholder is None: with ops.device("/cpu:0"): self._file_prefix_placeholder = constant_op.constant("model") @@ -967,30 +1071,17 @@ class CheckpointableSaver(object): with ops.device("/cpu:0"): file_prefix_tensor = constant_op.constant(save_path) file_prefix_feed_dict = None - reader = pywrap_tensorflow.NewCheckpointReader(save_path) - try: - object_graph_string = reader.get_tensor( - checkpointable_lib.OBJECT_GRAPH_PROTO_KEY) - except errors_impl.NotFoundError: - # The object graph proto does not exist in this checkpoint. Try again with - # name-based saving. - return NameBasedSaverStatus(self, save_path) - object_graph_proto = ( checkpointable_object_graph_pb2.CheckpointableObjectGraph()) object_graph_proto.ParseFromString(object_graph_string) - if in_graph_mode and object_graph_proto == self._last_restore_object_graph: + if graph_building and object_graph_proto == self._last_restore_object_graph: checkpoint = self._last_restore_checkpoint else: - if in_graph_mode: - dtype_map = None - else: - dtype_map = reader.get_variable_to_dtype_map() checkpoint = _CheckpointRestoreCoordinator( object_graph_proto=object_graph_proto, save_path=file_prefix_tensor, dtype_map=dtype_map) - if in_graph_mode: + if graph_building: if self._last_restore_object_graph is not None: raise NotImplementedError( "Using a single Saver to restore different object graphs is not " @@ -1164,8 +1255,8 @@ class Checkpoint(checkpointable_lib.Checkpointable): Returns: The full path to the checkpoint. """ - in_graph_mode = not context.executing_eagerly() - if in_graph_mode: + graph_building = not context.executing_eagerly() + if graph_building: if session is None: session = ops.get_default_session() if self._save_counter is None: @@ -1173,12 +1264,12 @@ class Checkpoint(checkpointable_lib.Checkpointable): # needs to be initialized before assign_add. This is only an issue if # restore() has not been called first. session.run(self.save_counter.initializer) - if not in_graph_mode or self._save_assign_op is None: + if not graph_building or self._save_assign_op is None: with ops.colocate_with(self.save_counter): assign_op = self.save_counter.assign_add(1, read_value=False) - if in_graph_mode: + if graph_building: self._save_assign_op = assign_op - if in_graph_mode: + if graph_building: session.run(self._save_assign_op) return self._saver.save( file_prefix=file_prefix, @@ -1224,9 +1315,9 @@ class Checkpoint(checkpointable_lib.Checkpointable): ops will grow as more objects are added to the dependency graph. Name-based `tf.train.Saver` checkpoints can be loaded using this - method. There is no deferred loading, and names are used to match - variables. No restore ops are created/run until `run_restore_ops()` or - `initialize_or_restore()` are called on the returned status object, even + method. Names are used to match variables. No restore ops are created/run + until `run_restore_ops()` or `initialize_or_restore()` are called on the + returned status object when graph building, but there is restore-on-creation when executing eagerly. Re-encode name-based checkpoints using `tf.train.Checkpoint.save` as soon as possible. @@ -1252,14 +1343,13 @@ class Checkpoint(checkpointable_lib.Checkpointable): - `initialize_or_restore(session=None)`: When graph building, runs variable initializers if `save_path` is `None`, but otherwise runs restore operations. If no `session` is - explicitly specified, the default session is used. No effect for - object-based checkpoints when executing eagerly (variables are - initialized or restored eagerly). + explicitly specified, the default session is used. No effect when + executing eagerly (variables are initialized or restored eagerly). - `run_restore_ops(session=None)`: When graph building, runs restore operations. If no `session` is - explicitly specified, the default session is used. No effect for - object-based checkpoints when executing eagerly (restore operations - are run eagerly). May only be called when `save_path` is not `None`. + explicitly specified, the default session is used. No effect when + executing eagerly (restore operations are run eagerly). May only be + called when `save_path` is not `None`. """ status = self._saver.restore(save_path=save_path) # Create the save counter now so it gets initialized with other variables diff --git a/tensorflow/python/training/checkpointable_utils_test.py b/tensorflow/python/training/checkpointable_utils_test.py index d94cdcfc06..79a61584e8 100644 --- a/tensorflow/python/training/checkpointable_utils_test.py +++ b/tensorflow/python/training/checkpointable_utils_test.py @@ -1396,12 +1396,22 @@ class CheckpointCompatibilityTests(test.TestCase): with self.assertRaises(AssertionError): self._check_sentinels(root) object_saver = checkpointable_utils.CheckpointableSaver(root) + self._set_sentinels(root) status = object_saver.restore(save_path) - with self.assertRaises(AssertionError): - status.assert_consumed() + if context.executing_eagerly(): + self._check_sentinels(root) + if context.executing_eagerly(): + with self.assertRaisesRegexp(AssertionError, "OBJECT_CONFIG_JSON"): + status.assert_consumed() + else: + # When graph building, we haven't read any keys, so we don't know + # whether the restore will be complete. + with self.assertRaisesRegexp(AssertionError, "not restored"): + status.assert_consumed() status.run_restore_ops() self._check_sentinels(root) self._set_sentinels(root) + status = object_saver.restore(save_path) status.initialize_or_restore() self._check_sentinels(root) |