aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/saver.py
diff options
context:
space:
mode:
authorGravatar Asim Shankar <ashankar@google.com>2018-03-07 12:03:56 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-07 12:10:42 -0800
commit37cef895bfe06913477b87917cbee7284aefa7cd (patch)
tree4f05a013578c0459a52fc5e6448bb3dfc2d04971 /tensorflow/python/training/saver.py
parent808b569e85df8d63590740f05bc14d964efc4801 (diff)
eager: Rename in_eager_mode to executing_eagerly and get rid of in_graph_mode.
This is in preparation to introduce one public, stable symbol: tf.executing_eagerly() (i.e., part of moving APIs related to eager execution from "contrib" to a namespace where we provide API stability guarantees) PiperOrigin-RevId: 188212646
Diffstat (limited to 'tensorflow/python/training/saver.py')
-rw-r--r--tensorflow/python/training/saver.py96
1 files changed, 50 insertions, 46 deletions
diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py
index df3ccce63e..2ce57c4432 100644
--- a/tensorflow/python/training/saver.py
+++ b/tensorflow/python/training/saver.py
@@ -582,7 +582,20 @@ class BaseSaverBuilder(object):
BaseSaverBuilder.OpListToDict(
list(var._gather_saveables_for_checkpoint().values())))
else:
- if context.in_graph_mode():
+ if context.executing_eagerly():
+ if not isinstance(var, resource_variable_ops.ResourceVariable):
+ raise ValueError(
+ "Can only save/restore ResourceVariables when eager execution "
+ "is enabled, type: %s." % type(var))
+ set_var = names_to_saveables.setdefault(var._shared_name, var)
+ if set_var is not var:
+ raise ValueError(
+ ("Two different ResourceVariable objects with the same "
+ "shared_name '%s' were passed to the Saver. This likely means "
+ "that they were created in different Graphs or isolation "
+ "contexts, and may not be checkpointed together.") %
+ (var._shared_name,))
+ else:
if convert_variable_to_tensor:
if isinstance(var, resource_variable_ops.ResourceVariable):
var = var._graph_element # pylint: disable=protected-access
@@ -598,18 +611,6 @@ class BaseSaverBuilder(object):
raise ValueError("At least two variables have the same name: %s" %
name)
names_to_saveables[name] = var
- else:
- if not isinstance(var, resource_variable_ops.ResourceVariable):
- raise ValueError("Can only save/restore ResourceVariable eager "
- "mode is enabled, type: %s." % type(var))
- set_var = names_to_saveables.setdefault(var._shared_name, var)
- if set_var is not var:
- raise ValueError(
- ("Two different ResourceVariable objects with the same "
- "shared_name '%s' were passed to the Saver. This likely means "
- "that they were created in different Graphs or isolation "
- "contexts, and may not be checkpointed together.") % (
- var._shared_name,))
# pylint: enable=protected-access
return names_to_saveables
@@ -671,7 +672,7 @@ class BaseSaverBuilder(object):
# pylint: enable=protected-access
else:
# A variable or tensor.
- if context.in_eager_mode():
+ if context.executing_eagerly():
if not isinstance(op, resource_variable_ops.ResourceVariable):
raise ValueError("Can only save/restore ResourceVariable eager "
"mode is enabled, type: %s." % type(op))
@@ -778,8 +779,10 @@ class BaseSaverBuilder(object):
build_save=True,
build_restore=True):
"""build() with option to only perform save and restore."""
- if context.in_graph_mode() and (not build_save or not build_restore):
- raise ValueError("Graph mode needs to build save and restore together.")
+ if not context.executing_eagerly() and (not build_save or
+ not build_restore):
+ raise ValueError("save and restore operations need to be built together "
+ " when eager execution is not enabled.")
saveables = self._ValidateAndSliceInputs(names_to_saveables)
if max_to_keep is None:
@@ -816,22 +819,22 @@ class BaseSaverBuilder(object):
# such usage model makes sense.
#
# assert restore_op.name.endswith("restore_all"), restore_op.name
- if context.in_graph_mode():
+ if context.executing_eagerly():
+ # Store the tensor values to the tensor_names.
+ save_tensor_name = save_tensor.numpy() if build_save else ""
return saver_pb2.SaverDef(
- filename_tensor_name=filename_tensor.name,
- save_tensor_name=save_tensor.name,
- restore_op_name=restore_op.name,
+ filename_tensor_name=filename_tensor.numpy(),
+ save_tensor_name=save_tensor_name,
+ restore_op_name="",
max_to_keep=max_to_keep,
sharded=sharded,
keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
version=self._write_version)
else:
- # Store the tensor values to the tensor_names.
- save_tensor_name = save_tensor.numpy() if build_save else ""
return saver_pb2.SaverDef(
- filename_tensor_name=filename_tensor.numpy(),
- save_tensor_name=save_tensor_name,
- restore_op_name="",
+ filename_tensor_name=filename_tensor.name,
+ save_tensor_name=save_tensor.name,
+ restore_op_name=restore_op.name,
max_to_keep=max_to_keep,
sharded=sharded,
keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
@@ -1280,7 +1283,7 @@ class Saver(object):
raise ValueError(
"If `var_list` is provided then build cannot be deferred. "
"Either set defer_build=False or var_list=None.")
- if context.in_eager_mode() and var_list is None:
+ if context.executing_eagerly() and var_list is None:
raise RuntimeError(
"When eager execution is enabled, `var_list` must specify a list or "
"dict of variables to save")
@@ -1301,10 +1304,10 @@ class Saver(object):
self._filename = filename
self._last_checkpoints = []
self._checkpoints_to_be_deleted = []
- if context.in_eager_mode():
+ if context.executing_eagerly():
self._next_checkpoint_time = (
time.time() + self._keep_checkpoint_every_n_hours * 3600)
- if not defer_build and context.in_graph_mode():
+ elif not defer_build:
self.build()
if self.saver_def:
self._check_saver_def()
@@ -1312,7 +1315,7 @@ class Saver(object):
self._save_relative_paths = save_relative_paths
def build(self):
- if context.in_eager_mode():
+ if context.executing_eagerly():
raise RuntimeError("Use save/restore instead of build in eager mode.")
self._build(self._filename, build_save=True, build_restore=True)
@@ -1322,12 +1325,12 @@ class Saver(object):
def _build(self, checkpoint_path, build_save, build_restore):
"""Builds saver_def."""
- if context.in_graph_mode():
+ if not context.executing_eagerly():
if self._is_built:
return
self._is_built = True
- if not self.saver_def or context.in_eager_mode():
+ if not self.saver_def or context.executing_eagerly():
if self._builder is None:
self._builder = BulkSaverBuilder(self._write_version)
@@ -1364,8 +1367,9 @@ class Saver(object):
self.saver_def.restore_op_name, self._name)
self._check_saver_def()
- if context.in_graph_mode(): # Set in __init__ when executing eagerly.
+ if not context.executing_eagerly():
# Updates next checkpoint time.
+ # Set in __init__ when executing eagerly.
self._next_checkpoint_time = (
time.time() + self.saver_def.keep_checkpoint_every_n_hours * 3600)
@@ -1373,7 +1377,7 @@ class Saver(object):
if not isinstance(self.saver_def, saver_pb2.SaverDef):
raise ValueError("saver_def must be a saver_pb2.SaverDef: %s" %
self.saver_def)
- if context.in_graph_mode():
+ if not context.executing_eagerly():
if not self.saver_def.save_tensor_name:
raise ValueError("saver_def must specify the save_tensor_name: %s" %
str(self.saver_def))
@@ -1623,7 +1627,7 @@ class Saver(object):
RuntimeError: If save and restore ops weren't built.
"""
# pylint: enable=line-too-long
- if not self._is_built and context.in_graph_mode():
+ if not self._is_built and not context.executing_eagerly():
raise RuntimeError(
"`build()` should be called before save if defer_build==True")
if latest_filename is None:
@@ -1655,21 +1659,21 @@ class Saver(object):
"'latest_filename' collides with 'save_path': '%s' and '%s'" %
(latest_filename, save_path))
- if (context.in_graph_mode() and
+ if (not context.executing_eagerly() and
not isinstance(sess, session.SessionInterface)):
raise TypeError("'sess' must be a Session; %s" % sess)
save_path_parent = os.path.dirname(save_path)
if not self._is_empty:
try:
- if context.in_graph_mode():
- model_checkpoint_path = sess.run(
- self.saver_def.save_tensor_name,
- {self.saver_def.filename_tensor_name: checkpoint_file})
- else:
+ if context.executing_eagerly():
self._build_eager(
checkpoint_file, build_save=True, build_restore=False)
model_checkpoint_path = self.saver_def.save_tensor_name
+ else:
+ model_checkpoint_path = sess.run(
+ self.saver_def.save_tensor_name,
+ {self.saver_def.filename_tensor_name: checkpoint_file})
model_checkpoint_path = compat.as_str(model_checkpoint_path)
if write_state:
@@ -1691,7 +1695,7 @@ class Saver(object):
if write_meta_graph:
meta_graph_filename = self._MetaGraphFilename(
checkpoint_file, meta_graph_suffix=meta_graph_suffix)
- if context.in_graph_mode():
+ if not context.executing_eagerly():
with sess.graph.as_default():
self.export_meta_graph(
meta_graph_filename, strip_default_attrs=strip_default_attrs)
@@ -1764,11 +1768,11 @@ class Saver(object):
if save_path is None:
raise ValueError("Can't load save_path when it is None.")
logging.info("Restoring parameters from %s", save_path)
- if context.in_graph_mode():
+ if context.executing_eagerly():
+ self._build_eager(save_path, build_save=False, build_restore=True)
+ else:
sess.run(self.saver_def.restore_op_name,
{self.saver_def.filename_tensor_name: save_path})
- else:
- self._build_eager(save_path, build_save=False, build_restore=True)
@staticmethod
def _add_collection_def(meta_graph_def, key, export_scope=None):
@@ -1908,7 +1912,7 @@ def import_meta_graph(meta_graph_or_file, clear_devices=False,
execution is enabled.
@end_compatibility
""" # pylint: disable=g-doc-exception
- if context.in_eager_mode():
+ if context.executing_eagerly():
raise RuntimeError("Exporting/importing meta graphs is not supported when "
"eager execution is enabled. No graph exists when eager "
"execution is enabled.")
@@ -1991,7 +1995,7 @@ def export_meta_graph(filename=None,
@end_compatibility
"""
# pylint: enable=line-too-long
- if context.in_eager_mode():
+ if context.executing_eagerly():
raise RuntimeError("Exporting/importing meta graphs is not supported when "
"eager execution is enabled. No graph exists when eager "
"execution is enabled.")