aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/checkpoint_management.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/training/checkpoint_management.py')
-rw-r--r--tensorflow/python/training/checkpoint_management.py293
1 files changed, 284 insertions, 9 deletions
diff --git a/tensorflow/python/training/checkpoint_management.py b/tensorflow/python/training/checkpoint_management.py
index aaddc015ed..85f2904318 100644
--- a/tensorflow/python/training/checkpoint_management.py
+++ b/tensorflow/python/training/checkpoint_management.py
@@ -19,16 +19,23 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import collections
import os.path
import re
+import time
from google.protobuf import text_format
from tensorflow.core.protobuf import saver_pb2
+from tensorflow.python.eager import context
from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
from tensorflow.python.lib.io import file_io
+from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.training import training_util
from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState
+from tensorflow.python.util import compat
from tensorflow.python.util.tf_export import tf_export
@@ -51,7 +58,9 @@ def _GetCheckpointFilename(save_dir, latest_filename):
@tf_export("train.generate_checkpoint_state_proto")
def generate_checkpoint_state_proto(save_dir,
model_checkpoint_path,
- all_model_checkpoint_paths=None):
+ all_model_checkpoint_paths=None,
+ all_model_checkpoint_timestamps=None,
+ last_preserved_timestamp=None):
"""Generates a checkpoint state proto.
Args:
@@ -61,11 +70,20 @@ def generate_checkpoint_state_proto(save_dir,
checkpoints, sorted from oldest to newest. If this is a non-empty list,
the last element must be equal to model_checkpoint_path. These paths
are also saved in the CheckpointState proto.
-
+ all_model_checkpoint_timestamps: A list of floats, indicating the number of
+ seconds since the Epoch when each checkpoint was generated.
+ last_preserved_timestamp: A float, indicating the number of seconds since
+ the Epoch when the last preserved checkpoint was written, e.g. due to a
+ `keep_checkpoint_every_n_hours` parameter (see
+ `tf.contrib.checkpoint.CheckpointManager` for an implementation).
Returns:
CheckpointState proto with model_checkpoint_path and
all_model_checkpoint_paths updated to either absolute paths or
relative paths to the current save_dir.
+
+ Raises:
+ ValueError: If `all_model_checkpoint_timestamps` was provided but its length
+ does not match `all_model_checkpoint_paths`.
"""
if all_model_checkpoint_paths is None:
all_model_checkpoint_paths = []
@@ -76,6 +94,14 @@ def generate_checkpoint_state_proto(save_dir,
model_checkpoint_path)
all_model_checkpoint_paths.append(model_checkpoint_path)
+ if (all_model_checkpoint_timestamps
+ and (len(all_model_checkpoint_timestamps)
+ != len(all_model_checkpoint_paths))):
+ raise ValueError(
+ ("Checkpoint timestamps, if provided, must match checkpoint paths (got "
+ "paths %s and timestamps %s)")
+ % (all_model_checkpoint_paths, all_model_checkpoint_timestamps))
+
# Relative paths need to be rewritten to be relative to the "save_dir"
# if model_checkpoint_path already contains "save_dir".
if not os.path.isabs(save_dir):
@@ -88,7 +114,9 @@ def generate_checkpoint_state_proto(save_dir,
coord_checkpoint_proto = CheckpointState(
model_checkpoint_path=model_checkpoint_path,
- all_model_checkpoint_paths=all_model_checkpoint_paths)
+ all_model_checkpoint_paths=all_model_checkpoint_paths,
+ all_model_checkpoint_timestamps=all_model_checkpoint_timestamps,
+ last_preserved_timestamp=last_preserved_timestamp)
return coord_checkpoint_proto
@@ -97,7 +125,9 @@ def generate_checkpoint_state_proto(save_dir,
def update_checkpoint_state(save_dir,
model_checkpoint_path,
all_model_checkpoint_paths=None,
- latest_filename=None):
+ latest_filename=None,
+ all_model_checkpoint_timestamps=None,
+ last_preserved_timestamp=None):
"""Updates the content of the 'checkpoint' file.
This updates the checkpoint file containing a CheckpointState
@@ -112,7 +142,13 @@ def update_checkpoint_state(save_dir,
are also saved in the CheckpointState proto.
latest_filename: Optional name of the checkpoint file. Default to
'checkpoint'.
-
+ all_model_checkpoint_timestamps: Optional list of timestamps (floats,
+ seconds since the Epoch) indicating when the checkpoints in
+ `all_model_checkpoint_paths` were created.
+ last_preserved_timestamp: A float, indicating the number of seconds since
+ the Epoch when the last preserved checkpoint was written, e.g. due to a
+ `keep_checkpoint_every_n_hours` parameter (see
+ `tf.contrib.checkpoint.CheckpointManager` for an implementation).
Raises:
RuntimeError: If any of the model checkpoint paths conflict with the file
containing CheckpointSate.
@@ -122,14 +158,18 @@ def update_checkpoint_state(save_dir,
model_checkpoint_path=model_checkpoint_path,
all_model_checkpoint_paths=all_model_checkpoint_paths,
latest_filename=latest_filename,
- save_relative_paths=False)
+ save_relative_paths=False,
+ all_model_checkpoint_timestamps=all_model_checkpoint_timestamps,
+ last_preserved_timestamp=last_preserved_timestamp)
def update_checkpoint_state_internal(save_dir,
model_checkpoint_path,
all_model_checkpoint_paths=None,
latest_filename=None,
- save_relative_paths=False):
+ save_relative_paths=False,
+ all_model_checkpoint_timestamps=None,
+ last_preserved_timestamp=None):
"""Updates the content of the 'checkpoint' file.
This updates the checkpoint file containing a CheckpointState
@@ -146,6 +186,13 @@ def update_checkpoint_state_internal(save_dir,
'checkpoint'.
save_relative_paths: If `True`, will write relative paths to the checkpoint
state file.
+ all_model_checkpoint_timestamps: Optional list of timestamps (floats,
+ seconds since the Epoch) indicating when the checkpoints in
+ `all_model_checkpoint_paths` were created.
+ last_preserved_timestamp: A float, indicating the number of seconds since
+ the Epoch when the last preserved checkpoint was written, e.g. due to a
+ `keep_checkpoint_every_n_hours` parameter (see
+ `tf.contrib.checkpoint.CheckpointManager` for an implementation).
Raises:
RuntimeError: If any of the model checkpoint paths conflict with the file
@@ -168,12 +215,16 @@ def update_checkpoint_state_internal(save_dir,
ckpt = generate_checkpoint_state_proto(
save_dir,
rel_model_checkpoint_path,
- all_model_checkpoint_paths=rel_all_model_checkpoint_paths)
+ all_model_checkpoint_paths=rel_all_model_checkpoint_paths,
+ all_model_checkpoint_timestamps=all_model_checkpoint_timestamps,
+ last_preserved_timestamp=last_preserved_timestamp)
else:
ckpt = generate_checkpoint_state_proto(
save_dir,
model_checkpoint_path,
- all_model_checkpoint_paths=all_model_checkpoint_paths)
+ all_model_checkpoint_paths=all_model_checkpoint_paths,
+ all_model_checkpoint_timestamps=all_model_checkpoint_timestamps,
+ last_preserved_timestamp=last_preserved_timestamp)
if coord_checkpoint_filename == ckpt.model_checkpoint_path:
raise RuntimeError("Save path '%s' conflicts with path used for "
@@ -404,3 +455,227 @@ def meta_graph_filename(checkpoint_filename, meta_graph_suffix="meta"):
basename = re.sub(r"-[\d\?]+-of-\d+$", "", checkpoint_filename)
suffixed_filename = ".".join([basename, meta_graph_suffix])
return suffixed_filename
+
+
+# TODO(allenl): Allow tf.keras.Model instances in the constructor directly?
+class CheckpointManager(object):
+ """Deletes old checkpoints.
+
+ Example usage:
+ ```python
+ import tensorflow as tf
+ checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
+ manager = tf.contrib.checkpoint.CheckpointManager(
+ checkpoint, directory="/tmp/model", max_to_keep=5)
+ status = checkpoint.restore(manager.latest_checkpoint)
+ while True:
+ # train
+ manager.save()
+ ```
+
+ `CheckpointManager` preserves its own state across instantiations (see the
+ `__init__` documentation for details). Only one should be active in a
+ particular directory at a time.
+ """
+
+ def __init__(self, checkpoint, directory,
+ max_to_keep, keep_checkpoint_every_n_hours=None):
+ """Configure a `CheckpointManager` for use in `directory`.
+
+ If a `CheckpointManager` was previously used in `directory`, its
+ state will be restored. This includes the list of managed checkpoints and
+ the timestamp bookkeeping necessary to support
+ `keep_checkpoint_every_n_hours`. The behavior of the new `CheckpointManager`
+ will be the same as the previous `CheckpointManager`, including cleaning up
+ existing checkpoints if appropriate.
+
+ Checkpoints are only considered for deletion just after a new checkpoint has
+ been added. At that point, `max_to_keep` checkpoints will remain in an
+ "active set". Once a checkpoint is preserved by
+ `keep_checkpoint_every_n_hours` it will not be deleted by this
+ `CheckpointManager` or any future `CheckpointManager` instantiated in
+ `directory` (regardless of the new setting of
+ `keep_checkpoint_every_n_hours`). The `max_to_keep` checkpoints in the
+ active set may be deleted by this `CheckpointManager` or a future
+ `CheckpointManager` instantiated in `directory` (subject to its
+ `max_to_keep` and `keep_checkpoint_every_n_hours` settings).
+
+ Args:
+ checkpoint: The `tf.train.Checkpoint` instance to save and manage
+ checkpoints for.
+ directory: The path to a directory in which to write checkpoints. A
+ special file named "checkpoint" is also written to this directory (in a
+ human-readable text format) which contains the state of the
+ `CheckpointManager`.
+ max_to_keep: An integer, the number of checkpoints to keep. Unless
+ preserved by `keep_checkpoint_every_n_hours`, checkpoints will be
+ deleted from the active set, oldest first, until only `max_to_keep`
+ checkpoints remain.
+ keep_checkpoint_every_n_hours: Upon removal from the active set, a
+ checkpoint will be preserved if it has been at least
+ `keep_checkpoint_every_n_hours` since the last preserved checkpoint. The
+ default setting of `None` does not preserve any checkpoints in this way.
+
+ Raises:
+ ValueError: If `max_to_keep` is not a positive integer.
+ """
+ self._checkpoint = checkpoint
+ self._save_counter_assign = None
+ if not max_to_keep or max_to_keep < 0:
+ raise ValueError(
+ "Expected a positive integer for `max_to_max_to_keep`, got %d."
+ % (max_to_keep,))
+ self._max_to_keep = max_to_keep
+ self._keep_checkpoint_every_n_hours = keep_checkpoint_every_n_hours
+ self._directory = directory
+ self._checkpoint_prefix = os.path.join(directory, "ckpt")
+ recovered_state = get_checkpoint_state(directory)
+ current_clock = time.time()
+ self._maybe_delete = collections.OrderedDict()
+ if recovered_state is None:
+ self._latest_checkpoint = None
+ self._last_preserved_timestamp = current_clock
+ else:
+ self._latest_checkpoint = recovered_state.model_checkpoint_path
+ self._last_preserved_timestamp = recovered_state.last_preserved_timestamp
+ if current_clock < self._last_preserved_timestamp:
+ # Time seems to have reversed itself. In addition to this warning, we'll
+ # min() saved checkpoint timestamps with the current time to ensure that
+ # old checkpoints don't get deleted accidentally.
+ logging.warning(
+ ("time.time() returned a value %f seconds behind the last "
+ "preserved checkpoint timestamp.")
+ % (self._last_preserved_timestamp - current_clock,))
+ self._last_preserved_timestamp = current_clock
+ all_timestamps = recovered_state.all_model_checkpoint_timestamps
+ all_paths = recovered_state.all_model_checkpoint_paths
+ del recovered_state # Uses modified values from now on
+ if not all_timestamps:
+ all_timestamps = [self._last_preserved_timestamp] * len(all_paths)
+
+ for filename, timestamp in zip(all_paths, all_timestamps):
+ timestamp = min(timestamp, current_clock)
+ if timestamp > self._last_preserved_timestamp:
+ self._maybe_delete[filename] = timestamp
+
+ @property
+ def latest_checkpoint(self):
+ """The prefix of the most recent checkpoint in `directory`.
+
+ Equivalent to `tf.train.latest_checkpoint(directory)` where `directory` is
+ the constructor argument to `CheckpointManager`.
+
+ Suitable for passing to `tf.train.Checkpoint.restore` to resume training.
+
+ Returns:
+ The checkpoint prefix. If there are no checkpoints, returns `None`.
+ """
+ return self._latest_checkpoint
+
+ @property
+ def checkpoints(self):
+ """A list of managed checkpoints.
+
+ Note that checkpoints saved due to `keep_checkpoint_every_n_hours` will not
+ show up in this list (to avoid ever-growing filename lists).
+
+ Returns:
+ A list of filenames, sorted from oldest to newest.
+ """
+ return list(self._maybe_delete.keys())
+
+ def _sweep(self):
+ """Deletes or preserves managed checkpoints."""
+ while len(self._maybe_delete) > self._max_to_keep:
+ filename, timestamp = self._maybe_delete.popitem(last=False)
+ # Even if we're keeping this checkpoint due to
+ # keep_checkpoint_every_n_hours, we won't reference it to avoid
+ # infinitely-growing CheckpointState protos.
+ if (self._keep_checkpoint_every_n_hours
+ and (timestamp - self._keep_checkpoint_every_n_hours * 3600.
+ >= self._last_preserved_timestamp)):
+ self._last_preserved_timestamp = timestamp
+ continue
+ remove_checkpoint(filename)
+
+ def _record_state(self):
+ """Saves the `CheckpointManager`'s state in `directory`."""
+ filenames, timestamps = zip(*self._maybe_delete.items())
+ update_checkpoint_state_internal(
+ self._directory,
+ model_checkpoint_path=self.latest_checkpoint,
+ all_model_checkpoint_paths=filenames,
+ all_model_checkpoint_timestamps=timestamps,
+ last_preserved_timestamp=self._last_preserved_timestamp,
+ save_relative_paths=True)
+
+ @property
+ def _prefix(self):
+ """A common prefix for all checkpoints saved with this manager.
+
+ For example, if `directory` (a constructor argument) were `"/tmp/tf-model"`,
+ `prefix` would be `"/tmp/tf-model/ckpt"` and checkpoints would generally be
+ numbered `"/tmp/tf-model/ckpt-1"`, `"/tmp/tf-model/ckpt-2"`, and so on. Each
+ checkpoint has several associated files
+ (e.g. `"/tmp/tf-model/ckpt-2.index"`).
+
+ Returns:
+ A string prefix.
+ """
+ return self._checkpoint_prefix
+
+ def save(self, session=None, checkpoint_number=None):
+ """Creates a new checkpoint and manages it.
+
+ Args:
+ session: The session to evaluate variables in. Ignored when executing
+ eagerly. If not provided when graph building, the default session is
+ used.
+ checkpoint_number: An optional integer, or an integer-dtype `Variable` or
+ `Tensor`, used to number the checkpoint. If `None` (default),
+ checkpoints are numbered using `checkpoint.save_counter`. Even if
+ `checkpoint_number` is provided, `save_counter` is still incremented. A
+ user-provided `checkpoint_number` is not incremented even if it is a
+ `Variable`.
+
+ Returns:
+ The path to the new checkpoint. It is also recorded in the `checkpoints`
+ and `latest_checkpoint` properies.
+ """
+ # Save counter logic duplicated from tf.train.Checkpoint, soon to diverge
+ # slightly with a custom numbering option.
+ if context.executing_eagerly():
+ save_counter = self._checkpoint.save_counter
+ save_counter.assign_add(1)
+ else:
+ if session is None:
+ session = ops.get_default_session()
+
+ def _initializing_creator(next_creator, **kwargs):
+ """Initialize the save counter if it has been newly created."""
+ v = next_creator(**kwargs)
+ session.run(v.initializer)
+ return v
+
+ with variable_scope.variable_creator_scope(_initializing_creator):
+ save_counter = self._checkpoint.save_counter
+ if self._save_counter_assign is None:
+ self._save_counter_assign = save_counter.assign_add(1, read_value=False)
+ session.run(self._save_counter_assign)
+ if checkpoint_number is None:
+ checkpoint_number = save_counter
+ if not isinstance(checkpoint_number, compat.integral_types):
+ checkpoint_number = training_util.global_step(
+ sess=session, global_step_tensor=checkpoint_number)
+ prefix = "%s-%d" % (self._prefix, checkpoint_number)
+ save_path = self._checkpoint.write(prefix)
+ timestamp = time.time()
+ # If this is an overwritten checkpoint we were previously tracking, delete
+ # and reinsert it to make sure it goes to the end of the queue.
+ if save_path in self._maybe_delete:
+ del self._maybe_delete[save_path]
+ self._maybe_delete[save_path] = timestamp
+ self._latest_checkpoint = save_path
+ self._sweep()
+ self._record_state()
+ return save_path