diff options
Diffstat (limited to 'tensorflow/python/training/saver.py')
-rw-r--r-- | tensorflow/python/training/saver.py | 84 |
1 files changed, 61 insertions, 23 deletions
diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py index 53ed89e4ab..c80cdf03be 100644 --- a/tensorflow/python/training/saver.py +++ b/tensorflow/python/training/saver.py @@ -22,7 +22,6 @@ from __future__ import print_function import collections import os.path import re -import sys import time import uuid @@ -127,8 +126,10 @@ class BaseSaverBuilder(object): def f(): with ops.device(v.device): x = v.read_value() - with ops.device("/device:CPU:0"): - return array_ops.identity(x) + # To allow variables placed on non-CPU devices to be checkpointed, + # we copy them to CPU on the same machine first. + with ops.device("/device:CPU:0"): + return array_ops.identity(x) return f self.handle_op = var.handle @@ -1043,8 +1044,8 @@ def get_checkpoint_state(checkpoint_dir, latest_filename=None): ckpt = CheckpointState() text_format.Merge(file_content, ckpt) if not ckpt.model_checkpoint_path: - raise ValueError("Invalid checkpoint state loaded from %s", - checkpoint_dir) + raise ValueError("Invalid checkpoint state loaded from " + + checkpoint_dir) # For relative model_checkpoint_path and all_model_checkpoint_paths, # prepend checkpoint_dir. if not os.path.isabs(ckpt.model_checkpoint_path): @@ -1706,12 +1707,17 @@ class Saver(object): save_path: Path where parameters were previously saved. Raises: - ValueError: If save_path is None. + ValueError: If save_path is None or not a valid checkpoint. """ if self._is_empty: return if save_path is None: raise ValueError("Can't load save_path when it is None.") + + if not checkpoint_exists(compat.as_text(save_path)): + raise ValueError("The passed save_path is not a valid checkpoint: " + + compat.as_text(save_path)) + logging.info("Restoring parameters from %s", compat.as_text(save_path)) try: if context.executing_eagerly(): @@ -1719,23 +1725,24 @@ class Saver(object): else: sess.run(self.saver_def.restore_op_name, {self.saver_def.filename_tensor_name: save_path}) - except errors.NotFoundError: - exception_type, exception_value, exception_traceback = sys.exc_info() - # The checkpoint would not be loaded successfully as is. Try to parse it - # as an object-based checkpoint. - should_reraise = False + except errors.NotFoundError as err: + # There are three common conditions that might cause this error: + # 0. The file is missing. We ignore here, as this is checked above. + # 1. This is an object-based checkpoint trying name-based loading. + # 2. The graph has been altered and a variable or other name is missing. + + # 1. The checkpoint would not be loaded successfully as is. Try to parse + # it as an object-based checkpoint. try: reader = pywrap_tensorflow.NewCheckpointReader(save_path) object_graph_string = reader.get_tensor( checkpointable.OBJECT_GRAPH_PROTO_KEY) except errors.NotFoundError: - # This is not an object-based checkpoint, or the checkpoint doesn't - # exist. Re-raise the original exception, but do it outside the except - # block so the object graph lookup isn't included in the stack trace. - should_reraise = True - if should_reraise: - six.reraise(exception_type, exception_value, exception_traceback) - del exception_traceback # avoid reference cycles + # 2. This is not an object-based checkpoint, which likely means there + # is a graph mismatch. Re-raise the original error with + # a helpful message (b/110263146) + raise _wrap_restore_error_with_msg( + err, "a Variable name or other graph key that is missing") # This is an object-based checkpoint. We'll print a warning and then do # the restore. @@ -1747,6 +1754,11 @@ class Saver(object): self._restore_from_object_based_checkpoint( sess=sess, save_path=save_path, object_graph_string=object_graph_string) + except errors.InvalidArgumentError as err: + # There is a mismatch between the graph and the checkpoint being loaded. + # We add a more reasonable error message here to help users (b/110263146) + raise _wrap_restore_error_with_msg( + err, "a mismatch between the current graph and the graph") def _restore_from_object_based_checkpoint(self, sess, save_path, object_graph_string): @@ -1913,6 +1925,14 @@ def import_meta_graph(meta_graph_or_file, clear_devices=False, execution is enabled. @end_compatibility """ # pylint: disable=g-doc-exception + return _import_meta_graph_with_return_elements( + meta_graph_or_file, clear_devices, import_scope, **kwargs)[0] + + +def _import_meta_graph_with_return_elements( + meta_graph_or_file, clear_devices=False, import_scope=None, + return_elements=None, **kwargs): + """Import MetaGraph, and return both a saver and returned elements.""" if context.executing_eagerly(): raise RuntimeError("Exporting/importing meta graphs is not supported when " "eager execution is enabled. No graph exists when eager " @@ -1922,12 +1942,22 @@ def import_meta_graph(meta_graph_or_file, clear_devices=False, else: meta_graph_def = meta_graph_or_file - imported_vars = meta_graph.import_scoped_meta_graph( - meta_graph_def, - clear_devices=clear_devices, - import_scope=import_scope, - **kwargs) + imported_vars, imported_return_elements = ( + meta_graph.import_scoped_meta_graph_with_return_elements( + meta_graph_def, + clear_devices=clear_devices, + import_scope=import_scope, + return_elements=return_elements, + **kwargs)) + + saver = _create_saver_from_imported_meta_graph( + meta_graph_def, import_scope, imported_vars) + return saver, imported_return_elements + +def _create_saver_from_imported_meta_graph( + meta_graph_def, import_scope, imported_vars): + """Return a saver for restoring variable values to an imported MetaGraph.""" if meta_graph_def.HasField("saver_def"): # Infer the scope that is prepended by `import_scoped_meta_graph`. scope = import_scope @@ -2139,6 +2169,14 @@ def _meta_graph_filename(checkpoint_filename, meta_graph_suffix="meta"): return meta_graph_filename +def _wrap_restore_error_with_msg(err, extra_verbiage): + err_msg = ("Restoring from checkpoint failed. This is most likely " + "due to {} from the checkpoint. Please ensure that you " + "have not altered the graph expected based on the checkpoint. " + "Original error:\n\n{}").format(extra_verbiage, err.message) + return err.__class__(err.node_def, err.op, err_msg) + + ops.register_proto_function( ops.GraphKeys.SAVERS, proto_type=saver_pb2.SaverDef, |