aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/checkpoint_management_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/training/checkpoint_management_test.py')
-rw-r--r--tensorflow/python/training/checkpoint_management_test.py201
1 files changed, 201 insertions, 0 deletions
diff --git a/tensorflow/python/training/checkpoint_management_test.py b/tensorflow/python/training/checkpoint_management_test.py
index 4b31d0c613..1e2827d0a4 100644
--- a/tensorflow/python/training/checkpoint_management_test.py
+++ b/tensorflow/python/training/checkpoint_management_test.py
@@ -26,14 +26,18 @@ import tempfile
from google.protobuf import text_format
from tensorflow.core.protobuf import saver_pb2
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops as ops_lib
+from tensorflow.python.framework import test_util
from tensorflow.python.lib.io import file_io
from tensorflow.python.ops import variables
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
+from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import saver as saver_module
from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState
+from tensorflow.python.training.checkpointable import util
class LatestCheckpointWithRelativePaths(test.TestCase):
@@ -312,5 +316,202 @@ class SaverUtilsTest(test.TestCase):
self.assertFalse(checkpoint_management.checkpoint_exists(ckpt_prefix))
+class CheckpointManagerTest(test.TestCase):
+
+ @test_util.run_in_graph_and_eager_modes
+ def testDeletion(self):
+ checkpoint = util.Checkpoint()
+ manager = checkpoint_management.CheckpointManager(
+ checkpoint, self.get_temp_dir(), max_to_keep=3)
+ first_path = manager.save()
+ second_path = manager.save()
+ third_path = manager.save()
+ fourth_path = manager.save()
+ self.assertTrue(checkpoint_management.checkpoint_exists(fourth_path))
+ self.assertTrue(checkpoint_management.checkpoint_exists(third_path))
+ self.assertTrue(checkpoint_management.checkpoint_exists(second_path))
+ self.assertFalse(checkpoint_management.checkpoint_exists(first_path))
+
+ @test_util.run_in_graph_and_eager_modes
+ @test.mock.patch.object(checkpoint_management, "time")
+ def testSaveRestoreState(self, mock_time):
+ directory = self.get_temp_dir()
+ mock_time.time.return_value = 3.
+ checkpoint = util.Checkpoint()
+ first_manager = checkpoint_management.CheckpointManager(
+ checkpoint, directory, max_to_keep=2)
+ first_time = 10000.
+ first_name = os.path.join(directory, "ckpt-1")
+ mock_time.time.return_value = first_time
+ first_manager.save()
+ state = checkpoint_management.get_checkpoint_state(directory)
+ self.assertEqual([first_time], state.all_model_checkpoint_timestamps)
+ self.assertEqual(3., state.last_preserved_timestamp)
+ second_time = first_time + 3610.
+ second_name = os.path.join(directory, "ckpt-2")
+ mock_time.time.return_value = second_time
+ first_manager.save()
+ state = checkpoint_management.get_checkpoint_state(directory)
+ self.assertEqual([first_time, second_time],
+ state.all_model_checkpoint_timestamps)
+ self.assertEqual(3., state.last_preserved_timestamp)
+ self.assertEqual([first_name, second_name], first_manager.checkpoints)
+ self.assertEqual(second_name, first_manager.latest_checkpoint)
+ del first_manager
+
+ second_manager = checkpoint_management.CheckpointManager(
+ checkpoint, directory,
+ max_to_keep=2, keep_checkpoint_every_n_hours=1.5)
+ self.assertEqual([first_name, second_name], second_manager.checkpoints)
+ self.assertEqual(second_name, second_manager.latest_checkpoint)
+ third_name = os.path.join(directory, "ckpt-3")
+ third_time = second_time + 3600. * 0.2
+ mock_time.time.return_value = third_time
+ second_manager.save()
+ self.assertTrue(checkpoint_management.checkpoint_exists(first_name))
+ self.assertTrue(checkpoint_management.checkpoint_exists(second_name))
+ self.assertEqual([second_name, third_name],
+ second_manager.checkpoints)
+ state = checkpoint_management.get_checkpoint_state(directory)
+ self.assertEqual(first_time, state.last_preserved_timestamp)
+ fourth_time = third_time + 3600. * 0.5
+ mock_time.time.return_value = fourth_time
+ fourth_name = os.path.join(directory, "ckpt-4")
+ second_manager.save()
+ self.assertTrue(checkpoint_management.checkpoint_exists(first_name))
+ self.assertFalse(checkpoint_management.checkpoint_exists(second_name))
+ self.assertEqual([third_name, fourth_name],
+ second_manager.checkpoints)
+ fifth_time = fourth_time + 3600. * 0.5
+ mock_time.time.return_value = fifth_time
+ fifth_name = os.path.join(directory, "ckpt-5")
+ second_manager.save()
+ self.assertEqual([fourth_name, fifth_name],
+ second_manager.checkpoints)
+ state = checkpoint_management.get_checkpoint_state(directory)
+ self.assertEqual(first_time, state.last_preserved_timestamp)
+ del second_manager
+ third_manager = checkpoint_management.CheckpointManager(
+ checkpoint, directory,
+ max_to_keep=2, keep_checkpoint_every_n_hours=1.5)
+ self.assertEqual(fifth_name, third_manager.latest_checkpoint)
+ mock_time.time.return_value += 10.
+ third_manager.save()
+ sixth_name = os.path.join(directory, "ckpt-6")
+ state = checkpoint_management.get_checkpoint_state(directory)
+ self.assertEqual(fourth_time, state.last_preserved_timestamp)
+ self.assertTrue(checkpoint_management.checkpoint_exists(first_name))
+ self.assertTrue(checkpoint_management.checkpoint_exists(fourth_name))
+ self.assertTrue(checkpoint_management.checkpoint_exists(fifth_name))
+ self.assertTrue(checkpoint_management.checkpoint_exists(sixth_name))
+ self.assertFalse(checkpoint_management.checkpoint_exists(second_name))
+ self.assertFalse(checkpoint_management.checkpoint_exists(third_name))
+ self.assertEqual([fifth_name, sixth_name],
+ third_manager.checkpoints)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testContinueFromUnmanaged(self):
+ directory = self.get_temp_dir()
+ prefix = os.path.join(directory, "unusual_prefix")
+ checkpoint = util.Checkpoint()
+ first_path = checkpoint.save(prefix)
+ second_path = checkpoint.save(prefix)
+ del checkpoint
+ checkpoint = util.Checkpoint()
+ manager = checkpoint_management.CheckpointManager(
+ checkpoint, directory, max_to_keep=2)
+ checkpoint.restore(manager.latest_checkpoint).run_restore_ops()
+ self.assertEqual(2, self.evaluate(checkpoint.save_counter))
+ third_path = manager.save()
+ self.assertEqual([third_path], manager.checkpoints)
+ fourth_path = manager.save()
+ self.assertEqual([third_path, fourth_path],
+ manager.checkpoints)
+ fifth_path = manager.save()
+ self.assertEqual([fourth_path, fifth_path],
+ manager.checkpoints)
+ self.assertTrue(checkpoint_management.checkpoint_exists(first_path))
+ self.assertTrue(checkpoint_management.checkpoint_exists(second_path))
+ self.assertFalse(checkpoint_management.checkpoint_exists(third_path))
+ self.assertTrue(checkpoint_management.checkpoint_exists(fourth_path))
+ self.assertTrue(checkpoint_management.checkpoint_exists(fifth_path))
+
+ @test_util.run_in_graph_and_eager_modes
+ @test.mock.patch.object(checkpoint_management, "time")
+ def testClockReset(self, mock_time):
+ directory = self.get_temp_dir()
+ mock_time.time.return_value = 10000.
+ checkpoint = util.Checkpoint()
+ first_manager = checkpoint_management.CheckpointManager(
+ checkpoint, directory, max_to_keep=1, keep_checkpoint_every_n_hours=1.)
+ first_path = first_manager.save()
+ mock_time.time.return_value += 3600.
+ second_path = first_manager.save()
+ mock_time.time.return_value += 3600.
+ third_path = first_manager.save()
+ self.assertFalse(checkpoint_management.checkpoint_exists(first_path))
+ self.assertTrue(checkpoint_management.checkpoint_exists(second_path))
+ self.assertTrue(checkpoint_management.checkpoint_exists(third_path))
+ self.assertEqual([third_path], first_manager.checkpoints)
+ state = checkpoint_management.get_checkpoint_state(directory)
+ self.assertEqual(13600., state.last_preserved_timestamp)
+ # Set the clock back in time
+ mock_time.time.return_value = 5000.
+ del first_manager
+ with test.mock.patch.object(logging, "warning") as mock_log:
+ second_manager = checkpoint_management.CheckpointManager(
+ checkpoint, directory, max_to_keep=1)
+ self.assertRegexpMatches(
+ str(mock_log.call_args),
+ "behind the last preserved checkpoint timestamp")
+ # We should err on the side of keeping checkpoints around when we're not
+ # sure whether they were preserved or not due to clock funkiness.
+ self.assertTrue(checkpoint_management.checkpoint_exists(second_path))
+ # We know about the existing checkpoints, but they'll never be deleted and
+ # so won't go in the CheckpointState proto on save.
+ self.assertEqual(third_path, second_manager.latest_checkpoint)
+ self.assertEqual([], second_manager.checkpoints)
+ mock_time.time.return_value += 10.
+ fourth_path = second_manager.save()
+ self.assertTrue(checkpoint_management.checkpoint_exists(second_path))
+ self.assertTrue(checkpoint_management.checkpoint_exists(third_path))
+ self.assertEqual(fourth_path, second_manager.latest_checkpoint)
+ self.assertEqual([fourth_path], second_manager.checkpoints)
+ mock_time.time.return_value += 10.
+ fifth_path = second_manager.save()
+ self.assertTrue(checkpoint_management.checkpoint_exists(second_path))
+ self.assertTrue(checkpoint_management.checkpoint_exists(third_path))
+ self.assertEqual([fifth_path], second_manager.checkpoints)
+ state = checkpoint_management.get_checkpoint_state(directory)
+ self.assertEqual(5000., state.last_preserved_timestamp)
+ self.assertEqual([5020.],
+ state.all_model_checkpoint_timestamps)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testCustomNumbering(self):
+ directory = self.get_temp_dir()
+ step = variables.Variable(0, dtype=dtypes.int64)
+ checkpoint = util.Checkpoint(step=step)
+ manager = checkpoint_management.CheckpointManager(
+ checkpoint, directory, max_to_keep=2)
+ self.evaluate(step.initializer)
+ for i in range(5):
+ path = manager.save(checkpoint_number=step)
+ expected_suffix = "-%d" % (2 * i,)
+ if not path.endswith(expected_suffix):
+ self.fail("%s should have suffix %s" % (path, expected_suffix))
+ self.evaluate(step.assign_add(2))
+ self.assertEqual(5, self.evaluate(checkpoint.save_counter))
+ # Test regular integers
+ last_path = manager.save(checkpoint_number=32)
+ self.assertIn("-32", last_path)
+ self.assertEqual(last_path, manager.latest_checkpoint)
+ self.assertEqual(
+ last_path, checkpoint_management.latest_checkpoint(directory))
+ state = checkpoint_management.get_checkpoint_state(directory)
+ # Only the most recent two checkpoints are saved
+ self.assertEqual([path, last_path], state.all_model_checkpoint_paths)
+
+
if __name__ == "__main__":
test.main()