diff options
author | 2017-09-18 15:57:06 -0700 | |
---|---|---|
committer | 2017-09-18 16:00:11 -0700 | |
commit | 662a2e6c597a5f5d3c726992676a600bbdf3e53d (patch) | |
tree | ecb5a041f73d3ec155e3f2b1c3a6e9e9382455fc /tensorflow/python/training/saver.py | |
parent | 7de939bb74c5edbc2f45e77a5d4696e70bb59e5b (diff) |
Update the checkpoints index file in CheckpointProto before actually deleting files.
Current order of execution is:
1. Save the new checkpoint
2. Delete old checkpoint files
3. Update the checkpoint proto
If the job is preempted after 2, then the checkpoint proto is left pointing to a deleted file.
It is better to update the checkpoint proto first:
1. write new checkpoint
2. update checkpoint proto
3. delete old checkpoint
I added tests to cover checkpoint proto in 168744975 and they are not failing with this change.
PiperOrigin-RevId: 169161095
Diffstat (limited to 'tensorflow/python/training/saver.py')
-rw-r--r-- | tensorflow/python/training/saver.py | 41 |
1 files changed, 23 insertions, 18 deletions
diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py index 3aae6e17b4..138f566835 100644 --- a/tensorflow/python/training/saver.py +++ b/tensorflow/python/training/saver.py @@ -1273,6 +1273,7 @@ class Saver(object): self._next_checkpoint_time = ( time.time() + self.saver_def.keep_checkpoint_every_n_hours * 3600) self._last_checkpoints = [] + self._checkpoints_to_be_deleted = [] def _check_saver_def(self): if not isinstance(self.saver_def, saver_pb2.SaverDef): @@ -1315,21 +1316,8 @@ class Saver(object): meta_graph_filename = ".".join([basename, meta_graph_suffix]) return meta_graph_filename - def _MaybeDeleteOldCheckpoints(self, - latest_save_path, - meta_graph_suffix="meta"): - """Deletes old checkpoints if necessary. - - Always keep the last `max_to_keep` checkpoints. If - `keep_checkpoint_every_n_hours` was specified, keep an additional checkpoint - every `N` hours. For example, if `N` is 0.5, an additional checkpoint is - kept for every 0.5 hours of training; if `N` is 10, an additional - checkpoint is kept for every 10 hours of training. - - Args: - latest_save_path: Name including path of checkpoint file to save. - meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'. - """ + def _RecordLastCheckpoint(self, latest_save_path): + """Manages the list of the latest checkpoints.""" if not self.saver_def.max_to_keep: return # Remove first from list if the same name was used before. @@ -1338,9 +1326,26 @@ class Saver(object): self._last_checkpoints.remove(p) # Append new path to list self._last_checkpoints.append((latest_save_path, time.time())) + # If more than max_to_keep, remove oldest. if len(self._last_checkpoints) > self.saver_def.max_to_keep: - p = self._last_checkpoints.pop(0) + self._checkpoints_to_be_deleted.append(self._last_checkpoints.pop(0)) + + def _MaybeDeleteOldCheckpoints(self, meta_graph_suffix="meta"): + """Deletes old checkpoints if necessary. + + `self._checkpoints_to_be_deleted` is going to contain checkpoints that are + over `max_to_keep`. They are going to be deleted. If + `keep_checkpoint_every_n_hours` was specified, keep an additional checkpoint + every `N` hours. For example, if `N` is 0.5, an additional checkpoint is + kept for every 0.5 hours of training; if `N` is 10, an additional + checkpoint is kept for every 10 hours of training. + + Args: + meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'. + """ + if self._checkpoints_to_be_deleted: + p = self._checkpoints_to_be_deleted.pop(0) # Do not delete the file if we keep_checkpoint_every_n_hours is set and we # have reached N hours of training. should_keep = p[1] > self._next_checkpoint_time @@ -1569,14 +1574,14 @@ class Saver(object): model_checkpoint_path = compat.as_str(model_checkpoint_path) if write_state: - self._MaybeDeleteOldCheckpoints( - model_checkpoint_path, meta_graph_suffix=meta_graph_suffix) + self._RecordLastCheckpoint(model_checkpoint_path) _update_checkpoint_state( save_dir=save_path_parent, model_checkpoint_path=model_checkpoint_path, all_model_checkpoint_paths=self.last_checkpoints, latest_filename=latest_filename, save_relative_paths=self._save_relative_paths) + self._MaybeDeleteOldCheckpoints(meta_graph_suffix=meta_graph_suffix) except (errors.FailedPreconditionError, errors.NotFoundError) as exc: if not gfile.IsDirectory(save_path_parent): exc = ValueError( |