aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/saver.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/training/saver.py')
-rw-r--r--tensorflow/python/training/saver.py84
1 files changed, 61 insertions, 23 deletions
diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py
index 53ed89e4ab..c80cdf03be 100644
--- a/tensorflow/python/training/saver.py
+++ b/tensorflow/python/training/saver.py
@@ -22,7 +22,6 @@ from __future__ import print_function
import collections
import os.path
import re
-import sys
import time
import uuid
@@ -127,8 +126,10 @@ class BaseSaverBuilder(object):
def f():
with ops.device(v.device):
x = v.read_value()
- with ops.device("/device:CPU:0"):
- return array_ops.identity(x)
+ # To allow variables placed on non-CPU devices to be checkpointed,
+ # we copy them to CPU on the same machine first.
+ with ops.device("/device:CPU:0"):
+ return array_ops.identity(x)
return f
self.handle_op = var.handle
@@ -1043,8 +1044,8 @@ def get_checkpoint_state(checkpoint_dir, latest_filename=None):
ckpt = CheckpointState()
text_format.Merge(file_content, ckpt)
if not ckpt.model_checkpoint_path:
- raise ValueError("Invalid checkpoint state loaded from %s",
- checkpoint_dir)
+ raise ValueError("Invalid checkpoint state loaded from "
+ + checkpoint_dir)
# For relative model_checkpoint_path and all_model_checkpoint_paths,
# prepend checkpoint_dir.
if not os.path.isabs(ckpt.model_checkpoint_path):
@@ -1706,12 +1707,17 @@ class Saver(object):
save_path: Path where parameters were previously saved.
Raises:
- ValueError: If save_path is None.
+ ValueError: If save_path is None or not a valid checkpoint.
"""
if self._is_empty:
return
if save_path is None:
raise ValueError("Can't load save_path when it is None.")
+
+ if not checkpoint_exists(compat.as_text(save_path)):
+ raise ValueError("The passed save_path is not a valid checkpoint: "
+ + compat.as_text(save_path))
+
logging.info("Restoring parameters from %s", compat.as_text(save_path))
try:
if context.executing_eagerly():
@@ -1719,23 +1725,24 @@ class Saver(object):
else:
sess.run(self.saver_def.restore_op_name,
{self.saver_def.filename_tensor_name: save_path})
- except errors.NotFoundError:
- exception_type, exception_value, exception_traceback = sys.exc_info()
- # The checkpoint would not be loaded successfully as is. Try to parse it
- # as an object-based checkpoint.
- should_reraise = False
+ except errors.NotFoundError as err:
+ # There are three common conditions that might cause this error:
+ # 0. The file is missing. We ignore here, as this is checked above.
+ # 1. This is an object-based checkpoint trying name-based loading.
+ # 2. The graph has been altered and a variable or other name is missing.
+
+ # 1. The checkpoint would not be loaded successfully as is. Try to parse
+ # it as an object-based checkpoint.
try:
reader = pywrap_tensorflow.NewCheckpointReader(save_path)
object_graph_string = reader.get_tensor(
checkpointable.OBJECT_GRAPH_PROTO_KEY)
except errors.NotFoundError:
- # This is not an object-based checkpoint, or the checkpoint doesn't
- # exist. Re-raise the original exception, but do it outside the except
- # block so the object graph lookup isn't included in the stack trace.
- should_reraise = True
- if should_reraise:
- six.reraise(exception_type, exception_value, exception_traceback)
- del exception_traceback # avoid reference cycles
+ # 2. This is not an object-based checkpoint, which likely means there
+ # is a graph mismatch. Re-raise the original error with
+ # a helpful message (b/110263146)
+ raise _wrap_restore_error_with_msg(
+ err, "a Variable name or other graph key that is missing")
# This is an object-based checkpoint. We'll print a warning and then do
# the restore.
@@ -1747,6 +1754,11 @@ class Saver(object):
self._restore_from_object_based_checkpoint(
sess=sess, save_path=save_path,
object_graph_string=object_graph_string)
+ except errors.InvalidArgumentError as err:
+ # There is a mismatch between the graph and the checkpoint being loaded.
+ # We add a more reasonable error message here to help users (b/110263146)
+ raise _wrap_restore_error_with_msg(
+ err, "a mismatch between the current graph and the graph")
def _restore_from_object_based_checkpoint(self, sess, save_path,
object_graph_string):
@@ -1913,6 +1925,14 @@ def import_meta_graph(meta_graph_or_file, clear_devices=False,
execution is enabled.
@end_compatibility
""" # pylint: disable=g-doc-exception
+ return _import_meta_graph_with_return_elements(
+ meta_graph_or_file, clear_devices, import_scope, **kwargs)[0]
+
+
+def _import_meta_graph_with_return_elements(
+ meta_graph_or_file, clear_devices=False, import_scope=None,
+ return_elements=None, **kwargs):
+ """Import MetaGraph, and return both a saver and returned elements."""
if context.executing_eagerly():
raise RuntimeError("Exporting/importing meta graphs is not supported when "
"eager execution is enabled. No graph exists when eager "
@@ -1922,12 +1942,22 @@ def import_meta_graph(meta_graph_or_file, clear_devices=False,
else:
meta_graph_def = meta_graph_or_file
- imported_vars = meta_graph.import_scoped_meta_graph(
- meta_graph_def,
- clear_devices=clear_devices,
- import_scope=import_scope,
- **kwargs)
+ imported_vars, imported_return_elements = (
+ meta_graph.import_scoped_meta_graph_with_return_elements(
+ meta_graph_def,
+ clear_devices=clear_devices,
+ import_scope=import_scope,
+ return_elements=return_elements,
+ **kwargs))
+
+ saver = _create_saver_from_imported_meta_graph(
+ meta_graph_def, import_scope, imported_vars)
+ return saver, imported_return_elements
+
+def _create_saver_from_imported_meta_graph(
+ meta_graph_def, import_scope, imported_vars):
+ """Return a saver for restoring variable values to an imported MetaGraph."""
if meta_graph_def.HasField("saver_def"):
# Infer the scope that is prepended by `import_scoped_meta_graph`.
scope = import_scope
@@ -2139,6 +2169,14 @@ def _meta_graph_filename(checkpoint_filename, meta_graph_suffix="meta"):
return meta_graph_filename
+def _wrap_restore_error_with_msg(err, extra_verbiage):
+ err_msg = ("Restoring from checkpoint failed. This is most likely "
+ "due to {} from the checkpoint. Please ensure that you "
+ "have not altered the graph expected based on the checkpoint. "
+ "Original error:\n\n{}").format(extra_verbiage, err.message)
+ return err.__class__(err.node_def, err.op, err_msg)
+
+
ops.register_proto_function(
ops.GraphKeys.SAVERS,
proto_type=saver_pb2.SaverDef,