diff options
author | 2018-08-24 11:26:38 -0700 | |
---|---|---|
committer | 2018-08-24 11:45:16 -0700 | |
commit | 89cd5087643bdf7a2a12996e8d21b916c7f25ec3 (patch) | |
tree | c379b8e62d0857142d95d93fab95381092f50dd7 /tensorflow/python/training | |
parent | 83839064dd8061089a7fdf69e1065655b432c4fd (diff) |
Add a max_to_keep=None option to CheckpointManager
Doesn't delete anything. Also keeps paths to all checkpoints; I will follow up with a way to remove them manually.
PiperOrigin-RevId: 210128785
Diffstat (limited to 'tensorflow/python/training')
-rw-r--r-- | tensorflow/python/training/checkpoint_management.py | 14 | ||||
-rw-r--r-- | tensorflow/python/training/checkpoint_management_test.py | 44 |
2 files changed, 55 insertions, 3 deletions
diff --git a/tensorflow/python/training/checkpoint_management.py b/tensorflow/python/training/checkpoint_management.py index 85f2904318..b7aa8264b0 100644 --- a/tensorflow/python/training/checkpoint_management.py +++ b/tensorflow/python/training/checkpoint_management.py @@ -510,7 +510,10 @@ class CheckpointManager(object): max_to_keep: An integer, the number of checkpoints to keep. Unless preserved by `keep_checkpoint_every_n_hours`, checkpoints will be deleted from the active set, oldest first, until only `max_to_keep` - checkpoints remain. + checkpoints remain. If `None`, no checkpoints are deleted and everything + stays in the active set. Note that `max_to_keep=None` will keep all + checkpoint paths in memory and in the checkpoint state protocol buffer + on disk. keep_checkpoint_every_n_hours: Upon removal from the active set, a checkpoint will be preserved if it has been at least `keep_checkpoint_every_n_hours` since the last preserved checkpoint. The @@ -521,9 +524,10 @@ class CheckpointManager(object): """ self._checkpoint = checkpoint self._save_counter_assign = None - if not max_to_keep or max_to_keep < 0: + if max_to_keep is not None and max_to_keep <= 0: raise ValueError( - "Expected a positive integer for `max_to_max_to_keep`, got %d." + ("Expected a positive integer or `None` for `max_to_max_to_keep`, " + "got %d.") % (max_to_keep,)) self._max_to_keep = max_to_keep self._keep_checkpoint_every_n_hours = keep_checkpoint_every_n_hours @@ -586,6 +590,10 @@ class CheckpointManager(object): def _sweep(self): """Deletes or preserves managed checkpoints.""" + if not self._max_to_keep: + # Does not update self._last_preserved_timestamp, since everything is kept + # in the active set. + return while len(self._maybe_delete) > self._max_to_keep: filename, timestamp = self._maybe_delete.popitem(last=False) # Even if we're keeping this checkpoint due to diff --git a/tensorflow/python/training/checkpoint_management_test.py b/tensorflow/python/training/checkpoint_management_test.py index 22c2cc678a..d7162265e6 100644 --- a/tensorflow/python/training/checkpoint_management_test.py +++ b/tensorflow/python/training/checkpoint_management_test.py @@ -26,6 +26,7 @@ import tempfile from google.protobuf import text_format from tensorflow.core.protobuf import saver_pb2 +from tensorflow.python.eager import context from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops as ops_lib from tensorflow.python.framework import test_util @@ -333,6 +334,49 @@ class CheckpointManagerTest(test.TestCase): self.assertFalse(checkpoint_management.checkpoint_exists(first_path)) @test_util.run_in_graph_and_eager_modes + def testKeepAll(self): + checkpoint = util.Checkpoint() + directory = os.path.join( + self.get_temp_dir(), + # Avoid sharing directories between eager and graph + # TODO(allenl): stop run_in_graph_and_eager_modes reusing directories + str(context.executing_eagerly())) + manager = checkpoint_management.CheckpointManager( + checkpoint, directory, max_to_keep=None) + first_path = manager.save() + second_path = manager.save() + third_path = manager.save() + self.assertTrue(checkpoint_management.checkpoint_exists(third_path)) + self.assertTrue(checkpoint_management.checkpoint_exists(second_path)) + self.assertTrue(checkpoint_management.checkpoint_exists(first_path)) + self.assertEqual(third_path, manager.latest_checkpoint) + self.assertEqual([first_path, second_path, third_path], + manager.checkpoints) + del manager + manager = checkpoint_management.CheckpointManager( + checkpoint, directory, max_to_keep=None) + fourth_path = manager.save() + self.assertEqual([first_path, second_path, third_path, fourth_path], + manager.checkpoints) + del manager + manager = checkpoint_management.CheckpointManager( + checkpoint, directory, max_to_keep=3) + self.assertEqual([first_path, second_path, third_path, fourth_path], + manager.checkpoints) + self.assertTrue(checkpoint_management.checkpoint_exists(fourth_path)) + self.assertTrue(checkpoint_management.checkpoint_exists(third_path)) + self.assertTrue(checkpoint_management.checkpoint_exists(second_path)) + self.assertTrue(checkpoint_management.checkpoint_exists(first_path)) + fifth_path = manager.save() + self.assertEqual([third_path, fourth_path, fifth_path], + manager.checkpoints) + self.assertTrue(checkpoint_management.checkpoint_exists(fifth_path)) + self.assertTrue(checkpoint_management.checkpoint_exists(fourth_path)) + self.assertTrue(checkpoint_management.checkpoint_exists(third_path)) + self.assertFalse(checkpoint_management.checkpoint_exists(second_path)) + self.assertFalse(checkpoint_management.checkpoint_exists(first_path)) + + @test_util.run_in_graph_and_eager_modes @test.mock.patch.object(checkpoint_management, "time") def testSaveRestoreState(self, mock_time): directory = self.get_temp_dir() |