aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Allen Lavoie <allenl@google.com>2017-11-14 14:28:04 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-14 14:32:12 -0800
commit3fab2f9bbdf5745643d2dd0a390e1dd762c85bc2 (patch)
tree19f38a5041d275998664294b07a0e273d0514550
parent98b52cfd420fc054ad082bf1865d9eabee0b7a3e (diff)
Make save/restore non-members of tfe.Network. This should make it easier to move
to core. tfe.Network.save -> tfe.save_network_checkpoint tfe.Network.restore -> tfe.restore_network_checkpoint Some minor changes in the restore-on-load logic to make it work as a non-member of Network (particularly in _add_deferred_restoration). The other code changes are trivial, just moving code around. PiperOrigin-RevId: 175735659
-rw-r--r--tensorflow/contrib/eager/python/network.py860
-rw-r--r--tensorflow/contrib/eager/python/network_test.py83
-rw-r--r--tensorflow/contrib/eager/python/tfe.py7
3 files changed, 485 insertions, 465 deletions
diff --git a/tensorflow/contrib/eager/python/network.py b/tensorflow/contrib/eager/python/network.py
index 1a5c6e8aec..713ab1ee57 100644
--- a/tensorflow/contrib/eager/python/network.py
+++ b/tensorflow/contrib/eager/python/network.py
@@ -37,185 +37,6 @@ from tensorflow.python.training import training_util
# functions in base.py which should be reused.
-_DeferredRestoration = collections.namedtuple(
-
- "_DeferredRestoration",
- [
- # The map_func to use (either user-specified or the default).
- "map_func",
- # Boolean, True if the user specified an explicit map_func, for error
- # messages.
- "map_func_is_user",
- # A mapping from checkpoint names to initial values of not-yet-created
- # variables which should be restored. These values come from parsing a
- # checkpoint.
- "checkpointed_variables_to_restore",
- # A mapping from checkpoint name to variable objects of variables which
- # have already been restored, for error checking.
- "restored_variables",
- # The session to restore with (if in graph mode).
- "session",
- # Names of the Network where the restore was requested, for error
- # messages.
- "network_name",
- "network_scope_name"
- ])
-
-
-def _default_naming_conflict_error_message(
- mapped_name, first_variable, second_variable,
- network_name, network_scope_name):
- return (
- ("The default checkpoint variable name mapping strategy for Network "
- "'%s' resulted in a naming conflict. We attempted to strip off the "
- "variable prefix for the Network ('%s'), but this resulted in two "
- "variables named '%s' (originally '%s' and '%s'). This should only "
- "happen when using variable sharing (i.e. the Network contains Networks "
- "or Layers which were first added to another Network, and therefore "
- "have that Network's variable prefix). One solution is to pass "
- "`map_func=lambda n: n` to Network.save and Network.restore to use "
- "fully qualified variable names in the checkpoint, although this will "
- "require that the variable prefix of the Network being restored into "
- "is also '%s'. You may alternatively write an arbitrary mapping.")
- % (
- network_name, network_scope_name, mapped_name,
- first_variable._shared_name,
- second_variable._shared_name, network_scope_name
- ))
-
-
-def _restore_custom_map_func_error_message(
- mapped_name, first_variable, second_variable,
- network_name, network_scope_name):
- return (
- ("The map_func passed to Network.restore for the Network '%s' "
- "resulted in two variables named '%s' (originally '%s' and '%s'). Since "
- "this is also an error on Network.save, this Network was "
- "probably not saved with this map_func. Note that map_func "
- "always maps from full variable names to checkpoint names; "
- "there is no need to specify an inverse mapping.\n\n"
- "Try stripping less from the variable names, or renaming parts "
- "of the Network. For reference, variables created by sub-Layers "
- "of this Network are prefixed with '%s', but if they are "
- "re-used after being added to another Network they will have "
- "that Network's full variable prefix instead.") % (
- network_name, mapped_name,
- first_variable._shared_name,
- second_variable._shared_name,
- network_scope_name))
-
-
-def _make_custom_getter_for_deferred_restorations():
- """Returns a custom getter which searches `deferred_restorations`.
-
- Returns: A tuple of (_custom_getter, deferred_restorations)
- _custom_getter: The getter which should be added to variable_scopes where
- variables will be created.
- deferred_restorations: A list for _DeferredRestoration objects. Typically
- empty when the getter is set, and expanded as deferred restorations are
- requested. All new deferred restorations should be appended to the end of
- the list, where they will have priority over older deferred restorations.
- """
- deferred_restorations = []
-
- def _custom_getter(getter, name, shape=None, dtype=None,
- initializer=None,
- *args, **kwargs):
- """A custom getter which processes deferred restorations."""
- # Iterate over restorations, newest first (newer restorations will take
- # precedence over older restorations, just like with immediate restorations
- # into existing variables).
- delayed_restoration = None
- found_value = False
- value_to_restore = None
- for delayed_restoration in reversed(
- deferred_restorations):
- checkpoint_name = delayed_restoration.map_func(name)
- if (checkpoint_name
- in delayed_restoration.checkpointed_variables_to_restore):
- found_value = True
- value_to_restore = (
- delayed_restoration.checkpointed_variables_to_restore[
- checkpoint_name])
- if found_value:
- break
- # value_to_restore may be False because this variable is not in any
- # checkpoint we are restoring, or None because we have explicitly set it to
- # None when it was previously fetched. In either case, we don't need to
- # set an initializer.
- if found_value and value_to_restore is not None:
- initializer = value_to_restore
- shape = None
- variable = getter(name, shape=shape, dtype=dtype, initializer=initializer,
- *args, **kwargs)
- if found_value and value_to_restore is not None:
- # Mark as already restored from this checkpoint.
- delayed_restoration.checkpointed_variables_to_restore[
- checkpoint_name] = None
- if context.in_graph_mode():
- delayed_restoration.session.run(variable.initializer)
- if found_value:
- # Error checking should run even if we've already restored a value.
- if delayed_restoration.restored_variables.setdefault(
- checkpoint_name, variable) is not variable:
- # Naming conflict. We've tried to initialize two variables with the
- # same value from the checkpoint.
- if delayed_restoration.map_func_is_user:
- raise ValueError(
- _restore_custom_map_func_error_message(
- mapped_name=checkpoint_name,
- first_variable=delayed_restoration.restored_variables[
- checkpoint_name],
- second_variable=variable,
- network_name=delayed_restoration.network_name,
- network_scope_name=delayed_restoration.network_scope_name))
- else:
- raise ValueError(
- _default_naming_conflict_error_message(
- mapped_name=checkpoint_name,
- first_variable=delayed_restoration.restored_variables[
- checkpoint_name],
- second_variable=variable,
- network_name=delayed_restoration.network_name,
- network_scope_name=delayed_restoration.network_scope_name))
- return variable
- return _custom_getter, deferred_restorations
-
-
-def _make_prefix_stripping_map_fn(scope_name):
- """Closure for stripping the scope name of a Network.
-
- Implemented as a closure rather than a member function to avoid reference
- cycles in deferred restorations (this function should not have a reference to
- the Network which created it).
-
- Args:
- scope_name: The Network.scope_name to strip from variables.
- Returns:
- A scope_name-stripping default `map_fn` for the Network.
- """
-
- def _strip_variable_prefix(original_variable_name):
- """The default map_func for saving or restoring variables.
-
- Strips the variable prefix for the Network on which save/restore was called,
- and leaves other variable names fully qualified in the checkpoint.
-
- Args:
- original_variable_name: The _shared_name of the variable (no :0
- suffix) to map.
- Returns:
- The checkpoint name of the variable.
- """
- scope_name_with_slash = scope_name + "/"
- if original_variable_name.startswith(scope_name_with_slash):
- return original_variable_name[len(scope_name_with_slash):]
- else:
- return original_variable_name
-
- return _strip_variable_prefix
-
-
class Network(base.Layer):
"""Represents the composition of a set of Layers.
@@ -250,8 +71,6 @@ class Network(base.Layer):
# closed before build is called.
self._variable_scope_counts_on_init = (
variable_scope._get_default_variable_store().variable_scopes_count)
- self._custom_getter, self._deferred_restorations = (
- _make_custom_getter_for_deferred_restorations())
def _init_set_name(self, name):
# Anonymous Networks (name=None) defer setting a final name until they are
@@ -543,252 +362,6 @@ class Network(base.Layer):
"at https://github.com/tensorflow/tensorflow/issues/new if this is "
"important to you")
- def save(self, save_path, global_step=None, map_func=None):
- """Save variables from the Network to a checkpoint.
-
- Args:
- save_path: Either a checkpoint prefix or the name of a directory to save
- the checkpoint in (in which case the checkpoint will be named based on
- the Network name).
- global_step: The global step to use when naming the checkpoint. If None
- (default), we will first try to get the default global step. If that
- fails because no default global step exists, then the checkpoint is
- created without a global step suffix.
- map_func: A function mapping fully qualified variable names
- (e.g. 'my_network_1/dense_1/kernel') to names in the checkpoint. By
- default (if `map_func=None`), the variable prefix for the network being
- restored (`Network.scope_name + '/'`, e.g. 'my_network_1/') is stripped
- and all other variable names (shared with other Networks) are left
- unchanged.
- Returns:
- The checkpoint prefix for the saved checkpoint, which may be passed to
- `Network.restore`.
- Raises:
- ValueError: If the Network has not yet been called, or if map_func results
- in a name collision.
- """
- if not self.built:
- raise ValueError(
- "Attempt to save the Network before it was first called. This means "
- "variables have not yet been created, so there is nothing to save.")
- self._set_scope() # scope_name should be available to map_funcs
- if global_step is None:
- global_step = training_util.get_global_step()
- if os.path.isdir(save_path):
- # If we were passed a directory, default to naming based on the Network
- # name.
- save_path = os.path.join(save_path, self.name.replace("/", "_"))
- user_map_func = map_func
- if map_func is None:
- map_func = _make_prefix_stripping_map_fn(self.scope_name)
- variable_map = {}
- for variable in self.variables:
- mapped_name = map_func(variable._shared_name)
- if variable_map.setdefault(mapped_name, variable) is not variable:
- if user_map_func is None:
- # Instead of erroring out, we could just re-try and silently use the
- # full variable names in the checkpoint. This could be odd for deeply
- # nested sub-Networks (since the full prefix from the nesting would
- # get added), so for now we'll let the user deal with this case.
- raise ValueError(_default_naming_conflict_error_message(
- mapped_name=mapped_name,
- first_variable=variable_map[mapped_name],
- second_variable=variable,
- network_name=self.name,
- network_scope_name=self.scope_name))
- else:
- # The user passed their own problematic map_func.
- raise ValueError(
- ("The map_func passed to Network.save for the Network '%s' "
- "resulted in two variables named '%s' ('%s' and '%s'). Try "
- "stripping less from the variable names, or renaming parts of "
- "the Network. For reference, variables created by sub-Layers of "
- "this Network are prefixed with '%s', but if they are re-used "
- "after being added to another Network, they will have that "
- "Network's full variable prefix instead.") % (
- self.name, mapped_name,
- variable_map[mapped_name]._shared_name,
- variable._shared_name,
- self.scope_name))
- if context.in_eager_mode():
- sess = None
- else:
- sess = ops.get_default_session()
- return saver_lib.Saver(variable_map).save(
- sess=sess, save_path=save_path, write_meta_graph=False,
- global_step=global_step)
-
- def _restore_existing_variables(self, save_path, map_func, user_map_func):
- """Use a standard Saver to restore existing variables from a checkpoint.
-
- Args:
- save_path: The checkpoint prefix or directory to read from.
- map_func: The function to use when mapping from variable names to
- checkpoint names.
- user_map_func: The original map_func passed by the user, for error
- checking.
- Returns:
- A dictionary mapping from checkpoint names to variable objects which have
- been restored (for bookkeeping to avoid deferred restorations on these
- variables).
- Raises:
- ValueError: If there is a name collision.
- """
- existing_variables_by_checkpoint_name = {}
- for variable in self.variables:
- checkpoint_name = map_func(variable._shared_name)
- if existing_variables_by_checkpoint_name.setdefault(
- checkpoint_name, variable) is not variable:
- if user_map_func is None:
- raise ValueError(_default_naming_conflict_error_message(
- mapped_name=checkpoint_name,
- first_variable=existing_variables_by_checkpoint_name[
- checkpoint_name],
- second_variable=variable,
- network_name=self.name,
- network_scope_name=self.scope_name))
- else:
- raise ValueError(_restore_custom_map_func_error_message(
- mapped_name=checkpoint_name,
- first_variable=existing_variables_by_checkpoint_name[
- checkpoint_name],
- second_variable=variable,
- network_name=self.name,
- network_scope_name=self.scope_name))
- if existing_variables_by_checkpoint_name:
- if context.in_eager_mode():
- sess = None
- else:
- sess = ops.get_default_session()
- saver_lib.Saver(var_list=existing_variables_by_checkpoint_name).restore(
- sess=sess, save_path=save_path)
- return existing_variables_by_checkpoint_name
-
- def _set_restore_on_create(self, save_path, map_func, user_map_func,
- existing_variables_by_checkpoint_name):
- """If necessary, request deferred restorations of variables."""
- checkpoint_reader = checkpoint_utils.load_checkpoint(save_path)
- checkpointed_variables_to_restore = {}
- for checkpoint_name, _ in checkpoint_utils.list_variables(save_path):
- if checkpoint_name in existing_variables_by_checkpoint_name:
- # This variable was already created and restored.
- continue
- # Save the variable for later restoration in a custom getter.
- checkpointed_variables_to_restore[checkpoint_name] = (
- checkpoint_reader.get_tensor(checkpoint_name))
- # Only set a deferred restoration if there are checkpoint variables which
- # have not been assigned to existing variables. Note that this loses out on
- # some opportunity for error checking, but avoids creating
- # _DeferredRestoration objects once a Network has been built (so that
- # restoring in a loop does not take increasing amounts of memory).
- if checkpointed_variables_to_restore:
- if context.in_eager_mode():
- sess = None
- else:
- sess = ops.get_default_session()
- # We need a name for error messages. If we haven't been added to another
- # Network yet, we're top-level.
- self._finalize_name(False)
- self._set_scope()
- # Save a record of this restoration for use in the custom getter.
- deferred_restoration = _DeferredRestoration(
- map_func=map_func,
- map_func_is_user=(user_map_func is not None),
- checkpointed_variables_to_restore=checkpointed_variables_to_restore,
- restored_variables={},
- session=sess,
- network_name=self.name,
- network_scope_name=self.scope_name)
- self._deferred_restorations.append(deferred_restoration)
- # Add the deferred registration to non-Network children, and request that
- # Networks propagate the request to their children.
- self._add_deferred_restoration(deferred_restoration)
-
- def _add_deferred_restoration(self, deferred_restoration):
- """Add a deferred restoration to this Network and all children.
-
- Restorations which are requested later have higher priority, and the highest
- priority matching restoration is applied to a variable when it is created.
-
- Args:
- deferred_restoration: A _DeferredRestoration object.
- """
- # Networks don't create variables at the moment, so this append isn't
- # strictly necessary. We could get by with only adding deferred restorations
- # to non-Network Layers.
- self._set_scope()
- # We use set_custom_getter because it avoids recursively calling up the
- # variable_scope tree. We've done the tree traversal ourselves and have
- # added the request to each Layer which needs it.
- self._scope.set_custom_getter(self._custom_getter)
- self._deferred_restorations.append(deferred_restoration)
- for layer in self.layers:
- if isinstance(layer, Network):
- # For Networks, request that they propagate this deferred restoration
- # to all of their children recursively.
- layer._add_deferred_restoration(deferred_restoration)
- else:
- # For non-Network Layers, make sure they have a deferred restoration
- # queue and a custom getter, then add our request to it.
- if not hasattr(layer, "_custom_getter"):
- assert not hasattr(layer, "_deferred_restorations")
- layer._custom_getter, layer._deferred_restorations = (
- _make_custom_getter_for_deferred_restorations())
- self._set_scope_for_nonnetwork_sublayer(layer)
- layer._scope.set_custom_getter(layer._custom_getter)
- layer._deferred_restorations.append(deferred_restoration)
-
- def restore(self, save_path, map_func=None):
- """Restore the Network from a checkpoint.
-
- If variables have already been created (typically when some or all of the
- `Network` is built), they are assigned values from the checkpoint
- immediately, overwriting any existing values (in graph mode the default
- session is used for the assignments).
-
- If there are checkpoint entries which do not correspond to any existing
- variables in the `Network`, these values are saved for deferred restoration;
- their initial values will be the checkpointed values once they are
- created. Requests for multiple deferred restorations behave the same way as
- immediate restorations, in that later requests will take priority over
- earlier requests relevant to the same variable.
-
- If this `Network` shares `Layer`s with another network, those `Layer`s will
- also have their variables restored from the checkpoint.
-
- Args:
- save_path: The return value of `Network.save`, or a directory to search
- for a checkpoint.
- map_func: A function mapping fully qualified variable names
- (e.g. 'my_network_1/dense_1/kernel') to names in the checkpoint. By
- default (if `map_func=None`), the variable prefix for the network being
- restored (`Network.scope_name + '/'`, e.g. 'my_network_1/') is stripped
- and all other variable names (shared with other Networks) are left
- unchanged. Note that this is the _same_ map_func as `Network.save`, not
- an inverse mapping.
- """
- self._finalize_name(parent_network=False)
- self._set_scope() # scope_name should be available to map_funcs
- if os.path.isdir(save_path):
- # If we don't have a name yet, set no parent.
- save_path = os.path.join(save_path, self.name.replace("/", "_"))
- user_map_func = map_func
- if map_func is None:
- map_func = _make_prefix_stripping_map_fn(self.scope_name)
- # Step one is to restore any existing variables from the checkpoint.
- existing_variables_by_checkpoint_name = self._restore_existing_variables(
- save_path=save_path,
- map_func=map_func,
- user_map_func=user_map_func)
- # Step two is to set a custom getter which restores variables on creation,
- # for those variables which have not been added to sub-Layers yet.
- self._set_restore_on_create(
- save_path=save_path,
- map_func=map_func,
- user_map_func=user_map_func,
- existing_variables_by_checkpoint_name=(
- existing_variables_by_checkpoint_name))
-
# TODO(josh11b): Support other Layer methods needed for graph mode, such as for
# losses and updates
@@ -838,3 +411,436 @@ class Sequential(Network):
else:
inputs = l(inputs)
return inputs
+
+
+_DeferredRestoration = collections.namedtuple(
+
+ "_DeferredRestoration",
+ [
+ # The map_func to use (either user-specified or the default).
+ "map_func",
+ # Boolean, True if the user specified an explicit map_func, for error
+ # messages.
+ "map_func_is_user",
+ # A mapping from checkpoint names to initial values of not-yet-created
+ # variables which should be restored. These values come from parsing a
+ # checkpoint.
+ "checkpointed_variables_to_restore",
+ # A mapping from checkpoint name to variable objects of variables which
+ # have already been restored, for error checking.
+ "restored_variables",
+ # The session to restore with (if in graph mode).
+ "session",
+ # Names of the Network where the restore was requested, for error
+ # messages.
+ "network_name",
+ "network_scope_name"
+ ])
+
+
+def _default_naming_conflict_error_message(
+ mapped_name, first_variable, second_variable,
+ network_name, network_scope_name):
+ return (
+ ("The default checkpoint variable name mapping strategy for Network "
+ "'%s' resulted in a naming conflict. We attempted to strip off the "
+ "variable prefix for the Network ('%s'), but this resulted in two "
+ "variables named '%s' (originally '%s' and '%s'). This should only "
+ "happen when using variable sharing (i.e. the Network contains Networks "
+ "or Layers which were first added to another Network, and therefore "
+ "have that Network's variable prefix). One solution is to pass "
+ "`map_func=lambda n: n` to save and restore to use fully qualified "
+ "variable names in the checkpoint, although this will require that the "
+ "variable prefix of the Network being restored into is also '%s'. You "
+ "may alternatively write an arbitrary mapping.")
+ % (
+ network_name, network_scope_name, mapped_name,
+ first_variable._shared_name,
+ second_variable._shared_name, network_scope_name
+ ))
+
+
+def _restore_custom_map_func_error_message(
+ mapped_name, first_variable, second_variable,
+ network_name, network_scope_name):
+ return (
+ ("The map_func passed to restore_network_checkpoint for the Network '%s' "
+ "resulted in two variables named '%s' (originally '%s' and '%s'). Since "
+ "this is also an error when saving, this Network was "
+ "probably not saved with this map_func. Note that map_func "
+ "always maps from full variable names to checkpoint names; "
+ "there is no need to specify an inverse mapping.\n\n"
+ "Try stripping less from the variable names, or renaming parts "
+ "of the Network. For reference, variables created by sub-Layers "
+ "of this Network are prefixed with '%s', but if they are "
+ "re-used after being added to another Network they will have "
+ "that Network's full variable prefix instead.") % (
+ network_name, mapped_name,
+ first_variable._shared_name,
+ second_variable._shared_name,
+ network_scope_name))
+
+
+def _make_custom_getter_for_deferred_restorations():
+ """Returns a custom getter which searches `deferred_restorations`.
+
+ Returns: A tuple of (_custom_getter, deferred_restorations)
+ _custom_getter: The getter which should be added to variable_scopes where
+ variables will be created.
+ deferred_restorations: A list for _DeferredRestoration objects. Typically
+ empty when the getter is set, and expanded as deferred restorations are
+ requested. All new deferred restorations should be appended to the end of
+ the list, where they will have priority over older deferred restorations.
+ """
+ deferred_restorations = []
+
+ def _custom_getter(getter, name, shape=None, dtype=None,
+ initializer=None,
+ *args, **kwargs):
+ """A custom getter which processes deferred restorations."""
+ # Iterate over restorations, newest first (newer restorations will take
+ # precedence over older restorations, just like with immediate restorations
+ # into existing variables).
+ delayed_restoration = None
+ found_value = False
+ value_to_restore = None
+ for delayed_restoration in reversed(
+ deferred_restorations):
+ checkpoint_name = delayed_restoration.map_func(name)
+ if (checkpoint_name
+ in delayed_restoration.checkpointed_variables_to_restore):
+ found_value = True
+ value_to_restore = (
+ delayed_restoration.checkpointed_variables_to_restore[
+ checkpoint_name])
+ if found_value:
+ break
+ # value_to_restore may be False because this variable is not in any
+ # checkpoint we are restoring, or None because we have explicitly set it to
+ # None when it was previously fetched. In either case, we don't need to
+ # set an initializer.
+ if found_value and value_to_restore is not None:
+ initializer = value_to_restore
+ shape = None
+ variable = getter(name, shape=shape, dtype=dtype, initializer=initializer,
+ *args, **kwargs)
+ if found_value and value_to_restore is not None:
+ # Mark as already restored from this checkpoint.
+ delayed_restoration.checkpointed_variables_to_restore[
+ checkpoint_name] = None
+ if context.in_graph_mode():
+ delayed_restoration.session.run(variable.initializer)
+ if found_value:
+ # Error checking should run even if we've already restored a value.
+ if delayed_restoration.restored_variables.setdefault(
+ checkpoint_name, variable) is not variable:
+ # Naming conflict. We've tried to initialize two variables with the
+ # same value from the checkpoint.
+ if delayed_restoration.map_func_is_user:
+ raise ValueError(
+ _restore_custom_map_func_error_message(
+ mapped_name=checkpoint_name,
+ first_variable=delayed_restoration.restored_variables[
+ checkpoint_name],
+ second_variable=variable,
+ network_name=delayed_restoration.network_name,
+ network_scope_name=delayed_restoration.network_scope_name))
+ else:
+ raise ValueError(
+ _default_naming_conflict_error_message(
+ mapped_name=checkpoint_name,
+ first_variable=delayed_restoration.restored_variables[
+ checkpoint_name],
+ second_variable=variable,
+ network_name=delayed_restoration.network_name,
+ network_scope_name=delayed_restoration.network_scope_name))
+ return variable
+ return _custom_getter, deferred_restorations
+
+
+def _make_prefix_stripping_map_fn(scope_name):
+ """Closure for stripping the scope name of a Network.
+
+ Implemented as a closure rather than a member function to avoid reference
+ cycles in deferred restorations (this function should not have a reference to
+ the Network which created it).
+
+ Args:
+ scope_name: The Network.scope_name to strip from variables.
+ Returns:
+ A scope_name-stripping default `map_fn` for the Network.
+ """
+
+ def _strip_variable_prefix(original_variable_name):
+ """The default map_func for saving or restoring variables.
+
+ Strips the variable prefix for the Network on which save/restore was called,
+ and leaves other variable names fully qualified in the checkpoint.
+
+ Args:
+ original_variable_name: The _shared_name of the variable (no :0
+ suffix) to map.
+ Returns:
+ The checkpoint name of the variable.
+ """
+ scope_name_with_slash = scope_name + "/"
+ if original_variable_name.startswith(scope_name_with_slash):
+ return original_variable_name[len(scope_name_with_slash):]
+ else:
+ return original_variable_name
+
+ return _strip_variable_prefix
+
+
+def save_network_checkpoint(
+ network, save_path, global_step=None, map_func=None):
+ """Save variables from the Network to a checkpoint.
+
+ Args:
+ network: A Network object to save.
+ save_path: Either a checkpoint prefix or the name of a directory to save
+ the checkpoint in (in which case the checkpoint will be named based on
+ the Network name).
+ global_step: The global step to use when naming the checkpoint. If None
+ (default), we will first try to get the default global step. If that
+ fails because no default global step exists, then the checkpoint is
+ created without a global step suffix.
+ map_func: A function mapping fully qualified variable names
+ (e.g. 'my_network_1/dense_1/kernel') to names in the checkpoint. By
+ default (if `map_func=None`), the variable prefix for the network being
+ restored (`Network.scope_name + '/'`, e.g. 'my_network_1/') is stripped
+ and all other variable names (shared with other Networks) are left
+ unchanged.
+ Returns:
+ The checkpoint prefix for the saved checkpoint, which may be passed to
+ `Network.restore`.
+ Raises:
+ ValueError: If the Network has not yet been called, or if map_func results
+ in a name collision.
+ """
+ if not network.built:
+ raise ValueError(
+ "Attempt to save the Network before it was first called. This means "
+ "variables have not yet been created, so there is nothing to save.")
+ network._set_scope() # scope_name should be available to map_funcs
+ if global_step is None:
+ global_step = training_util.get_global_step()
+ if os.path.isdir(save_path):
+ # If we were passed a directory, default to naming based on the Network
+ # name.
+ save_path = os.path.join(save_path, network.name.replace("/", "_"))
+ user_map_func = map_func
+ if map_func is None:
+ map_func = _make_prefix_stripping_map_fn(network.scope_name)
+ variable_map = {}
+ for variable in network.variables:
+ mapped_name = map_func(variable._shared_name)
+ if variable_map.setdefault(mapped_name, variable) is not variable:
+ if user_map_func is None:
+ # Instead of erroring out, we could just re-try and silently use the
+ # full variable names in the checkpoint. This could be odd for deeply
+ # nested sub-Networks (since the full prefix from the nesting would
+ # get added), so for now we'll let the user deal with this case.
+ raise ValueError(_default_naming_conflict_error_message(
+ mapped_name=mapped_name,
+ first_variable=variable_map[mapped_name],
+ second_variable=variable,
+ network_name=network.name,
+ network_scope_name=network.scope_name))
+ else:
+ # The user passed their own problematic map_func.
+ raise ValueError(
+ ("The map_func passed to save_network_checkpoint for the Network "
+ "'%s' resulted in two variables named '%s' ('%s' and '%s'). Try "
+ "stripping less from the variable names, or renaming parts of "
+ "the Network. For reference, variables created by sub-Layers of "
+ "this Network are prefixed with '%s', but if they are re-used "
+ "after being added to another Network, they will have that "
+ "Network's full variable prefix instead.") % (
+ network.name, mapped_name,
+ variable_map[mapped_name]._shared_name,
+ variable._shared_name,
+ network.scope_name))
+ if context.in_eager_mode():
+ sess = None
+ else:
+ sess = ops.get_default_session()
+ return saver_lib.Saver(variable_map).save(
+ sess=sess, save_path=save_path, write_meta_graph=False,
+ global_step=global_step)
+
+
+def _add_deferred_restoration(layer, deferred_restoration):
+ """Add a deferred restoration to this Layer and all children.
+
+ Restorations which are requested later have higher priority, and the highest
+ priority matching restoration is applied to a variable when it is created.
+
+ Args:
+ layer: The Layer (may not be a Network) to operate on.
+ deferred_restoration: A _DeferredRestoration object.
+ """
+ # Networks don't create variables at the moment, so this append isn't strictly
+ # necessary. We could get by with only adding deferred restorations to
+ # non-Network Layers.
+ if isinstance(layer, Network):
+ layer._set_scope()
+ # Make sure this Layer has a deferred restoration queue and a custom getter,
+ # then add our request to it.
+ if not hasattr(layer, "_custom_getter"):
+ assert not hasattr(layer, "_deferred_restorations")
+ layer._custom_getter, layer._deferred_restorations = (
+ _make_custom_getter_for_deferred_restorations())
+ # We use set_custom_getter because it avoids recursively calling up the
+ # variable_scope tree. We've done the tree traversal ourselves and have added
+ # the request to each Layer which needs it.
+ layer._scope.set_custom_getter(layer._custom_getter)
+ layer._deferred_restorations.append(deferred_restoration)
+ if isinstance(layer, Network):
+ for sublayer in layer.layers:
+ if not isinstance(sublayer, Network):
+ layer._set_scope_for_nonnetwork_sublayer(sublayer)
+ _add_deferred_restoration(sublayer, deferred_restoration)
+
+
+def _restore_existing_variables(network, save_path, map_func, user_map_func):
+ """Use a standard Saver to restore existing variables from a checkpoint.
+
+ Args:
+ network: A Network object to restore.
+ save_path: The checkpoint prefix or directory to read from.
+ map_func: The function to use when mapping from variable names to
+ checkpoint names.
+ user_map_func: The original map_func passed by the user, for error
+ checking.
+ Returns:
+ A dictionary mapping from checkpoint names to variable objects which have
+ been restored (for bookkeeping to avoid deferred restorations on these
+ variables).
+ Raises:
+ ValueError: If there is a name collision.
+ """
+ existing_variables_by_checkpoint_name = {}
+ for variable in network.variables:
+ checkpoint_name = map_func(variable._shared_name)
+ if existing_variables_by_checkpoint_name.setdefault(
+ checkpoint_name, variable) is not variable:
+ if user_map_func is None:
+ raise ValueError(_default_naming_conflict_error_message(
+ mapped_name=checkpoint_name,
+ first_variable=existing_variables_by_checkpoint_name[
+ checkpoint_name],
+ second_variable=variable,
+ network_name=network.name,
+ network_scope_name=network.scope_name))
+ else:
+ raise ValueError(_restore_custom_map_func_error_message(
+ mapped_name=checkpoint_name,
+ first_variable=existing_variables_by_checkpoint_name[
+ checkpoint_name],
+ second_variable=variable,
+ network_name=network.name,
+ network_scope_name=network.scope_name))
+ if existing_variables_by_checkpoint_name:
+ if context.in_eager_mode():
+ sess = None
+ else:
+ sess = ops.get_default_session()
+ saver_lib.Saver(var_list=existing_variables_by_checkpoint_name).restore(
+ sess=sess, save_path=save_path)
+ return existing_variables_by_checkpoint_name
+
+
+def _set_restore_on_create(network, save_path, map_func, user_map_func,
+ existing_variables_by_checkpoint_name):
+ """If necessary, request deferred restorations of variables."""
+ checkpoint_reader = checkpoint_utils.load_checkpoint(save_path)
+ checkpointed_variables_to_restore = {}
+ for checkpoint_name, _ in checkpoint_utils.list_variables(save_path):
+ if checkpoint_name in existing_variables_by_checkpoint_name:
+ # This variable was already created and restored.
+ continue
+ # Save the variable for later restoration in a custom getter.
+ checkpointed_variables_to_restore[checkpoint_name] = (
+ checkpoint_reader.get_tensor(checkpoint_name))
+ # Only set a deferred restoration if there are checkpoint variables which
+ # have not been assigned to existing variables. Note that this loses out on
+ # some opportunity for error checking, but avoids creating
+ # _DeferredRestoration objects once a Network has been built (so that
+ # restoring in a loop does not take increasing amounts of memory).
+ if checkpointed_variables_to_restore:
+ if context.in_eager_mode():
+ sess = None
+ else:
+ sess = ops.get_default_session()
+ # We need a name for error messages. If we haven't been added to another
+ # Network yet, we're top-level.
+ network._finalize_name(False)
+ network._set_scope()
+ # Save a record of this restoration for use in the custom getter.
+ deferred_restoration = _DeferredRestoration(
+ map_func=map_func,
+ map_func_is_user=(user_map_func is not None),
+ checkpointed_variables_to_restore=checkpointed_variables_to_restore,
+ restored_variables={},
+ session=sess,
+ network_name=network.name,
+ network_scope_name=network.scope_name)
+ # Add the deferred registration to non-Network children, and request that
+ # Networks propagate the request to their children.
+ _add_deferred_restoration(network, deferred_restoration)
+
+
+def restore_network_checkpoint(network, save_path, map_func=None):
+ """Restore the Network from a checkpoint.
+
+ If variables have already been created (typically when some or all of the
+ `Network` is built), they are assigned values from the checkpoint immediately,
+ overwriting any existing values (in graph mode the default session is used for
+ the assignments).
+
+ If there are checkpoint entries which do not correspond to any existing
+ variables in the `Network`, these values are saved for deferred restoration;
+ their initial values will be the checkpointed values once they are
+ created. Requests for multiple deferred restorations behave the same way as
+ immediate restorations, in that later requests will take priority over earlier
+ requests relevant to the same variable.
+
+ If this `Network` shares `Layer`s with another network, those `Layer`s will
+ also have their variables restored from the checkpoint.
+
+ Args:
+ network: A Network object to restore.
+ save_path: The return value of `tfe.save_network_checkpoint`, or a directory
+ to search for a checkpoint.
+ map_func: A function mapping fully qualified variable names
+ (e.g. 'my_network_1/dense_1/kernel') to names in the checkpoint. By
+ default (if `map_func=None`), the variable prefix for the network being
+ restored (`Network.scope_name + '/'`, e.g. 'my_network_1/') is stripped
+ and all other variable names (shared with other Networks) are left
+ unchanged. Note that this is the _same_ map_func as
+ `tfe.save_network_checkpoint`, not an inverse mapping.
+ """
+ network._finalize_name(parent_network=False)
+ network._set_scope() # scope_name should be available to map_funcs
+ if os.path.isdir(save_path):
+ # If we don't have a name yet, set no parent.
+ save_path = os.path.join(save_path, network.name.replace("/", "_"))
+ user_map_func = map_func
+ if map_func is None:
+ map_func = _make_prefix_stripping_map_fn(network.scope_name)
+ # Step one is to restore any existing variables from the checkpoint.
+ existing_variables_by_checkpoint_name = _restore_existing_variables(
+ network=network,
+ save_path=save_path,
+ map_func=map_func,
+ user_map_func=user_map_func)
+ # Step two is to set a custom getter which restores variables on creation,
+ # for those variables which have not been added to sub-Layers yet.
+ _set_restore_on_create(
+ network=network,
+ save_path=save_path,
+ map_func=map_func,
+ user_map_func=user_map_func,
+ existing_variables_by_checkpoint_name=(
+ existing_variables_by_checkpoint_name))
diff --git a/tensorflow/contrib/eager/python/network_test.py b/tensorflow/contrib/eager/python/network_test.py
index 1127055c05..e66486d165 100644
--- a/tensorflow/contrib/eager/python/network_test.py
+++ b/tensorflow/contrib/eager/python/network_test.py
@@ -46,8 +46,8 @@ class NetworkTest(test.TestCase):
def _save_modify_load_network_built(self, net, global_step=None):
checkpoint_directory = self.get_temp_dir()
- checkpoint_path = net.save(
- save_path=checkpoint_directory, global_step=global_step)
+ checkpoint_path = network.save_network_checkpoint(
+ network=net, save_path=checkpoint_directory, global_step=global_step)
input_value = constant_op.constant([[42.0]])
original_output = self.evaluate(net(input_value))
for var in net.variables:
@@ -56,13 +56,13 @@ class NetworkTest(test.TestCase):
self.evaluate(net(input_value)),
original_output)
# Either the returned explicit checkpoint path or the directory should work.
- net.restore(save_path=checkpoint_directory)
+ network.restore_network_checkpoint(net, save_path=checkpoint_directory)
self.assertAllEqual(
original_output,
self.evaluate(net(input_value)))
for var in net.variables:
self.evaluate(var.assign(var + 2.))
- net.restore(save_path=checkpoint_path)
+ network.restore_network_checkpoint(net, save_path=checkpoint_path)
self.assertAllEqual(
original_output,
self.evaluate(net(input_value)))
@@ -91,7 +91,7 @@ class NetworkTest(test.TestCase):
net = MyNetwork(name="abcd")
with self.assertRaisesRegexp(
ValueError, "Attempt to save the Network before it was first called"):
- net.save(self.get_temp_dir())
+ network.save_network_checkpoint(net, self.get_temp_dir())
net(constant_op.constant([[2.0]]))
self.evaluate(net.trainable_variables[0].assign([[17.0]]))
self._save_modify_load_network_built(net, global_step=None)
@@ -105,7 +105,7 @@ class NetworkTest(test.TestCase):
self.evaluate(net.variables[0].assign([[3.]]))
default_global_step = training_util.get_or_create_global_step()
self.evaluate(default_global_step.assign(4242))
- save_path = net.save(self.get_temp_dir())
+ save_path = network.save_network_checkpoint(net, self.get_temp_dir())
self.assertIn("abcd-4242", save_path)
# TODO(allenl): This test creates garbage in some Python versions
@@ -116,10 +116,10 @@ class NetworkTest(test.TestCase):
test_input = constant_op.constant([[2.0]])
net1(test_input)
self.evaluate(net1.trainable_variables[0].assign([[17.0]]))
- save_path = net1.save(save_dir)
+ save_path = network.save_network_checkpoint(net1, save_dir)
# With a pre-build restore we should have the same value.
net2 = MyNetwork()
- net2.restore(save_path)
+ network.restore_network_checkpoint(net2, save_path)
self.assertAllEqual(self.evaluate(net1(test_input)),
self.evaluate(net2(test_input)))
self.assertIsNot(net1.variables[0], net2.variables[0])
@@ -176,11 +176,12 @@ class NetworkTest(test.TestCase):
"checkpoint_creator/first_layer/kernel": "owner_1/first_layer/kernel",
"checkpoint_creator/second_layer/kernel": "second_layer/kernel",
}
- save_path = checkpoint_creator.save(
+ save_path = network.save_network_checkpoint(
+ checkpoint_creator,
self.get_temp_dir(),
map_func=lambda full_name: name_mapping[full_name])
load_into = User(use_layer=first_owner.first)
- load_into.restore(save_path)
+ network.restore_network_checkpoint(load_into, save_path)
self.assertEqual(0, len(first_owner.variables))
self.assertAllEqual(self.evaluate(checkpoint_creator(one)),
self.evaluate(load_into(one)))
@@ -201,7 +202,8 @@ class NetworkTest(test.TestCase):
else:
return "user_2/" + original_name
with self.assertRaisesRegexp(ValueError, "garbage collected"):
- load_into.restore(save_path, map_func=_restore_map_func)
+ network.restore_network_checkpoint(
+ load_into, save_path, map_func=_restore_map_func)
@test_util.run_in_graph_and_eager_modes()
def testRestoreIntoSubNetwork(self):
@@ -221,17 +223,18 @@ class NetworkTest(test.TestCase):
whole_model_saver(one)
self.evaluate(whole_model_saver.variables[0].assign([[15.]]))
self.evaluate(whole_model_saver.variables[1].assign([[16.]]))
- whole_model_checkpoint = whole_model_saver.save(self.get_temp_dir())
+ whole_model_checkpoint = network.save_network_checkpoint(
+ whole_model_saver, self.get_temp_dir())
save_from = MyNetwork()
save_from(one)
self.evaluate(save_from.variables[0].assign([[5.]]))
- checkpoint = save_from.save(self.get_temp_dir())
+ checkpoint = network.save_network_checkpoint(save_from, self.get_temp_dir())
save_into_parent = Parent()
- save_into_parent.restore(whole_model_checkpoint)
- save_into_parent.first.restore(checkpoint)
- save_into_parent.first.restore(checkpoint) # deferred loading multiple
- # times is fine
+ network.restore_network_checkpoint(save_into_parent, whole_model_checkpoint)
+ network.restore_network_checkpoint(save_into_parent.first, checkpoint)
+ # deferred loading multiple times is fine
+ network.restore_network_checkpoint(save_into_parent.first, checkpoint)
save_into_parent(one) # deferred loading
self.assertAllEqual([[5.]], self.evaluate(save_into_parent.variables[0]))
self.assertAllEqual([[16.]], self.evaluate(save_into_parent.variables[1]))
@@ -240,9 +243,9 @@ class NetworkTest(test.TestCase):
# (deferred restoration should happen the same way non-deferred happens,
# with later restorations overwriting older ones).
save_into_parent = Parent()
- save_into_parent.first.restore(checkpoint) # deferred loading multiple
- # times is fine
- save_into_parent.restore(whole_model_checkpoint)
+ # deferred loading multiple times is fine
+ network.restore_network_checkpoint(save_into_parent.first, checkpoint)
+ network.restore_network_checkpoint(save_into_parent, whole_model_checkpoint)
save_into_parent(one) # deferred loading
# We've overwritten the sub-Network restore.
self.assertAllEqual([[15.]], self.evaluate(save_into_parent.variables[0]))
@@ -250,12 +253,12 @@ class NetworkTest(test.TestCase):
self.evaluate(save_into_parent.variables[0].assign([[3.]]))
self.evaluate(save_into_parent.variables[1].assign([[4.]]))
- save_into_parent.second.restore(checkpoint)
+ network.restore_network_checkpoint(save_into_parent.second, checkpoint)
self.assertAllEqual([[5.]], self.evaluate(save_into_parent.variables[1]))
with self.assertRaisesRegexp(errors_impl.NotFoundError,
"not found in checkpoint"):
# The checkpoint is incompatible.
- save_into_parent.restore(checkpoint)
+ network.restore_network_checkpoint(save_into_parent, checkpoint)
@test_util.run_in_graph_and_eager_modes()
def testCustomMapCollisionErrors(self):
@@ -277,25 +280,30 @@ class NetworkTest(test.TestCase):
self.evaluate(make_checkpoint.variables[1].assign([[3.]]))
with self.assertRaisesRegexp(
ValueError,
- "The map_func passed to Network.save for the Network 'parent_1' "
- "resulted in two variables named 'foo'"):
- make_checkpoint.save(self.get_temp_dir(), map_func=lambda n: "foo")
- checkpoint = make_checkpoint.first.save(
- self.get_temp_dir(), map_func=lambda n: "foo")
+ "The map_func passed to save_network_checkpoint for the Network "
+ "'parent_1' resulted in two variables named 'foo'"):
+ network.save_network_checkpoint(
+ make_checkpoint, self.get_temp_dir(), map_func=lambda n: "foo")
+ checkpoint = network.save_network_checkpoint(
+ network=make_checkpoint.first,
+ save_path=self.get_temp_dir(),
+ map_func=lambda n: "foo")
loader = Parent()
- loader.restore(checkpoint, map_func=lambda n: "foo")
+ network.restore_network_checkpoint(
+ loader, checkpoint, map_func=lambda n: "foo")
with self.assertRaisesRegexp(
ValueError,
- ("The map_func passed to Network.restore for the Network"
+ ("The map_func passed to restore_network_checkpoint for the Network"
" 'parent_2' resulted in two variables named 'foo'")):
loader(one)
loader = Parent()
loader(one)
with self.assertRaisesRegexp(
ValueError,
- ("The map_func passed to Network.restore for the Network"
+ ("The map_func passed to restore_network_checkpoint for the Network"
" 'parent_3' resulted in two variables named 'foo'")):
- loader.restore(checkpoint, map_func=lambda n: "foo")
+ network.restore_network_checkpoint(
+ loader, checkpoint, map_func=lambda n: "foo")
@test_util.run_in_graph_and_eager_modes()
def testDefaultMapCollisionErrors(self):
@@ -323,7 +331,7 @@ class NetworkTest(test.TestCase):
ValueError,
("The default checkpoint variable name mapping strategy for Network "
"'parent_1' resulted in a naming conflict.")):
- make_checkpoint.save(self.get_temp_dir())
+ network.save_network_checkpoint(make_checkpoint, self.get_temp_dir())
class Compatible(network.Network):
@@ -337,14 +345,15 @@ class NetworkTest(test.TestCase):
successful_checkpoint = Compatible()
successful_checkpoint(one)
self.evaluate(successful_checkpoint.variables[0].assign([[-1.]]))
- checkpoint_path = successful_checkpoint.save(self.get_temp_dir())
+ checkpoint_path = network.save_network_checkpoint(
+ successful_checkpoint, self.get_temp_dir())
load_checkpoint = Parent()
load_checkpoint(one)
with self.assertRaisesRegexp(
ValueError,
("The default checkpoint variable name mapping strategy for Network "
"'parent_2' resulted in a naming conflict.")):
- load_checkpoint.restore(checkpoint_path)
+ network.restore_network_checkpoint(load_checkpoint, checkpoint_path)
def testNoReferenceCyclesAfterCall(self):
@@ -494,17 +503,17 @@ class NetworkTest(test.TestCase):
self.assertStartsWith(
expected_start="scope1/scope2/my_network_1/dense_1/",
actual=net.trainable_weights[0].name)
- save_path = net.save(self.get_temp_dir())
+ save_path = network.save_network_checkpoint(net, self.get_temp_dir())
self.assertIn("scope1_scope2_my_network_1", save_path)
restore_net = MyNetwork()
# Delayed restoration
- restore_net.restore(save_path)
+ network.restore_network_checkpoint(restore_net, save_path)
restore_net(constant_op.constant([[1.0]]))
self.assertAllEqual([[42.]],
self.evaluate(restore_net.variables[0]))
self.evaluate(restore_net.variables[0].assign([[-1.]]))
# Immediate restoration
- restore_net.restore(save_path)
+ network.restore_network_checkpoint(restore_net, save_path)
self.assertAllEqual([[42.]],
self.evaluate(restore_net.variables[0]))
diff --git a/tensorflow/contrib/eager/python/tfe.py b/tensorflow/contrib/eager/python/tfe.py
index b6c687c829..577d3efef6 100644
--- a/tensorflow/contrib/eager/python/tfe.py
+++ b/tensorflow/contrib/eager/python/tfe.py
@@ -46,13 +46,16 @@ To use, at program startup, call `tfe.enable_eager_execution()`.
@@seterr
@@Iterator
-@@Network
@@Saver
@@restore_variables_on_create
@@Variable
@@get_optimizer_variables
@@EagerVariableStore
+@@Network
+@@save_network_checkpoint
+@@restore_network_checkpoint
+
@@in_eager_mode
@@in_graph_mode
@@ -74,6 +77,8 @@ from __future__ import print_function
from tensorflow.contrib.eager.python import metrics
from tensorflow.contrib.eager.python.datasets import Iterator
from tensorflow.contrib.eager.python.network import Network
+from tensorflow.contrib.eager.python.network import save_network_checkpoint
+from tensorflow.contrib.eager.python.network import restore_network_checkpoint
from tensorflow.contrib.eager.python.saver import get_optimizer_variables
from tensorflow.contrib.eager.python.saver import restore_variables_on_create
from tensorflow.contrib.eager.python.saver import Saver