aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/checkpointable/util_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/training/checkpointable/util_test.py')
-rw-r--r--tensorflow/python/training/checkpointable/util_test.py9
1 files changed, 4 insertions, 5 deletions
diff --git a/tensorflow/python/training/checkpointable/util_test.py b/tensorflow/python/training/checkpointable/util_test.py
index 98a42b1c20..a0a87b6b79 100644
--- a/tensorflow/python/training/checkpointable/util_test.py
+++ b/tensorflow/python/training/checkpointable/util_test.py
@@ -522,7 +522,6 @@ class CheckpointingTests(test.TestCase):
# Does create garbage when executing eagerly due to ops.Graph() creation.
num_training_steps = 10
checkpoint_directory = self.get_temp_dir()
- checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
for training_continuation in range(3):
with ops.Graph().as_default(), self.test_session(
graph=ops.get_default_graph()), test_util.device(use_gpu=True):
@@ -531,9 +530,9 @@ class CheckpointingTests(test.TestCase):
root = checkpointable_utils.Checkpoint(
optimizer=optimizer, model=model,
global_step=training_util.get_or_create_global_step())
- checkpoint_path = checkpoint_management.latest_checkpoint(
- checkpoint_directory)
- status = root.restore(save_path=checkpoint_path)
+ manager = checkpoint_management.CheckpointManager(
+ root, checkpoint_directory, max_to_keep=1)
+ status = root.restore(save_path=manager.latest_checkpoint)
input_value = constant_op.constant([[3.]])
train_fn = functools.partial(
optimizer.minimize,
@@ -544,7 +543,7 @@ class CheckpointingTests(test.TestCase):
status.initialize_or_restore()
for _ in range(num_training_steps):
train_fn()
- root.save(file_prefix=checkpoint_prefix)
+ manager.save()
self.assertEqual((training_continuation + 1) * num_training_steps,
self.evaluate(root.global_step))
self.assertEqual(training_continuation + 1,