diff options
Diffstat (limited to 'tensorflow/python/training/checkpoint_management.py')
-rw-r--r-- | tensorflow/python/training/checkpoint_management.py | 293 |
1 files changed, 284 insertions, 9 deletions
diff --git a/tensorflow/python/training/checkpoint_management.py b/tensorflow/python/training/checkpoint_management.py index aaddc015ed..85f2904318 100644 --- a/tensorflow/python/training/checkpoint_management.py +++ b/tensorflow/python/training/checkpoint_management.py @@ -19,16 +19,23 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections import os.path import re +import time from google.protobuf import text_format from tensorflow.core.protobuf import saver_pb2 +from tensorflow.python.eager import context from tensorflow.python.framework import errors +from tensorflow.python.framework import ops from tensorflow.python.lib.io import file_io +from tensorflow.python.ops import variable_scope from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.training import training_util from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState +from tensorflow.python.util import compat from tensorflow.python.util.tf_export import tf_export @@ -51,7 +58,9 @@ def _GetCheckpointFilename(save_dir, latest_filename): @tf_export("train.generate_checkpoint_state_proto") def generate_checkpoint_state_proto(save_dir, model_checkpoint_path, - all_model_checkpoint_paths=None): + all_model_checkpoint_paths=None, + all_model_checkpoint_timestamps=None, + last_preserved_timestamp=None): """Generates a checkpoint state proto. Args: @@ -61,11 +70,20 @@ def generate_checkpoint_state_proto(save_dir, checkpoints, sorted from oldest to newest. If this is a non-empty list, the last element must be equal to model_checkpoint_path. These paths are also saved in the CheckpointState proto. - + all_model_checkpoint_timestamps: A list of floats, indicating the number of + seconds since the Epoch when each checkpoint was generated. + last_preserved_timestamp: A float, indicating the number of seconds since + the Epoch when the last preserved checkpoint was written, e.g. due to a + `keep_checkpoint_every_n_hours` parameter (see + `tf.contrib.checkpoint.CheckpointManager` for an implementation). Returns: CheckpointState proto with model_checkpoint_path and all_model_checkpoint_paths updated to either absolute paths or relative paths to the current save_dir. + + Raises: + ValueError: If `all_model_checkpoint_timestamps` was provided but its length + does not match `all_model_checkpoint_paths`. """ if all_model_checkpoint_paths is None: all_model_checkpoint_paths = [] @@ -76,6 +94,14 @@ def generate_checkpoint_state_proto(save_dir, model_checkpoint_path) all_model_checkpoint_paths.append(model_checkpoint_path) + if (all_model_checkpoint_timestamps + and (len(all_model_checkpoint_timestamps) + != len(all_model_checkpoint_paths))): + raise ValueError( + ("Checkpoint timestamps, if provided, must match checkpoint paths (got " + "paths %s and timestamps %s)") + % (all_model_checkpoint_paths, all_model_checkpoint_timestamps)) + # Relative paths need to be rewritten to be relative to the "save_dir" # if model_checkpoint_path already contains "save_dir". if not os.path.isabs(save_dir): @@ -88,7 +114,9 @@ def generate_checkpoint_state_proto(save_dir, coord_checkpoint_proto = CheckpointState( model_checkpoint_path=model_checkpoint_path, - all_model_checkpoint_paths=all_model_checkpoint_paths) + all_model_checkpoint_paths=all_model_checkpoint_paths, + all_model_checkpoint_timestamps=all_model_checkpoint_timestamps, + last_preserved_timestamp=last_preserved_timestamp) return coord_checkpoint_proto @@ -97,7 +125,9 @@ def generate_checkpoint_state_proto(save_dir, def update_checkpoint_state(save_dir, model_checkpoint_path, all_model_checkpoint_paths=None, - latest_filename=None): + latest_filename=None, + all_model_checkpoint_timestamps=None, + last_preserved_timestamp=None): """Updates the content of the 'checkpoint' file. This updates the checkpoint file containing a CheckpointState @@ -112,7 +142,13 @@ def update_checkpoint_state(save_dir, are also saved in the CheckpointState proto. latest_filename: Optional name of the checkpoint file. Default to 'checkpoint'. - + all_model_checkpoint_timestamps: Optional list of timestamps (floats, + seconds since the Epoch) indicating when the checkpoints in + `all_model_checkpoint_paths` were created. + last_preserved_timestamp: A float, indicating the number of seconds since + the Epoch when the last preserved checkpoint was written, e.g. due to a + `keep_checkpoint_every_n_hours` parameter (see + `tf.contrib.checkpoint.CheckpointManager` for an implementation). Raises: RuntimeError: If any of the model checkpoint paths conflict with the file containing CheckpointSate. @@ -122,14 +158,18 @@ def update_checkpoint_state(save_dir, model_checkpoint_path=model_checkpoint_path, all_model_checkpoint_paths=all_model_checkpoint_paths, latest_filename=latest_filename, - save_relative_paths=False) + save_relative_paths=False, + all_model_checkpoint_timestamps=all_model_checkpoint_timestamps, + last_preserved_timestamp=last_preserved_timestamp) def update_checkpoint_state_internal(save_dir, model_checkpoint_path, all_model_checkpoint_paths=None, latest_filename=None, - save_relative_paths=False): + save_relative_paths=False, + all_model_checkpoint_timestamps=None, + last_preserved_timestamp=None): """Updates the content of the 'checkpoint' file. This updates the checkpoint file containing a CheckpointState @@ -146,6 +186,13 @@ def update_checkpoint_state_internal(save_dir, 'checkpoint'. save_relative_paths: If `True`, will write relative paths to the checkpoint state file. + all_model_checkpoint_timestamps: Optional list of timestamps (floats, + seconds since the Epoch) indicating when the checkpoints in + `all_model_checkpoint_paths` were created. + last_preserved_timestamp: A float, indicating the number of seconds since + the Epoch when the last preserved checkpoint was written, e.g. due to a + `keep_checkpoint_every_n_hours` parameter (see + `tf.contrib.checkpoint.CheckpointManager` for an implementation). Raises: RuntimeError: If any of the model checkpoint paths conflict with the file @@ -168,12 +215,16 @@ def update_checkpoint_state_internal(save_dir, ckpt = generate_checkpoint_state_proto( save_dir, rel_model_checkpoint_path, - all_model_checkpoint_paths=rel_all_model_checkpoint_paths) + all_model_checkpoint_paths=rel_all_model_checkpoint_paths, + all_model_checkpoint_timestamps=all_model_checkpoint_timestamps, + last_preserved_timestamp=last_preserved_timestamp) else: ckpt = generate_checkpoint_state_proto( save_dir, model_checkpoint_path, - all_model_checkpoint_paths=all_model_checkpoint_paths) + all_model_checkpoint_paths=all_model_checkpoint_paths, + all_model_checkpoint_timestamps=all_model_checkpoint_timestamps, + last_preserved_timestamp=last_preserved_timestamp) if coord_checkpoint_filename == ckpt.model_checkpoint_path: raise RuntimeError("Save path '%s' conflicts with path used for " @@ -404,3 +455,227 @@ def meta_graph_filename(checkpoint_filename, meta_graph_suffix="meta"): basename = re.sub(r"-[\d\?]+-of-\d+$", "", checkpoint_filename) suffixed_filename = ".".join([basename, meta_graph_suffix]) return suffixed_filename + + +# TODO(allenl): Allow tf.keras.Model instances in the constructor directly? +class CheckpointManager(object): + """Deletes old checkpoints. + + Example usage: + ```python + import tensorflow as tf + checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model) + manager = tf.contrib.checkpoint.CheckpointManager( + checkpoint, directory="/tmp/model", max_to_keep=5) + status = checkpoint.restore(manager.latest_checkpoint) + while True: + # train + manager.save() + ``` + + `CheckpointManager` preserves its own state across instantiations (see the + `__init__` documentation for details). Only one should be active in a + particular directory at a time. + """ + + def __init__(self, checkpoint, directory, + max_to_keep, keep_checkpoint_every_n_hours=None): + """Configure a `CheckpointManager` for use in `directory`. + + If a `CheckpointManager` was previously used in `directory`, its + state will be restored. This includes the list of managed checkpoints and + the timestamp bookkeeping necessary to support + `keep_checkpoint_every_n_hours`. The behavior of the new `CheckpointManager` + will be the same as the previous `CheckpointManager`, including cleaning up + existing checkpoints if appropriate. + + Checkpoints are only considered for deletion just after a new checkpoint has + been added. At that point, `max_to_keep` checkpoints will remain in an + "active set". Once a checkpoint is preserved by + `keep_checkpoint_every_n_hours` it will not be deleted by this + `CheckpointManager` or any future `CheckpointManager` instantiated in + `directory` (regardless of the new setting of + `keep_checkpoint_every_n_hours`). The `max_to_keep` checkpoints in the + active set may be deleted by this `CheckpointManager` or a future + `CheckpointManager` instantiated in `directory` (subject to its + `max_to_keep` and `keep_checkpoint_every_n_hours` settings). + + Args: + checkpoint: The `tf.train.Checkpoint` instance to save and manage + checkpoints for. + directory: The path to a directory in which to write checkpoints. A + special file named "checkpoint" is also written to this directory (in a + human-readable text format) which contains the state of the + `CheckpointManager`. + max_to_keep: An integer, the number of checkpoints to keep. Unless + preserved by `keep_checkpoint_every_n_hours`, checkpoints will be + deleted from the active set, oldest first, until only `max_to_keep` + checkpoints remain. + keep_checkpoint_every_n_hours: Upon removal from the active set, a + checkpoint will be preserved if it has been at least + `keep_checkpoint_every_n_hours` since the last preserved checkpoint. The + default setting of `None` does not preserve any checkpoints in this way. + + Raises: + ValueError: If `max_to_keep` is not a positive integer. + """ + self._checkpoint = checkpoint + self._save_counter_assign = None + if not max_to_keep or max_to_keep < 0: + raise ValueError( + "Expected a positive integer for `max_to_max_to_keep`, got %d." + % (max_to_keep,)) + self._max_to_keep = max_to_keep + self._keep_checkpoint_every_n_hours = keep_checkpoint_every_n_hours + self._directory = directory + self._checkpoint_prefix = os.path.join(directory, "ckpt") + recovered_state = get_checkpoint_state(directory) + current_clock = time.time() + self._maybe_delete = collections.OrderedDict() + if recovered_state is None: + self._latest_checkpoint = None + self._last_preserved_timestamp = current_clock + else: + self._latest_checkpoint = recovered_state.model_checkpoint_path + self._last_preserved_timestamp = recovered_state.last_preserved_timestamp + if current_clock < self._last_preserved_timestamp: + # Time seems to have reversed itself. In addition to this warning, we'll + # min() saved checkpoint timestamps with the current time to ensure that + # old checkpoints don't get deleted accidentally. + logging.warning( + ("time.time() returned a value %f seconds behind the last " + "preserved checkpoint timestamp.") + % (self._last_preserved_timestamp - current_clock,)) + self._last_preserved_timestamp = current_clock + all_timestamps = recovered_state.all_model_checkpoint_timestamps + all_paths = recovered_state.all_model_checkpoint_paths + del recovered_state # Uses modified values from now on + if not all_timestamps: + all_timestamps = [self._last_preserved_timestamp] * len(all_paths) + + for filename, timestamp in zip(all_paths, all_timestamps): + timestamp = min(timestamp, current_clock) + if timestamp > self._last_preserved_timestamp: + self._maybe_delete[filename] = timestamp + + @property + def latest_checkpoint(self): + """The prefix of the most recent checkpoint in `directory`. + + Equivalent to `tf.train.latest_checkpoint(directory)` where `directory` is + the constructor argument to `CheckpointManager`. + + Suitable for passing to `tf.train.Checkpoint.restore` to resume training. + + Returns: + The checkpoint prefix. If there are no checkpoints, returns `None`. + """ + return self._latest_checkpoint + + @property + def checkpoints(self): + """A list of managed checkpoints. + + Note that checkpoints saved due to `keep_checkpoint_every_n_hours` will not + show up in this list (to avoid ever-growing filename lists). + + Returns: + A list of filenames, sorted from oldest to newest. + """ + return list(self._maybe_delete.keys()) + + def _sweep(self): + """Deletes or preserves managed checkpoints.""" + while len(self._maybe_delete) > self._max_to_keep: + filename, timestamp = self._maybe_delete.popitem(last=False) + # Even if we're keeping this checkpoint due to + # keep_checkpoint_every_n_hours, we won't reference it to avoid + # infinitely-growing CheckpointState protos. + if (self._keep_checkpoint_every_n_hours + and (timestamp - self._keep_checkpoint_every_n_hours * 3600. + >= self._last_preserved_timestamp)): + self._last_preserved_timestamp = timestamp + continue + remove_checkpoint(filename) + + def _record_state(self): + """Saves the `CheckpointManager`'s state in `directory`.""" + filenames, timestamps = zip(*self._maybe_delete.items()) + update_checkpoint_state_internal( + self._directory, + model_checkpoint_path=self.latest_checkpoint, + all_model_checkpoint_paths=filenames, + all_model_checkpoint_timestamps=timestamps, + last_preserved_timestamp=self._last_preserved_timestamp, + save_relative_paths=True) + + @property + def _prefix(self): + """A common prefix for all checkpoints saved with this manager. + + For example, if `directory` (a constructor argument) were `"/tmp/tf-model"`, + `prefix` would be `"/tmp/tf-model/ckpt"` and checkpoints would generally be + numbered `"/tmp/tf-model/ckpt-1"`, `"/tmp/tf-model/ckpt-2"`, and so on. Each + checkpoint has several associated files + (e.g. `"/tmp/tf-model/ckpt-2.index"`). + + Returns: + A string prefix. + """ + return self._checkpoint_prefix + + def save(self, session=None, checkpoint_number=None): + """Creates a new checkpoint and manages it. + + Args: + session: The session to evaluate variables in. Ignored when executing + eagerly. If not provided when graph building, the default session is + used. + checkpoint_number: An optional integer, or an integer-dtype `Variable` or + `Tensor`, used to number the checkpoint. If `None` (default), + checkpoints are numbered using `checkpoint.save_counter`. Even if + `checkpoint_number` is provided, `save_counter` is still incremented. A + user-provided `checkpoint_number` is not incremented even if it is a + `Variable`. + + Returns: + The path to the new checkpoint. It is also recorded in the `checkpoints` + and `latest_checkpoint` properies. + """ + # Save counter logic duplicated from tf.train.Checkpoint, soon to diverge + # slightly with a custom numbering option. + if context.executing_eagerly(): + save_counter = self._checkpoint.save_counter + save_counter.assign_add(1) + else: + if session is None: + session = ops.get_default_session() + + def _initializing_creator(next_creator, **kwargs): + """Initialize the save counter if it has been newly created.""" + v = next_creator(**kwargs) + session.run(v.initializer) + return v + + with variable_scope.variable_creator_scope(_initializing_creator): + save_counter = self._checkpoint.save_counter + if self._save_counter_assign is None: + self._save_counter_assign = save_counter.assign_add(1, read_value=False) + session.run(self._save_counter_assign) + if checkpoint_number is None: + checkpoint_number = save_counter + if not isinstance(checkpoint_number, compat.integral_types): + checkpoint_number = training_util.global_step( + sess=session, global_step_tensor=checkpoint_number) + prefix = "%s-%d" % (self._prefix, checkpoint_number) + save_path = self._checkpoint.write(prefix) + timestamp = time.time() + # If this is an overwritten checkpoint we were previously tracking, delete + # and reinsert it to make sure it goes to the end of the queue. + if save_path in self._maybe_delete: + del self._maybe_delete[save_path] + self._maybe_delete[save_path] = timestamp + self._latest_checkpoint = save_path + self._sweep() + self._record_state() + return save_path |