diff options
Diffstat (limited to 'tensorflow/python/training/checkpointable/util_test.py')
-rw-r--r-- | tensorflow/python/training/checkpointable/util_test.py | 9 |
1 files changed, 4 insertions, 5 deletions
diff --git a/tensorflow/python/training/checkpointable/util_test.py b/tensorflow/python/training/checkpointable/util_test.py index 98a42b1c20..a0a87b6b79 100644 --- a/tensorflow/python/training/checkpointable/util_test.py +++ b/tensorflow/python/training/checkpointable/util_test.py @@ -522,7 +522,6 @@ class CheckpointingTests(test.TestCase): # Does create garbage when executing eagerly due to ops.Graph() creation. num_training_steps = 10 checkpoint_directory = self.get_temp_dir() - checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") for training_continuation in range(3): with ops.Graph().as_default(), self.test_session( graph=ops.get_default_graph()), test_util.device(use_gpu=True): @@ -531,9 +530,9 @@ class CheckpointingTests(test.TestCase): root = checkpointable_utils.Checkpoint( optimizer=optimizer, model=model, global_step=training_util.get_or_create_global_step()) - checkpoint_path = checkpoint_management.latest_checkpoint( - checkpoint_directory) - status = root.restore(save_path=checkpoint_path) + manager = checkpoint_management.CheckpointManager( + root, checkpoint_directory, max_to_keep=1) + status = root.restore(save_path=manager.latest_checkpoint) input_value = constant_op.constant([[3.]]) train_fn = functools.partial( optimizer.minimize, @@ -544,7 +543,7 @@ class CheckpointingTests(test.TestCase): status.initialize_or_restore() for _ in range(num_training_steps): train_fn() - root.save(file_prefix=checkpoint_prefix) + manager.save() self.assertEqual((training_continuation + 1) * num_training_steps, self.evaluate(root.global_step)) self.assertEqual(training_continuation + 1, |