diff options
author | Asim Shankar <ashankar@google.com> | 2018-03-07 12:03:56 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-03-07 12:10:42 -0800 |
commit | 37cef895bfe06913477b87917cbee7284aefa7cd (patch) | |
tree | 4f05a013578c0459a52fc5e6448bb3dfc2d04971 /tensorflow/python/training/saver.py | |
parent | 808b569e85df8d63590740f05bc14d964efc4801 (diff) |
eager: Rename in_eager_mode to executing_eagerly and get rid of in_graph_mode.
This is in preparation to introduce one public, stable symbol: tf.executing_eagerly()
(i.e., part of moving APIs related to eager execution from "contrib" to a namespace
where we provide API stability guarantees)
PiperOrigin-RevId: 188212646
Diffstat (limited to 'tensorflow/python/training/saver.py')
-rw-r--r-- | tensorflow/python/training/saver.py | 96 |
1 files changed, 50 insertions, 46 deletions
diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py index df3ccce63e..2ce57c4432 100644 --- a/tensorflow/python/training/saver.py +++ b/tensorflow/python/training/saver.py @@ -582,7 +582,20 @@ class BaseSaverBuilder(object): BaseSaverBuilder.OpListToDict( list(var._gather_saveables_for_checkpoint().values()))) else: - if context.in_graph_mode(): + if context.executing_eagerly(): + if not isinstance(var, resource_variable_ops.ResourceVariable): + raise ValueError( + "Can only save/restore ResourceVariables when eager execution " + "is enabled, type: %s." % type(var)) + set_var = names_to_saveables.setdefault(var._shared_name, var) + if set_var is not var: + raise ValueError( + ("Two different ResourceVariable objects with the same " + "shared_name '%s' were passed to the Saver. This likely means " + "that they were created in different Graphs or isolation " + "contexts, and may not be checkpointed together.") % + (var._shared_name,)) + else: if convert_variable_to_tensor: if isinstance(var, resource_variable_ops.ResourceVariable): var = var._graph_element # pylint: disable=protected-access @@ -598,18 +611,6 @@ class BaseSaverBuilder(object): raise ValueError("At least two variables have the same name: %s" % name) names_to_saveables[name] = var - else: - if not isinstance(var, resource_variable_ops.ResourceVariable): - raise ValueError("Can only save/restore ResourceVariable eager " - "mode is enabled, type: %s." % type(var)) - set_var = names_to_saveables.setdefault(var._shared_name, var) - if set_var is not var: - raise ValueError( - ("Two different ResourceVariable objects with the same " - "shared_name '%s' were passed to the Saver. This likely means " - "that they were created in different Graphs or isolation " - "contexts, and may not be checkpointed together.") % ( - var._shared_name,)) # pylint: enable=protected-access return names_to_saveables @@ -671,7 +672,7 @@ class BaseSaverBuilder(object): # pylint: enable=protected-access else: # A variable or tensor. - if context.in_eager_mode(): + if context.executing_eagerly(): if not isinstance(op, resource_variable_ops.ResourceVariable): raise ValueError("Can only save/restore ResourceVariable eager " "mode is enabled, type: %s." % type(op)) @@ -778,8 +779,10 @@ class BaseSaverBuilder(object): build_save=True, build_restore=True): """build() with option to only perform save and restore.""" - if context.in_graph_mode() and (not build_save or not build_restore): - raise ValueError("Graph mode needs to build save and restore together.") + if not context.executing_eagerly() and (not build_save or + not build_restore): + raise ValueError("save and restore operations need to be built together " + " when eager execution is not enabled.") saveables = self._ValidateAndSliceInputs(names_to_saveables) if max_to_keep is None: @@ -816,22 +819,22 @@ class BaseSaverBuilder(object): # such usage model makes sense. # # assert restore_op.name.endswith("restore_all"), restore_op.name - if context.in_graph_mode(): + if context.executing_eagerly(): + # Store the tensor values to the tensor_names. + save_tensor_name = save_tensor.numpy() if build_save else "" return saver_pb2.SaverDef( - filename_tensor_name=filename_tensor.name, - save_tensor_name=save_tensor.name, - restore_op_name=restore_op.name, + filename_tensor_name=filename_tensor.numpy(), + save_tensor_name=save_tensor_name, + restore_op_name="", max_to_keep=max_to_keep, sharded=sharded, keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours, version=self._write_version) else: - # Store the tensor values to the tensor_names. - save_tensor_name = save_tensor.numpy() if build_save else "" return saver_pb2.SaverDef( - filename_tensor_name=filename_tensor.numpy(), - save_tensor_name=save_tensor_name, - restore_op_name="", + filename_tensor_name=filename_tensor.name, + save_tensor_name=save_tensor.name, + restore_op_name=restore_op.name, max_to_keep=max_to_keep, sharded=sharded, keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours, @@ -1280,7 +1283,7 @@ class Saver(object): raise ValueError( "If `var_list` is provided then build cannot be deferred. " "Either set defer_build=False or var_list=None.") - if context.in_eager_mode() and var_list is None: + if context.executing_eagerly() and var_list is None: raise RuntimeError( "When eager execution is enabled, `var_list` must specify a list or " "dict of variables to save") @@ -1301,10 +1304,10 @@ class Saver(object): self._filename = filename self._last_checkpoints = [] self._checkpoints_to_be_deleted = [] - if context.in_eager_mode(): + if context.executing_eagerly(): self._next_checkpoint_time = ( time.time() + self._keep_checkpoint_every_n_hours * 3600) - if not defer_build and context.in_graph_mode(): + elif not defer_build: self.build() if self.saver_def: self._check_saver_def() @@ -1312,7 +1315,7 @@ class Saver(object): self._save_relative_paths = save_relative_paths def build(self): - if context.in_eager_mode(): + if context.executing_eagerly(): raise RuntimeError("Use save/restore instead of build in eager mode.") self._build(self._filename, build_save=True, build_restore=True) @@ -1322,12 +1325,12 @@ class Saver(object): def _build(self, checkpoint_path, build_save, build_restore): """Builds saver_def.""" - if context.in_graph_mode(): + if not context.executing_eagerly(): if self._is_built: return self._is_built = True - if not self.saver_def or context.in_eager_mode(): + if not self.saver_def or context.executing_eagerly(): if self._builder is None: self._builder = BulkSaverBuilder(self._write_version) @@ -1364,8 +1367,9 @@ class Saver(object): self.saver_def.restore_op_name, self._name) self._check_saver_def() - if context.in_graph_mode(): # Set in __init__ when executing eagerly. + if not context.executing_eagerly(): # Updates next checkpoint time. + # Set in __init__ when executing eagerly. self._next_checkpoint_time = ( time.time() + self.saver_def.keep_checkpoint_every_n_hours * 3600) @@ -1373,7 +1377,7 @@ class Saver(object): if not isinstance(self.saver_def, saver_pb2.SaverDef): raise ValueError("saver_def must be a saver_pb2.SaverDef: %s" % self.saver_def) - if context.in_graph_mode(): + if not context.executing_eagerly(): if not self.saver_def.save_tensor_name: raise ValueError("saver_def must specify the save_tensor_name: %s" % str(self.saver_def)) @@ -1623,7 +1627,7 @@ class Saver(object): RuntimeError: If save and restore ops weren't built. """ # pylint: enable=line-too-long - if not self._is_built and context.in_graph_mode(): + if not self._is_built and not context.executing_eagerly(): raise RuntimeError( "`build()` should be called before save if defer_build==True") if latest_filename is None: @@ -1655,21 +1659,21 @@ class Saver(object): "'latest_filename' collides with 'save_path': '%s' and '%s'" % (latest_filename, save_path)) - if (context.in_graph_mode() and + if (not context.executing_eagerly() and not isinstance(sess, session.SessionInterface)): raise TypeError("'sess' must be a Session; %s" % sess) save_path_parent = os.path.dirname(save_path) if not self._is_empty: try: - if context.in_graph_mode(): - model_checkpoint_path = sess.run( - self.saver_def.save_tensor_name, - {self.saver_def.filename_tensor_name: checkpoint_file}) - else: + if context.executing_eagerly(): self._build_eager( checkpoint_file, build_save=True, build_restore=False) model_checkpoint_path = self.saver_def.save_tensor_name + else: + model_checkpoint_path = sess.run( + self.saver_def.save_tensor_name, + {self.saver_def.filename_tensor_name: checkpoint_file}) model_checkpoint_path = compat.as_str(model_checkpoint_path) if write_state: @@ -1691,7 +1695,7 @@ class Saver(object): if write_meta_graph: meta_graph_filename = self._MetaGraphFilename( checkpoint_file, meta_graph_suffix=meta_graph_suffix) - if context.in_graph_mode(): + if not context.executing_eagerly(): with sess.graph.as_default(): self.export_meta_graph( meta_graph_filename, strip_default_attrs=strip_default_attrs) @@ -1764,11 +1768,11 @@ class Saver(object): if save_path is None: raise ValueError("Can't load save_path when it is None.") logging.info("Restoring parameters from %s", save_path) - if context.in_graph_mode(): + if context.executing_eagerly(): + self._build_eager(save_path, build_save=False, build_restore=True) + else: sess.run(self.saver_def.restore_op_name, {self.saver_def.filename_tensor_name: save_path}) - else: - self._build_eager(save_path, build_save=False, build_restore=True) @staticmethod def _add_collection_def(meta_graph_def, key, export_scope=None): @@ -1908,7 +1912,7 @@ def import_meta_graph(meta_graph_or_file, clear_devices=False, execution is enabled. @end_compatibility """ # pylint: disable=g-doc-exception - if context.in_eager_mode(): + if context.executing_eagerly(): raise RuntimeError("Exporting/importing meta graphs is not supported when " "eager execution is enabled. No graph exists when eager " "execution is enabled.") @@ -1991,7 +1995,7 @@ def export_meta_graph(filename=None, @end_compatibility """ # pylint: enable=line-too-long - if context.in_eager_mode(): + if context.executing_eagerly(): raise RuntimeError("Exporting/importing meta graphs is not supported when " "eager execution is enabled. No graph exists when eager " "execution is enabled.") |