aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training
diff options
context:
space:
mode:
authorGravatar Allen Lavoie <allenl@google.com>2018-08-24 11:26:38 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-24 11:45:16 -0700
commit89cd5087643bdf7a2a12996e8d21b916c7f25ec3 (patch)
treec379b8e62d0857142d95d93fab95381092f50dd7 /tensorflow/python/training
parent83839064dd8061089a7fdf69e1065655b432c4fd (diff)
Add a max_to_keep=None option to CheckpointManager
Doesn't delete anything. Also keeps paths to all checkpoints; I will follow up with a way to remove them manually. PiperOrigin-RevId: 210128785
Diffstat (limited to 'tensorflow/python/training')
-rw-r--r--tensorflow/python/training/checkpoint_management.py14
-rw-r--r--tensorflow/python/training/checkpoint_management_test.py44
2 files changed, 55 insertions, 3 deletions
diff --git a/tensorflow/python/training/checkpoint_management.py b/tensorflow/python/training/checkpoint_management.py
index 85f2904318..b7aa8264b0 100644
--- a/tensorflow/python/training/checkpoint_management.py
+++ b/tensorflow/python/training/checkpoint_management.py
@@ -510,7 +510,10 @@ class CheckpointManager(object):
max_to_keep: An integer, the number of checkpoints to keep. Unless
preserved by `keep_checkpoint_every_n_hours`, checkpoints will be
deleted from the active set, oldest first, until only `max_to_keep`
- checkpoints remain.
+ checkpoints remain. If `None`, no checkpoints are deleted and everything
+ stays in the active set. Note that `max_to_keep=None` will keep all
+ checkpoint paths in memory and in the checkpoint state protocol buffer
+ on disk.
keep_checkpoint_every_n_hours: Upon removal from the active set, a
checkpoint will be preserved if it has been at least
`keep_checkpoint_every_n_hours` since the last preserved checkpoint. The
@@ -521,9 +524,10 @@ class CheckpointManager(object):
"""
self._checkpoint = checkpoint
self._save_counter_assign = None
- if not max_to_keep or max_to_keep < 0:
+ if max_to_keep is not None and max_to_keep <= 0:
raise ValueError(
- "Expected a positive integer for `max_to_max_to_keep`, got %d."
+ ("Expected a positive integer or `None` for `max_to_max_to_keep`, "
+ "got %d.")
% (max_to_keep,))
self._max_to_keep = max_to_keep
self._keep_checkpoint_every_n_hours = keep_checkpoint_every_n_hours
@@ -586,6 +590,10 @@ class CheckpointManager(object):
def _sweep(self):
"""Deletes or preserves managed checkpoints."""
+ if not self._max_to_keep:
+ # Does not update self._last_preserved_timestamp, since everything is kept
+ # in the active set.
+ return
while len(self._maybe_delete) > self._max_to_keep:
filename, timestamp = self._maybe_delete.popitem(last=False)
# Even if we're keeping this checkpoint due to
diff --git a/tensorflow/python/training/checkpoint_management_test.py b/tensorflow/python/training/checkpoint_management_test.py
index 22c2cc678a..d7162265e6 100644
--- a/tensorflow/python/training/checkpoint_management_test.py
+++ b/tensorflow/python/training/checkpoint_management_test.py
@@ -26,6 +26,7 @@ import tempfile
from google.protobuf import text_format
from tensorflow.core.protobuf import saver_pb2
+from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops as ops_lib
from tensorflow.python.framework import test_util
@@ -333,6 +334,49 @@ class CheckpointManagerTest(test.TestCase):
self.assertFalse(checkpoint_management.checkpoint_exists(first_path))
@test_util.run_in_graph_and_eager_modes
+ def testKeepAll(self):
+ checkpoint = util.Checkpoint()
+ directory = os.path.join(
+ self.get_temp_dir(),
+ # Avoid sharing directories between eager and graph
+ # TODO(allenl): stop run_in_graph_and_eager_modes reusing directories
+ str(context.executing_eagerly()))
+ manager = checkpoint_management.CheckpointManager(
+ checkpoint, directory, max_to_keep=None)
+ first_path = manager.save()
+ second_path = manager.save()
+ third_path = manager.save()
+ self.assertTrue(checkpoint_management.checkpoint_exists(third_path))
+ self.assertTrue(checkpoint_management.checkpoint_exists(second_path))
+ self.assertTrue(checkpoint_management.checkpoint_exists(first_path))
+ self.assertEqual(third_path, manager.latest_checkpoint)
+ self.assertEqual([first_path, second_path, third_path],
+ manager.checkpoints)
+ del manager
+ manager = checkpoint_management.CheckpointManager(
+ checkpoint, directory, max_to_keep=None)
+ fourth_path = manager.save()
+ self.assertEqual([first_path, second_path, third_path, fourth_path],
+ manager.checkpoints)
+ del manager
+ manager = checkpoint_management.CheckpointManager(
+ checkpoint, directory, max_to_keep=3)
+ self.assertEqual([first_path, second_path, third_path, fourth_path],
+ manager.checkpoints)
+ self.assertTrue(checkpoint_management.checkpoint_exists(fourth_path))
+ self.assertTrue(checkpoint_management.checkpoint_exists(third_path))
+ self.assertTrue(checkpoint_management.checkpoint_exists(second_path))
+ self.assertTrue(checkpoint_management.checkpoint_exists(first_path))
+ fifth_path = manager.save()
+ self.assertEqual([third_path, fourth_path, fifth_path],
+ manager.checkpoints)
+ self.assertTrue(checkpoint_management.checkpoint_exists(fifth_path))
+ self.assertTrue(checkpoint_management.checkpoint_exists(fourth_path))
+ self.assertTrue(checkpoint_management.checkpoint_exists(third_path))
+ self.assertFalse(checkpoint_management.checkpoint_exists(second_path))
+ self.assertFalse(checkpoint_management.checkpoint_exists(first_path))
+
+ @test_util.run_in_graph_and_eager_modes
@test.mock.patch.object(checkpoint_management, "time")
def testSaveRestoreState(self, mock_time):
directory = self.get_temp_dir()