diff options
author | Allen Lavoie <allenl@google.com> | 2018-08-02 15:47:43 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-02 15:51:17 -0700 |
commit | 1bf206bc82f600886f1e19c9860f09f18984346b (patch) | |
tree | fbd6ee10df16e491142017e96120181b81a72ec5 /tensorflow/contrib/optimizer_v2 | |
parent | 6fbbad97e293cc39bde32495e92614c69a9a7896 (diff) |
Split checkpoint management utility functions out of saver.py
Pure refactor, in preparation for adding a higher level checkpoint management utility. This utility will also need to work with the Checkpoint proto, and globbing it on to saver.py seems dirty.
PiperOrigin-RevId: 207179646
Diffstat (limited to 'tensorflow/contrib/optimizer_v2')
-rw-r--r-- | tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py | 13 |
1 files changed, 9 insertions, 4 deletions
diff --git a/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py b/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py index 06ab58188a..28a531dfec 100644 --- a/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py +++ b/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py @@ -41,6 +41,7 @@ from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import template from tensorflow.python.ops import variable_scope +from tensorflow.python.training import checkpoint_management from tensorflow.python.training import saver as core_saver from tensorflow.python.training import training_util from tensorflow.python.training.checkpointable import tracking @@ -278,7 +279,8 @@ class CheckpointingTests(test.TestCase): root = util.Checkpoint( optimizer=optimizer, model=model, optimizer_step=training_util.get_or_create_global_step()) - root.restore(core_saver.latest_checkpoint(checkpoint_directory)) + root.restore(checkpoint_management.latest_checkpoint( + checkpoint_directory)) for _ in range(num_training_steps): # TODO(allenl): Use a Dataset and serialize/checkpoint it. input_value = constant_op.constant([[3.]]) @@ -306,7 +308,8 @@ class CheckpointingTests(test.TestCase): train_op = optimizer.minimize( model(input_value), global_step=root.global_step) - checkpoint_path = core_saver.latest_checkpoint(checkpoint_directory) + checkpoint_path = checkpoint_management.latest_checkpoint( + checkpoint_directory) with self.test_session(graph=ops.get_default_graph()) as session: status = root.restore(save_path=checkpoint_path) status.initialize_or_restore(session=session) @@ -339,7 +342,8 @@ class CheckpointingTests(test.TestCase): root = util.Checkpoint( optimizer=optimizer, model=model, global_step=training_util.get_or_create_global_step()) - checkpoint_path = core_saver.latest_checkpoint(checkpoint_directory) + checkpoint_path = checkpoint_management.latest_checkpoint( + checkpoint_directory) status = root.restore(save_path=checkpoint_path) input_value = constant_op.constant([[3.]]) train_fn = functools.partial( @@ -372,7 +376,8 @@ class CheckpointingTests(test.TestCase): root = util.Checkpoint( optimizer=optimizer, model=model, global_step=training_util.get_or_create_global_step()) - checkpoint_path = core_saver.latest_checkpoint(checkpoint_directory) + checkpoint_path = checkpoint_management.latest_checkpoint( + checkpoint_directory) status = root.restore(save_path=checkpoint_path) def train_fn(): @function.defun |