aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/optimizer_v2
diff options
context:
space:
mode:
authorGravatar Allen Lavoie <allenl@google.com>2018-08-02 15:47:43 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-02 15:51:17 -0700
commit1bf206bc82f600886f1e19c9860f09f18984346b (patch)
treefbd6ee10df16e491142017e96120181b81a72ec5 /tensorflow/contrib/optimizer_v2
parent6fbbad97e293cc39bde32495e92614c69a9a7896 (diff)
Split checkpoint management utility functions out of saver.py
Pure refactor, in preparation for adding a higher level checkpoint management utility. This utility will also need to work with the Checkpoint proto, and globbing it on to saver.py seems dirty. PiperOrigin-RevId: 207179646
Diffstat (limited to 'tensorflow/contrib/optimizer_v2')
-rw-r--r--tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py13
1 files changed, 9 insertions, 4 deletions
diff --git a/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py b/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py
index 06ab58188a..28a531dfec 100644
--- a/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py
+++ b/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py
@@ -41,6 +41,7 @@ from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import template
from tensorflow.python.ops import variable_scope
+from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import saver as core_saver
from tensorflow.python.training import training_util
from tensorflow.python.training.checkpointable import tracking
@@ -278,7 +279,8 @@ class CheckpointingTests(test.TestCase):
root = util.Checkpoint(
optimizer=optimizer, model=model,
optimizer_step=training_util.get_or_create_global_step())
- root.restore(core_saver.latest_checkpoint(checkpoint_directory))
+ root.restore(checkpoint_management.latest_checkpoint(
+ checkpoint_directory))
for _ in range(num_training_steps):
# TODO(allenl): Use a Dataset and serialize/checkpoint it.
input_value = constant_op.constant([[3.]])
@@ -306,7 +308,8 @@ class CheckpointingTests(test.TestCase):
train_op = optimizer.minimize(
model(input_value),
global_step=root.global_step)
- checkpoint_path = core_saver.latest_checkpoint(checkpoint_directory)
+ checkpoint_path = checkpoint_management.latest_checkpoint(
+ checkpoint_directory)
with self.test_session(graph=ops.get_default_graph()) as session:
status = root.restore(save_path=checkpoint_path)
status.initialize_or_restore(session=session)
@@ -339,7 +342,8 @@ class CheckpointingTests(test.TestCase):
root = util.Checkpoint(
optimizer=optimizer, model=model,
global_step=training_util.get_or_create_global_step())
- checkpoint_path = core_saver.latest_checkpoint(checkpoint_directory)
+ checkpoint_path = checkpoint_management.latest_checkpoint(
+ checkpoint_directory)
status = root.restore(save_path=checkpoint_path)
input_value = constant_op.constant([[3.]])
train_fn = functools.partial(
@@ -372,7 +376,8 @@ class CheckpointingTests(test.TestCase):
root = util.Checkpoint(
optimizer=optimizer, model=model,
global_step=training_util.get_or_create_global_step())
- checkpoint_path = core_saver.latest_checkpoint(checkpoint_directory)
+ checkpoint_path = checkpoint_management.latest_checkpoint(
+ checkpoint_directory)
status = root.restore(save_path=checkpoint_path)
def train_fn():
@function.defun