diff options
Diffstat (limited to 'tensorflow/python/training/checkpoint_management_test.py')
-rw-r--r-- | tensorflow/python/training/checkpoint_management_test.py | 201 |
1 files changed, 201 insertions, 0 deletions
diff --git a/tensorflow/python/training/checkpoint_management_test.py b/tensorflow/python/training/checkpoint_management_test.py index 4b31d0c613..1e2827d0a4 100644 --- a/tensorflow/python/training/checkpoint_management_test.py +++ b/tensorflow/python/training/checkpoint_management_test.py @@ -26,14 +26,18 @@ import tempfile from google.protobuf import text_format from tensorflow.core.protobuf import saver_pb2 +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops as ops_lib +from tensorflow.python.framework import test_util from tensorflow.python.lib.io import file_io from tensorflow.python.ops import variables from tensorflow.python.platform import gfile from tensorflow.python.platform import test +from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import checkpoint_management from tensorflow.python.training import saver as saver_module from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState +from tensorflow.python.training.checkpointable import util class LatestCheckpointWithRelativePaths(test.TestCase): @@ -312,5 +316,202 @@ class SaverUtilsTest(test.TestCase): self.assertFalse(checkpoint_management.checkpoint_exists(ckpt_prefix)) +class CheckpointManagerTest(test.TestCase): + + @test_util.run_in_graph_and_eager_modes + def testDeletion(self): + checkpoint = util.Checkpoint() + manager = checkpoint_management.CheckpointManager( + checkpoint, self.get_temp_dir(), max_to_keep=3) + first_path = manager.save() + second_path = manager.save() + third_path = manager.save() + fourth_path = manager.save() + self.assertTrue(checkpoint_management.checkpoint_exists(fourth_path)) + self.assertTrue(checkpoint_management.checkpoint_exists(third_path)) + self.assertTrue(checkpoint_management.checkpoint_exists(second_path)) + self.assertFalse(checkpoint_management.checkpoint_exists(first_path)) + + @test_util.run_in_graph_and_eager_modes + @test.mock.patch.object(checkpoint_management, "time") + def testSaveRestoreState(self, mock_time): + directory = self.get_temp_dir() + mock_time.time.return_value = 3. + checkpoint = util.Checkpoint() + first_manager = checkpoint_management.CheckpointManager( + checkpoint, directory, max_to_keep=2) + first_time = 10000. + first_name = os.path.join(directory, "ckpt-1") + mock_time.time.return_value = first_time + first_manager.save() + state = checkpoint_management.get_checkpoint_state(directory) + self.assertEqual([first_time], state.all_model_checkpoint_timestamps) + self.assertEqual(3., state.last_preserved_timestamp) + second_time = first_time + 3610. + second_name = os.path.join(directory, "ckpt-2") + mock_time.time.return_value = second_time + first_manager.save() + state = checkpoint_management.get_checkpoint_state(directory) + self.assertEqual([first_time, second_time], + state.all_model_checkpoint_timestamps) + self.assertEqual(3., state.last_preserved_timestamp) + self.assertEqual([first_name, second_name], first_manager.checkpoints) + self.assertEqual(second_name, first_manager.latest_checkpoint) + del first_manager + + second_manager = checkpoint_management.CheckpointManager( + checkpoint, directory, + max_to_keep=2, keep_checkpoint_every_n_hours=1.5) + self.assertEqual([first_name, second_name], second_manager.checkpoints) + self.assertEqual(second_name, second_manager.latest_checkpoint) + third_name = os.path.join(directory, "ckpt-3") + third_time = second_time + 3600. * 0.2 + mock_time.time.return_value = third_time + second_manager.save() + self.assertTrue(checkpoint_management.checkpoint_exists(first_name)) + self.assertTrue(checkpoint_management.checkpoint_exists(second_name)) + self.assertEqual([second_name, third_name], + second_manager.checkpoints) + state = checkpoint_management.get_checkpoint_state(directory) + self.assertEqual(first_time, state.last_preserved_timestamp) + fourth_time = third_time + 3600. * 0.5 + mock_time.time.return_value = fourth_time + fourth_name = os.path.join(directory, "ckpt-4") + second_manager.save() + self.assertTrue(checkpoint_management.checkpoint_exists(first_name)) + self.assertFalse(checkpoint_management.checkpoint_exists(second_name)) + self.assertEqual([third_name, fourth_name], + second_manager.checkpoints) + fifth_time = fourth_time + 3600. * 0.5 + mock_time.time.return_value = fifth_time + fifth_name = os.path.join(directory, "ckpt-5") + second_manager.save() + self.assertEqual([fourth_name, fifth_name], + second_manager.checkpoints) + state = checkpoint_management.get_checkpoint_state(directory) + self.assertEqual(first_time, state.last_preserved_timestamp) + del second_manager + third_manager = checkpoint_management.CheckpointManager( + checkpoint, directory, + max_to_keep=2, keep_checkpoint_every_n_hours=1.5) + self.assertEqual(fifth_name, third_manager.latest_checkpoint) + mock_time.time.return_value += 10. + third_manager.save() + sixth_name = os.path.join(directory, "ckpt-6") + state = checkpoint_management.get_checkpoint_state(directory) + self.assertEqual(fourth_time, state.last_preserved_timestamp) + self.assertTrue(checkpoint_management.checkpoint_exists(first_name)) + self.assertTrue(checkpoint_management.checkpoint_exists(fourth_name)) + self.assertTrue(checkpoint_management.checkpoint_exists(fifth_name)) + self.assertTrue(checkpoint_management.checkpoint_exists(sixth_name)) + self.assertFalse(checkpoint_management.checkpoint_exists(second_name)) + self.assertFalse(checkpoint_management.checkpoint_exists(third_name)) + self.assertEqual([fifth_name, sixth_name], + third_manager.checkpoints) + + @test_util.run_in_graph_and_eager_modes + def testContinueFromUnmanaged(self): + directory = self.get_temp_dir() + prefix = os.path.join(directory, "unusual_prefix") + checkpoint = util.Checkpoint() + first_path = checkpoint.save(prefix) + second_path = checkpoint.save(prefix) + del checkpoint + checkpoint = util.Checkpoint() + manager = checkpoint_management.CheckpointManager( + checkpoint, directory, max_to_keep=2) + checkpoint.restore(manager.latest_checkpoint).run_restore_ops() + self.assertEqual(2, self.evaluate(checkpoint.save_counter)) + third_path = manager.save() + self.assertEqual([third_path], manager.checkpoints) + fourth_path = manager.save() + self.assertEqual([third_path, fourth_path], + manager.checkpoints) + fifth_path = manager.save() + self.assertEqual([fourth_path, fifth_path], + manager.checkpoints) + self.assertTrue(checkpoint_management.checkpoint_exists(first_path)) + self.assertTrue(checkpoint_management.checkpoint_exists(second_path)) + self.assertFalse(checkpoint_management.checkpoint_exists(third_path)) + self.assertTrue(checkpoint_management.checkpoint_exists(fourth_path)) + self.assertTrue(checkpoint_management.checkpoint_exists(fifth_path)) + + @test_util.run_in_graph_and_eager_modes + @test.mock.patch.object(checkpoint_management, "time") + def testClockReset(self, mock_time): + directory = self.get_temp_dir() + mock_time.time.return_value = 10000. + checkpoint = util.Checkpoint() + first_manager = checkpoint_management.CheckpointManager( + checkpoint, directory, max_to_keep=1, keep_checkpoint_every_n_hours=1.) + first_path = first_manager.save() + mock_time.time.return_value += 3600. + second_path = first_manager.save() + mock_time.time.return_value += 3600. + third_path = first_manager.save() + self.assertFalse(checkpoint_management.checkpoint_exists(first_path)) + self.assertTrue(checkpoint_management.checkpoint_exists(second_path)) + self.assertTrue(checkpoint_management.checkpoint_exists(third_path)) + self.assertEqual([third_path], first_manager.checkpoints) + state = checkpoint_management.get_checkpoint_state(directory) + self.assertEqual(13600., state.last_preserved_timestamp) + # Set the clock back in time + mock_time.time.return_value = 5000. + del first_manager + with test.mock.patch.object(logging, "warning") as mock_log: + second_manager = checkpoint_management.CheckpointManager( + checkpoint, directory, max_to_keep=1) + self.assertRegexpMatches( + str(mock_log.call_args), + "behind the last preserved checkpoint timestamp") + # We should err on the side of keeping checkpoints around when we're not + # sure whether they were preserved or not due to clock funkiness. + self.assertTrue(checkpoint_management.checkpoint_exists(second_path)) + # We know about the existing checkpoints, but they'll never be deleted and + # so won't go in the CheckpointState proto on save. + self.assertEqual(third_path, second_manager.latest_checkpoint) + self.assertEqual([], second_manager.checkpoints) + mock_time.time.return_value += 10. + fourth_path = second_manager.save() + self.assertTrue(checkpoint_management.checkpoint_exists(second_path)) + self.assertTrue(checkpoint_management.checkpoint_exists(third_path)) + self.assertEqual(fourth_path, second_manager.latest_checkpoint) + self.assertEqual([fourth_path], second_manager.checkpoints) + mock_time.time.return_value += 10. + fifth_path = second_manager.save() + self.assertTrue(checkpoint_management.checkpoint_exists(second_path)) + self.assertTrue(checkpoint_management.checkpoint_exists(third_path)) + self.assertEqual([fifth_path], second_manager.checkpoints) + state = checkpoint_management.get_checkpoint_state(directory) + self.assertEqual(5000., state.last_preserved_timestamp) + self.assertEqual([5020.], + state.all_model_checkpoint_timestamps) + + @test_util.run_in_graph_and_eager_modes + def testCustomNumbering(self): + directory = self.get_temp_dir() + step = variables.Variable(0, dtype=dtypes.int64) + checkpoint = util.Checkpoint(step=step) + manager = checkpoint_management.CheckpointManager( + checkpoint, directory, max_to_keep=2) + self.evaluate(step.initializer) + for i in range(5): + path = manager.save(checkpoint_number=step) + expected_suffix = "-%d" % (2 * i,) + if not path.endswith(expected_suffix): + self.fail("%s should have suffix %s" % (path, expected_suffix)) + self.evaluate(step.assign_add(2)) + self.assertEqual(5, self.evaluate(checkpoint.save_counter)) + # Test regular integers + last_path = manager.save(checkpoint_number=32) + self.assertIn("-32", last_path) + self.assertEqual(last_path, manager.latest_checkpoint) + self.assertEqual( + last_path, checkpoint_management.latest_checkpoint(directory)) + state = checkpoint_management.get_checkpoint_state(directory) + # Only the most recent two checkpoints are saved + self.assertEqual([path, last_path], state.all_model_checkpoint_paths) + + if __name__ == "__main__": test.main() |