aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/dataset_serialization_test_base.py3
-rw-r--r--tensorflow/contrib/data/python/ops/iterator_ops.py3
-rw-r--r--tensorflow/contrib/eager/python/datasets_test.py5
-rw-r--r--tensorflow/contrib/eager/python/examples/spinn/spinn_test.py4
-rw-r--r--tensorflow/contrib/framework/python/framework/checkpoint_utils.py4
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/estimator.py7
-rw-r--r--tensorflow/contrib/learn/python/learn/experiment.py10
-rw-r--r--tensorflow/contrib/learn/python/learn/graph_actions_test.py5
-rw-r--r--tensorflow/contrib/learn/python/learn/monitors.py8
-rw-r--r--tensorflow/contrib/learn/python/learn/monitors_test.py14
-rw-r--r--tensorflow/contrib/learn/python/learn/utils/export.py4
-rw-r--r--tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py5
-rw-r--r--tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py13
-rw-r--r--tensorflow/contrib/predictor/contrib_estimator_predictor.py5
-rw-r--r--tensorflow/contrib/training/python/training/evaluation.py4
-rw-r--r--tensorflow/contrib/training/python/training/training_test.py3
-rw-r--r--tensorflow/python/BUILD70
-rw-r--r--tensorflow/python/data/kernel_tests/iterator_ops_test.py4
-rw-r--r--tensorflow/python/estimator/estimator.py13
-rw-r--r--tensorflow/python/estimator/estimator_test.py4
-rw-r--r--tensorflow/python/estimator/keras.py3
-rw-r--r--tensorflow/python/tools/freeze_graph.py3
-rw-r--r--tensorflow/python/training/checkpoint_management.py406
-rw-r--r--tensorflow/python/training/checkpoint_management_test.py316
-rw-r--r--tensorflow/python/training/checkpoint_utils.py3
-rw-r--r--tensorflow/python/training/checkpointable/BUILD4
-rw-r--r--tensorflow/python/training/checkpointable/util_test.py16
-rw-r--r--tensorflow/python/training/monitored_session_test.py5
-rw-r--r--tensorflow/python/training/saver.py395
-rw-r--r--tensorflow/python/training/saver_test.py460
-rw-r--r--tensorflow/python/training/session_manager.py6
-rw-r--r--tensorflow/python/training/session_manager_test.py5
-rw-r--r--tensorflow/python/training/supervisor_test.py3
-rw-r--r--tensorflow/python/training/training.py12
35 files changed, 1011 insertions, 817 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py
index 30a993b1f7..77148aceec 100644
--- a/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py
@@ -28,6 +28,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
+from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import saver as saver_lib
from tensorflow.python.training import training_util
@@ -55,7 +56,7 @@ class CheckpointInputPipelineHookTest(test.TestCase):
def _read_vars(self, model_dir):
"""Returns (global_step, latest_feature)."""
with ops.Graph().as_default() as g:
- ckpt_path = saver_lib.latest_checkpoint(model_dir)
+ ckpt_path = checkpoint_management.latest_checkpoint(model_dir)
meta_filename = ckpt_path + '.meta'
saver_lib.import_meta_graph(meta_filename)
saver = saver_lib.Saver()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/dataset_serialization_test_base.py b/tensorflow/contrib/data/python/kernel_tests/serialization/dataset_serialization_test_base.py
index 393f08850b..3ed4dfb729 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/dataset_serialization_test_base.py
+++ b/tensorflow/contrib/data/python/kernel_tests/serialization/dataset_serialization_test_base.py
@@ -32,6 +32,7 @@ from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
+from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import saver as saver_lib
from tensorflow.python.util import nest
@@ -655,7 +656,7 @@ class DatasetSerializationTestBase(test.TestCase):
return os.path.join(self.get_temp_dir(), "iterator")
def _latest_ckpt(self):
- return saver_lib.latest_checkpoint(self.get_temp_dir())
+ return checkpoint_management.latest_checkpoint(self.get_temp_dir())
def _save(self, sess, saver):
saver.save(sess, self._ckpt_path())
diff --git a/tensorflow/contrib/data/python/ops/iterator_ops.py b/tensorflow/contrib/data/python/ops/iterator_ops.py
index 0d71be6601..d2c1d0d362 100644
--- a/tensorflow/contrib/data/python/ops/iterator_ops.py
+++ b/tensorflow/contrib/data/python/ops/iterator_ops.py
@@ -20,6 +20,7 @@ from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.framework import ops
from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.training import basic_session_run_hooks
+from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import saver as saver_lib
from tensorflow.python.training import session_run_hook
@@ -206,7 +207,7 @@ class CheckpointInputPipelineHook(session_run_hook.SessionRunHook):
# Check if there is an existing checkpoint. If so, restore from it.
# pylint: disable=protected-access
- latest_checkpoint_path = saver_lib.latest_checkpoint(
+ latest_checkpoint_path = checkpoint_management.latest_checkpoint(
self._checkpoint_saver_hook._checkpoint_dir,
latest_filename=self._latest_filename)
if latest_checkpoint_path:
diff --git a/tensorflow/contrib/eager/python/datasets_test.py b/tensorflow/contrib/eager/python/datasets_test.py
index 2917eaac97..a753d77580 100644
--- a/tensorflow/contrib/eager/python/datasets_test.py
+++ b/tensorflow/contrib/eager/python/datasets_test.py
@@ -37,7 +37,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import script_ops
-from tensorflow.python.training import saver
+from tensorflow.python.training import checkpoint_management
from tensorflow.python.training.checkpointable import util as checkpointable_utils
@@ -314,7 +314,8 @@ class IteratorTest(test.TestCase):
for i in range(5):
iterator = datasets.Iterator(dataset)
checkpoint = checkpointable_utils.Checkpoint(iterator=iterator)
- checkpoint.restore(saver.latest_checkpoint(checkpoint_directory))
+ checkpoint.restore(checkpoint_management.latest_checkpoint(
+ checkpoint_directory))
for j in range(2):
self.assertEqual(i * 2 + j, iterator.get_next().numpy())
checkpoint.save(file_prefix=checkpoint_prefix)
diff --git a/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py b/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py
index 8ac553e0ae..d18a097063 100644
--- a/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py
+++ b/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py
@@ -36,7 +36,7 @@ from third_party.examples.eager.spinn import spinn
from tensorflow.contrib.summary import summary_test_util
from tensorflow.python.eager import test
from tensorflow.python.framework import test_util
-from tensorflow.python.training import saver
+from tensorflow.python.training import checkpoint_management
from tensorflow.python.training.checkpointable import util as checkpointable_utils
# pylint: enable=g-bad-import-order
@@ -422,7 +422,7 @@ class SpinnTest(test_util.TensorFlowTestCase):
# 5. Verify that checkpoints exist and contains all the expected variables.
self.assertTrue(glob.glob(os.path.join(config.logdir, "ckpt*")))
object_graph = checkpointable_utils.object_metadata(
- saver.latest_checkpoint(config.logdir))
+ checkpoint_management.latest_checkpoint(config.logdir))
ckpt_variable_names = set()
for node in object_graph.nodes:
for attribute in node.attributes:
diff --git a/tensorflow/contrib/framework/python/framework/checkpoint_utils.py b/tensorflow/contrib/framework/python/framework/checkpoint_utils.py
index 9e356dd965..e7184a01fb 100644
--- a/tensorflow/contrib/framework/python/framework/checkpoint_utils.py
+++ b/tensorflow/contrib/framework/python/framework/checkpoint_utils.py
@@ -27,7 +27,7 @@ from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops import variables
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
-from tensorflow.python.training import saver
+from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import training as train
__all__ = [
@@ -40,7 +40,7 @@ __all__ = [
def _get_checkpoint_filename(filepattern):
"""Returns checkpoint filename given directory or specific filepattern."""
if gfile.IsDirectory(filepattern):
- return saver.latest_checkpoint(filepattern)
+ return checkpoint_management.latest_checkpoint(filepattern)
return filepattern
diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py
index 7a026a15e4..c1de42782e 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py
@@ -72,6 +72,7 @@ from tensorflow.python.saved_model import builder as saved_model_builder
from tensorflow.python.saved_model import tag_constants
from tensorflow.python.summary import summary as core_summary
from tensorflow.python.training import basic_session_run_hooks
+from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import device_setter
from tensorflow.python.training import monitored_session
from tensorflow.python.training import saver
@@ -891,7 +892,7 @@ class BaseEstimator(sklearn.BaseEstimator, evaluable.Evaluable,
# Check that model has been trained (if nothing has been set explicitly).
if not checkpoint_path:
- latest_path = saver.latest_checkpoint(self._model_dir)
+ latest_path = checkpoint_management.latest_checkpoint(self._model_dir)
if not latest_path:
raise NotFittedError(
"Couldn't find trained model at %s." % self._model_dir)
@@ -956,7 +957,7 @@ class BaseEstimator(sklearn.BaseEstimator, evaluable.Evaluable,
as_iterable=True,
iterate_batches=False):
# Check that model has been trained.
- checkpoint_path = saver.latest_checkpoint(self._model_dir)
+ checkpoint_path = checkpoint_management.latest_checkpoint(self._model_dir)
if not checkpoint_path:
raise NotFittedError(
"Couldn't find trained model at %s." % self._model_dir)
@@ -1364,7 +1365,7 @@ class Estimator(BaseEstimator):
if not checkpoint_path:
# Locate the latest checkpoint
- checkpoint_path = saver.latest_checkpoint(self._model_dir)
+ checkpoint_path = checkpoint_management.latest_checkpoint(self._model_dir)
if not checkpoint_path:
raise NotFittedError(
"Couldn't find trained model at %s." % self._model_dir)
diff --git a/tensorflow/contrib/learn/python/learn/experiment.py b/tensorflow/contrib/learn/python/learn/experiment.py
index f8a3709ee5..08e907a608 100644
--- a/tensorflow/contrib/learn/python/learn/experiment.py
+++ b/tensorflow/contrib/learn/python/learn/experiment.py
@@ -41,7 +41,7 @@ from tensorflow.python.estimator import estimator as core_estimator
from tensorflow.python.framework import ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import basic_session_run_hooks
-from tensorflow.python.training import saver
+from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import server_lib
from tensorflow.python.util import compat
from tensorflow.python.util import function_utils
@@ -95,7 +95,7 @@ class _EvalAndExportListener(basic_session_run_hooks.CheckpointSaverListener):
# Load and cache the path of the most recent checkpoint to avoid duplicate
# searches on GCS.
logging.info("Checking for checkpoint in %s", self._model_dir)
- latest_path = saver.latest_checkpoint(self._model_dir)
+ latest_path = checkpoint_management.latest_checkpoint(self._model_dir)
if not latest_path:
logging.warning("Skipping evaluation and export since model has not been "
@@ -516,7 +516,8 @@ class Experiment(object):
start = time.time()
error_msg = None
- latest_path = saver.latest_checkpoint(self._estimator.model_dir)
+ latest_path = checkpoint_management.latest_checkpoint(
+ self._estimator.model_dir)
if not latest_path:
error_msg = ("Estimator is not fitted yet. "
"Will start an evaluation when a checkpoint is ready.")
@@ -778,7 +779,8 @@ class Experiment(object):
saving_listeners=self._saving_listeners)
logging.info("Evaluating model now.")
- latest_checkpoint = saver.latest_checkpoint(self._estimator.model_dir)
+ latest_checkpoint = checkpoint_management.latest_checkpoint(
+ self._estimator.model_dir)
eval_result = self._call_evaluate(
input_fn=self._eval_input_fn,
steps=self._eval_steps,
diff --git a/tensorflow/contrib/learn/python/learn/graph_actions_test.py b/tensorflow/contrib/learn/python/learn/graph_actions_test.py
index 0d039d593b..df156da3f4 100644
--- a/tensorflow/contrib/learn/python/learn/graph_actions_test.py
+++ b/tensorflow/contrib/learn/python/learn/graph_actions_test.py
@@ -35,6 +35,7 @@ from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.summary import summary
+from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import saver as saver_lib
@@ -124,7 +125,7 @@ class GraphActionsTest(test.TestCase):
# TODO(ptucker): Test number and contents of checkpoint files.
def _assert_ckpt(self, output_dir, expected=True):
- ckpt_state = saver_lib.get_checkpoint_state(output_dir)
+ ckpt_state = checkpoint_management.get_checkpoint_state(output_dir)
if expected:
pattern = '%s/model.ckpt-.*' % output_dir
primary_ckpt_path = ckpt_state.model_checkpoint_path
@@ -434,7 +435,7 @@ class GraphActionsTrainTest(test.TestCase):
# TODO(ptucker): Test number and contents of checkpoint files.
def _assert_ckpt(self, output_dir, expected=True):
- ckpt_state = saver_lib.get_checkpoint_state(output_dir)
+ ckpt_state = checkpoint_management.get_checkpoint_state(output_dir)
if expected:
pattern = '%s/model.ckpt-.*' % output_dir
primary_ckpt_path = ckpt_state.model_checkpoint_path
diff --git a/tensorflow/contrib/learn/python/learn/monitors.py b/tensorflow/contrib/learn/python/learn/monitors.py
index 77f7c73d54..3d691d4340 100644
--- a/tensorflow/contrib/learn/python/learn/monitors.py
+++ b/tensorflow/contrib/learn/python/learn/monitors.py
@@ -51,7 +51,7 @@ from tensorflow.python.estimator import estimator as core_estimator
from tensorflow.python.framework import ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.summary import summary as core_summary
-from tensorflow.python.training import saver as saver_lib
+from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import session_run_hook
from tensorflow.python.training import training_util
from tensorflow.python.util import deprecation
@@ -735,7 +735,8 @@ class ValidationMonitor(EveryN):
return False
self._last_checkpoint_check_time = current_time
# Check that we are not running evaluation on the same checkpoint.
- latest_path = saver_lib.latest_checkpoint(self._estimator.model_dir)
+ latest_path = checkpoint_management.latest_checkpoint(
+ self._estimator.model_dir)
if latest_path is None:
logging.debug("Skipping evaluation since model has not been saved yet "
"at step %d.", step)
@@ -1059,7 +1060,8 @@ class ExportMonitor(EveryN):
def end(self, session=None):
super(ExportMonitor, self).end(session=session)
- latest_path = saver_lib.latest_checkpoint(self._estimator.model_dir)
+ latest_path = checkpoint_management.latest_checkpoint(
+ self._estimator.model_dir)
if latest_path is None:
logging.info("Skipping export at the end since model has not been saved "
"yet.")
diff --git a/tensorflow/contrib/learn/python/learn/monitors_test.py b/tensorflow/contrib/learn/python/learn/monitors_test.py
index 5c34d0ddb0..ff1da32c21 100644
--- a/tensorflow/contrib/learn/python/learn/monitors_test.py
+++ b/tensorflow/contrib/learn/python/learn/monitors_test.py
@@ -39,9 +39,9 @@ from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.summary import summary
+from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import gradient_descent
from tensorflow.python.training import monitored_session
-from tensorflow.python.training import saver
from tensorflow.python.training import training_util
@@ -317,7 +317,7 @@ class MonitorsTest(test.TestCase):
self._run_monitor(monitor)
@test.mock.patch.object(estimators, 'Estimator', autospec=True)
- @test.mock.patch.object(saver, 'latest_checkpoint')
+ @test.mock.patch.object(checkpoint_management, 'latest_checkpoint')
def test_validation_monitor_no_ckpt(self, mock_latest_checkpoint,
mock_estimator_class):
estimator = mock_estimator_class()
@@ -336,7 +336,7 @@ class MonitorsTest(test.TestCase):
mock_latest_checkpoint.assert_called_with(model_dir)
@test.mock.patch.object(estimators, 'Estimator', autospec=True)
- @test.mock.patch.object(saver, 'latest_checkpoint')
+ @test.mock.patch.object(checkpoint_management, 'latest_checkpoint')
def test_validation_monitor_no_early_stopping_rounds(self,
mock_latest_checkpoint,
mock_estimator_class):
@@ -356,7 +356,7 @@ class MonitorsTest(test.TestCase):
self._assert_validation_monitor(monitor)
@test.mock.patch.object(estimators, 'Estimator', autospec=True)
- @test.mock.patch.object(saver, 'latest_checkpoint')
+ @test.mock.patch.object(checkpoint_management, 'latest_checkpoint')
def test_validation_monitor_invalid_metric(self, mock_latest_checkpoint,
mock_estimator_class):
estimator = mock_estimator_class()
@@ -375,7 +375,7 @@ class MonitorsTest(test.TestCase):
self._run_monitor(monitor, num_epochs=1, num_steps_per_epoch=1)
@test.mock.patch.object(estimators, 'Estimator', autospec=True)
- @test.mock.patch.object(saver, 'latest_checkpoint')
+ @test.mock.patch.object(checkpoint_management, 'latest_checkpoint')
def test_validation_monitor(self, mock_latest_checkpoint,
mock_estimator_class):
estimator = mock_estimator_class()
@@ -464,7 +464,7 @@ class MonitorsTest(test.TestCase):
monitor.epoch_end(epoch=0)
monitor.end()
- @test.mock.patch.object(saver, 'latest_checkpoint')
+ @test.mock.patch.object(checkpoint_management, 'latest_checkpoint')
def test_validation_monitor_with_core_estimator(self, mock_latest_checkpoint):
estimator = test.mock.Mock(spec=core_estimator.Estimator)
model_dir = 'model/dir'
@@ -495,7 +495,7 @@ class MonitorsTest(test.TestCase):
expected_best_metrics={'loss': 42.0, 'auc': 0.5})
monitor.post_step(step=step, session=None)
- @test.mock.patch.object(saver, 'latest_checkpoint')
+ @test.mock.patch.object(checkpoint_management, 'latest_checkpoint')
def test_validation_monitor_fail_with_core_estimator_and_metrics(
self, mock_latest_checkpoint):
estimator = test.mock.Mock(spec=core_estimator.Estimator)
diff --git a/tensorflow/contrib/learn/python/learn/utils/export.py b/tensorflow/contrib/learn/python/learn/utils/export.py
index 3eacac7a3d..0144b93814 100644
--- a/tensorflow/contrib/learn/python/learn/utils/export.py
+++ b/tensorflow/contrib/learn/python/learn/utils/export.py
@@ -35,6 +35,7 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import saver as tf_saver
from tensorflow.python.training import training_util
@@ -298,7 +299,8 @@ def _export_estimator(estimator,
# If checkpoint_path is specified, use the specified checkpoint path.
checkpoint_path = (checkpoint_path or
- tf_saver.latest_checkpoint(estimator._model_dir))
+ checkpoint_management.latest_checkpoint(
+ estimator._model_dir))
with ops.Graph().as_default() as g:
training_util.create_global_step(g)
diff --git a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py
index f8106d1e4a..66af6833da 100644
--- a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py
+++ b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py
@@ -55,7 +55,7 @@ from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model import signature_def_utils
from tensorflow.python.summary import summary_iterator
-from tensorflow.python.training import saver
+from tensorflow.python.training import checkpoint_management
from tensorflow.python.util import compat
from tensorflow.python.util.deprecation import deprecated
@@ -714,7 +714,8 @@ def make_best_model_export_strategy(
# as soon as contrib is cleaned up and we can thus be sure that
# estimator is a tf.estimator.Estimator and not a
# tf.contrib.learn.Estimator
- checkpoint_path = saver.latest_checkpoint(estimator.model_dir)
+ checkpoint_path = checkpoint_management.latest_checkpoint(
+ estimator.model_dir)
export_checkpoint_path, export_eval_result = best_model_selector.update(
checkpoint_path, eval_result)
diff --git a/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py b/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py
index 06ab58188a..28a531dfec 100644
--- a/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py
+++ b/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py
@@ -41,6 +41,7 @@ from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import template
from tensorflow.python.ops import variable_scope
+from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import saver as core_saver
from tensorflow.python.training import training_util
from tensorflow.python.training.checkpointable import tracking
@@ -278,7 +279,8 @@ class CheckpointingTests(test.TestCase):
root = util.Checkpoint(
optimizer=optimizer, model=model,
optimizer_step=training_util.get_or_create_global_step())
- root.restore(core_saver.latest_checkpoint(checkpoint_directory))
+ root.restore(checkpoint_management.latest_checkpoint(
+ checkpoint_directory))
for _ in range(num_training_steps):
# TODO(allenl): Use a Dataset and serialize/checkpoint it.
input_value = constant_op.constant([[3.]])
@@ -306,7 +308,8 @@ class CheckpointingTests(test.TestCase):
train_op = optimizer.minimize(
model(input_value),
global_step=root.global_step)
- checkpoint_path = core_saver.latest_checkpoint(checkpoint_directory)
+ checkpoint_path = checkpoint_management.latest_checkpoint(
+ checkpoint_directory)
with self.test_session(graph=ops.get_default_graph()) as session:
status = root.restore(save_path=checkpoint_path)
status.initialize_or_restore(session=session)
@@ -339,7 +342,8 @@ class CheckpointingTests(test.TestCase):
root = util.Checkpoint(
optimizer=optimizer, model=model,
global_step=training_util.get_or_create_global_step())
- checkpoint_path = core_saver.latest_checkpoint(checkpoint_directory)
+ checkpoint_path = checkpoint_management.latest_checkpoint(
+ checkpoint_directory)
status = root.restore(save_path=checkpoint_path)
input_value = constant_op.constant([[3.]])
train_fn = functools.partial(
@@ -372,7 +376,8 @@ class CheckpointingTests(test.TestCase):
root = util.Checkpoint(
optimizer=optimizer, model=model,
global_step=training_util.get_or_create_global_step())
- checkpoint_path = core_saver.latest_checkpoint(checkpoint_directory)
+ checkpoint_path = checkpoint_management.latest_checkpoint(
+ checkpoint_directory)
status = root.restore(save_path=checkpoint_path)
def train_fn():
@function.defun
diff --git a/tensorflow/contrib/predictor/contrib_estimator_predictor.py b/tensorflow/contrib/predictor/contrib_estimator_predictor.py
index af3b2ad1b5..c2166594e5 100644
--- a/tensorflow/contrib/predictor/contrib_estimator_predictor.py
+++ b/tensorflow/contrib/predictor/contrib_estimator_predictor.py
@@ -22,8 +22,8 @@ from __future__ import print_function
from tensorflow.contrib.learn.python.learn.utils import saved_model_export_utils
from tensorflow.contrib.predictor import predictor
from tensorflow.python.framework import ops
+from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import monitored_session
-from tensorflow.python.training import saver
class ContribEstimatorPredictor(predictor.Predictor):
@@ -57,7 +57,8 @@ class ContribEstimatorPredictor(predictor.Predictor):
# pylint: disable=protected-access
model_fn_ops = estimator._get_predict_ops(input_fn_ops.features)
# pylint: enable=protected-access
- checkpoint_path = saver.latest_checkpoint(estimator.model_dir)
+ checkpoint_path = checkpoint_management.latest_checkpoint(
+ estimator.model_dir)
self._session = monitored_session.MonitoredSession(
session_creator=monitored_session.ChiefSessionCreator(
config=config,
diff --git a/tensorflow/contrib/training/python/training/evaluation.py b/tensorflow/contrib/training/python/training/evaluation.py
index f7fd66d33f..01bac891da 100644
--- a/tensorflow/contrib/training/python/training/evaluation.py
+++ b/tensorflow/contrib/training/python/training/evaluation.py
@@ -142,9 +142,9 @@ from tensorflow.python.ops import state_ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.summary import summary
from tensorflow.python.training import basic_session_run_hooks
+from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import evaluation
from tensorflow.python.training import monitored_session
-from tensorflow.python.training import saver as tf_saver
from tensorflow.python.training import session_run_hook
from tensorflow.python.training import training_util
@@ -189,7 +189,7 @@ def wait_for_new_checkpoint(checkpoint_dir,
logging.info('Waiting for new checkpoint at %s', checkpoint_dir)
stop_time = time.time() + timeout if timeout is not None else None
while True:
- checkpoint_path = tf_saver.latest_checkpoint(checkpoint_dir)
+ checkpoint_path = checkpoint_management.latest_checkpoint(checkpoint_dir)
if checkpoint_path is None or checkpoint_path == last_checkpoint:
if stop_time is not None and time.time() + seconds_to_sleep > stop_time:
return None
diff --git a/tensorflow/contrib/training/python/training/training_test.py b/tensorflow/contrib/training/python/training/training_test.py
index 4877c010fa..94cf7788b2 100644
--- a/tensorflow/contrib/training/python/training/training_test.py
+++ b/tensorflow/contrib/training/python/training/training_test.py
@@ -36,6 +36,7 @@ from tensorflow.python.ops.losses import losses
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
from tensorflow.python.training import basic_session_run_hooks
+from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import gradient_descent
from tensorflow.python.training import monitored_session
from tensorflow.python.training import saver as saver_lib
@@ -421,7 +422,7 @@ class TrainTest(test.TestCase):
train_op = self.create_train_op()
model_variables = variables_lib2.global_variables()
- model_path = saver_lib.latest_checkpoint(logdir1)
+ model_path = checkpoint_management.latest_checkpoint(logdir1)
assign_fn = variables_lib.assign_from_checkpoint_fn(
model_path, model_variables)
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 2b8110a999..7cf8ddb1d9 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -3216,6 +3216,7 @@ py_library(
"training/checkpointable/**/*.py",
# The following targets have their own build rules (same name as the
# file):
+ "training/checkpoint_management.py",
"training/saveable_object.py",
"training/saver.py",
"training/training_util.py",
@@ -3223,8 +3224,10 @@ py_library(
),
srcs_version = "PY2AND3",
deps = [
+ "saver",
":array_ops",
":array_ops_gen",
+ ":checkpoint_management",
":checkpoint_ops_gen",
":client",
":control_flow_ops",
@@ -3236,25 +3239,20 @@ py_library(
":framework_ops",
":gradients",
":init_ops",
- ":distribute",
":io_ops",
- ":io_ops_gen",
":layers_base",
- ":lib",
":lookup_ops",
":math_ops",
":platform",
- ":protos_all_py",
":pywrap_tensorflow",
":random_ops",
":resource_variable_ops",
":resources",
- "saver",
- ":saveable_object",
":sdca_ops",
+ ":session",
":sparse_ops",
+ ":sparse_tensor",
":state_ops",
- ":string_ops",
":summary",
":training_ops_gen",
":training_util",
@@ -3264,6 +3262,7 @@ py_library(
"//third_party/py/numpy",
"@six_archive//:six",
"//tensorflow/core:protos_all_py",
+ "//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/eager:backprop",
"//tensorflow/python/eager:context",
# `layers` dependency only exists due to the use of a small utility.
@@ -3281,11 +3280,25 @@ py_library(
)
py_library(
+ name = "checkpoint_management",
+ srcs = ["training/checkpoint_management.py"],
+ deps = [
+ ":errors",
+ ":lib",
+ ":platform",
+ ":protos_all_py",
+ ":util",
+ "//tensorflow/core:protos_all_py",
+ ],
+)
+
+py_library(
name = "saver",
srcs = ["training/saver.py"],
srcs_version = "PY2AND3",
deps = [
":array_ops",
+ ":checkpoint_management",
":constant_op",
":control_flow_ops",
":device",
@@ -3294,9 +3307,7 @@ py_library(
":framework_ops",
":io_ops",
":io_ops_gen",
- ":lib",
":platform",
- ":protos_all_py",
":pywrap_tensorflow",
":resource_variable_ops",
":saveable_object",
@@ -4423,6 +4434,42 @@ cuda_py_test(
tags = ["multi_gpu"],
)
+cuda_py_test(
+ name = "checkpoint_management_test",
+ size = "small",
+ srcs = [
+ "training/checkpoint_management_test.py",
+ ],
+ additional_deps = [
+ ":array_ops",
+ ":client_testlib",
+ ":control_flow_ops",
+ ":data_flow_ops",
+ ":errors",
+ ":gradients",
+ ":math_ops",
+ ":nn_grad",
+ ":nn_ops",
+ ":saver_test_utils",
+ ":partitioned_variables",
+ ":platform",
+ ":platform_test",
+ ":pywrap_tensorflow",
+ ":random_ops",
+ ":resource_variable_ops",
+ ":sparse_ops",
+ ":summary",
+ ":training",
+ ":util",
+ ":variable_scope",
+ ":variables",
+ "//third_party/py/numpy",
+ "@six_archive//:six",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
+
py_test(
name = "saver_large_variable_test",
size = "medium",
@@ -4489,6 +4536,7 @@ tf_py_test(
srcs = ["training/supervisor_test.py"],
additional_deps = [
":array_ops",
+ ":checkpoint_management",
":client_testlib",
":errors",
":framework",
@@ -4496,6 +4544,7 @@ tf_py_test(
":io_ops",
":parsing_ops",
":platform",
+ ":saver",
":summary",
":training",
":variables",
@@ -4609,10 +4658,13 @@ py_test(
tags = ["notsan"], # b/67945581
deps = [
":array_ops",
+ ":checkpoint_management",
":client_testlib",
":control_flow_ops",
":errors",
":framework_for_generated_wrappers",
+ ":resource_variable_ops",
+ ":saver",
":session",
":state_ops",
":summary",
diff --git a/tensorflow/python/data/kernel_tests/iterator_ops_test.py b/tensorflow/python/data/kernel_tests/iterator_ops_test.py
index dd39262f9b..352424514e 100644
--- a/tensorflow/python/data/kernel_tests/iterator_ops_test.py
+++ b/tensorflow/python/data/kernel_tests/iterator_ops_test.py
@@ -47,7 +47,7 @@ from tensorflow.python.ops import parsing_ops
from tensorflow.python.ops import script_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
-from tensorflow.python.training import saver
+from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import server_lib
from tensorflow.python.training.checkpointable import util as checkpointable_utils
from tensorflow.python.util import compat
@@ -877,7 +877,7 @@ class IteratorCheckpointingTest(test.TestCase):
checkpoint = checkpointable_utils.Checkpoint(iterator=iterator)
for i in range(5):
with self.test_session() as sess:
- checkpoint.restore(saver.latest_checkpoint(
+ checkpoint.restore(checkpoint_management.latest_checkpoint(
checkpoint_directory)).initialize_or_restore(sess)
for j in range(2):
self.assertEqual(i * 2 + j, sess.run(get_next))
diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py
index a4e735a092..43deb8bc6c 100644
--- a/tensorflow/python/estimator/estimator.py
+++ b/tensorflow/python/estimator/estimator.py
@@ -53,6 +53,7 @@ from tensorflow.python.saved_model import builder as saved_model_builder
from tensorflow.python.saved_model import constants
from tensorflow.python.summary import summary
from tensorflow.python.summary.writer import writer_cache
+from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import device_setter
from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.training import evaluation
@@ -268,7 +269,7 @@ class Estimator(object):
found.
"""
with context.graph_mode():
- return saver.latest_checkpoint(self.model_dir)
+ return checkpoint_management.latest_checkpoint(self.model_dir)
def train(self,
input_fn,
@@ -417,7 +418,7 @@ class Estimator(object):
# Check that model has been trained (if nothing has been set explicitly).
if not checkpoint_path:
- latest_path = saver.latest_checkpoint(self._model_dir)
+ latest_path = checkpoint_management.latest_checkpoint(self._model_dir)
if not latest_path:
logging.info('Could not find trained model in model_dir: {}, running '
'initialization to evaluate.'.format(self._model_dir))
@@ -504,7 +505,8 @@ class Estimator(object):
hooks = _check_hooks_type(hooks)
# Check that model has been trained.
if not checkpoint_path:
- checkpoint_path = saver.latest_checkpoint(self._model_dir)
+ checkpoint_path = checkpoint_management.latest_checkpoint(
+ self._model_dir)
if not checkpoint_path:
logging.info('Could not find trained model in model_dir: {}, running '
'initialization to predict.'.format(self._model_dir))
@@ -769,7 +771,8 @@ class Estimator(object):
with context.graph_mode():
if not checkpoint_path:
# Locate the latest checkpoint
- checkpoint_path = saver.latest_checkpoint(self._model_dir)
+ checkpoint_path = checkpoint_management.latest_checkpoint(
+ self._model_dir)
if not checkpoint_path:
raise ValueError("Couldn't find trained model at %s." % self._model_dir)
@@ -1626,7 +1629,7 @@ def _combine_distributed_scaffold(grouped_scaffold, distribution):
def _check_checkpoint_available(model_dir):
- latest_path = saver.latest_checkpoint(model_dir)
+ latest_path = checkpoint_management.latest_checkpoint(model_dir)
if not latest_path:
raise ValueError(
'Could not find trained model in model_dir: {}.'.format(model_dir))
diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py
index 68fc5bcadf..e8552092e0 100644
--- a/tensorflow/python/estimator/estimator_test.py
+++ b/tensorflow/python/estimator/estimator_test.py
@@ -69,6 +69,7 @@ from tensorflow.python.summary import summary
from tensorflow.python.summary import summary_iterator
from tensorflow.python.summary.writer import writer_cache
from tensorflow.python.training import basic_session_run_hooks
+from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import checkpoint_state_pb2
from tensorflow.python.training import saver
from tensorflow.python.training import saver_test_utils
@@ -1548,7 +1549,8 @@ class EstimatorPredictTest(test.TestCase):
next(
est.predict(
dummy_input_fn,
- checkpoint_path=saver.latest_checkpoint('fakedir')))
+ checkpoint_path=
+ checkpoint_management.latest_checkpoint('fakedir')))
def test_tensor_predictions(self):
diff --git a/tensorflow/python/estimator/keras.py b/tensorflow/python/estimator/keras.py
index 079560c495..c63deb8f4d 100644
--- a/tensorflow/python/estimator/keras.py
+++ b/tensorflow/python/estimator/keras.py
@@ -42,6 +42,7 @@ from tensorflow.python.ops import metrics as metrics_module
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.saved_model import signature_constants
+from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.training import saver as saver_lib
from tensorflow.python.training import training_util
@@ -442,7 +443,7 @@ def _save_first_checkpoint(keras_model, custom_objects, config):
# save checkpoint into subdirectory to allow warm start
keras_model_dir = os.path.join(config.model_dir, 'keras')
# Load weights and save to checkpoint if there is no checkpoint
- latest_path = saver_lib.latest_checkpoint(keras_model_dir)
+ latest_path = checkpoint_management.latest_checkpoint(keras_model_dir)
if not latest_path:
keras_weights = None
if _any_weight_initialized(keras_model):
diff --git a/tensorflow/python/tools/freeze_graph.py b/tensorflow/python/tools/freeze_graph.py
index 4349699a94..130fe70beb 100644
--- a/tensorflow/python/tools/freeze_graph.py
+++ b/tensorflow/python/tools/freeze_graph.py
@@ -55,6 +55,7 @@ from tensorflow.python.platform import gfile
from tensorflow.python.saved_model import loader
from tensorflow.python.saved_model import tag_constants
from tensorflow.python.tools import saved_model_utils
+from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import saver as saver_lib
@@ -78,7 +79,7 @@ def freeze_graph_with_def_protos(input_graph_def,
# 'input_checkpoint' may be a prefix if we're using Saver V2 format
if (not input_saved_model_dir and
- not saver_lib.checkpoint_exists(input_checkpoint)):
+ not checkpoint_management.checkpoint_exists(input_checkpoint)):
print("Input checkpoint '" + input_checkpoint + "' doesn't exist!")
return -1
diff --git a/tensorflow/python/training/checkpoint_management.py b/tensorflow/python/training/checkpoint_management.py
new file mode 100644
index 0000000000..aaddc015ed
--- /dev/null
+++ b/tensorflow/python/training/checkpoint_management.py
@@ -0,0 +1,406 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+# pylint: disable=invalid-name
+"""Save and restore variables."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os.path
+import re
+
+from google.protobuf import text_format
+
+from tensorflow.core.protobuf import saver_pb2
+from tensorflow.python.framework import errors
+from tensorflow.python.lib.io import file_io
+from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState
+from tensorflow.python.util.tf_export import tf_export
+
+
+def _GetCheckpointFilename(save_dir, latest_filename):
+ """Returns a filename for storing the CheckpointState.
+
+ Args:
+ save_dir: The directory for saving and restoring checkpoints.
+ latest_filename: Name of the file in 'save_dir' that is used
+ to store the CheckpointState.
+
+ Returns:
+ The path of the file that contains the CheckpointState proto.
+ """
+ if latest_filename is None:
+ latest_filename = "checkpoint"
+ return os.path.join(save_dir, latest_filename)
+
+
+@tf_export("train.generate_checkpoint_state_proto")
+def generate_checkpoint_state_proto(save_dir,
+ model_checkpoint_path,
+ all_model_checkpoint_paths=None):
+ """Generates a checkpoint state proto.
+
+ Args:
+ save_dir: Directory where the model was saved.
+ model_checkpoint_path: The checkpoint file.
+ all_model_checkpoint_paths: List of strings. Paths to all not-yet-deleted
+ checkpoints, sorted from oldest to newest. If this is a non-empty list,
+ the last element must be equal to model_checkpoint_path. These paths
+ are also saved in the CheckpointState proto.
+
+ Returns:
+ CheckpointState proto with model_checkpoint_path and
+ all_model_checkpoint_paths updated to either absolute paths or
+ relative paths to the current save_dir.
+ """
+ if all_model_checkpoint_paths is None:
+ all_model_checkpoint_paths = []
+
+ if (not all_model_checkpoint_paths or
+ all_model_checkpoint_paths[-1] != model_checkpoint_path):
+ logging.info("%s is not in all_model_checkpoint_paths. Manually adding it.",
+ model_checkpoint_path)
+ all_model_checkpoint_paths.append(model_checkpoint_path)
+
+ # Relative paths need to be rewritten to be relative to the "save_dir"
+ # if model_checkpoint_path already contains "save_dir".
+ if not os.path.isabs(save_dir):
+ if not os.path.isabs(model_checkpoint_path):
+ model_checkpoint_path = os.path.relpath(model_checkpoint_path, save_dir)
+ for i in range(len(all_model_checkpoint_paths)):
+ p = all_model_checkpoint_paths[i]
+ if not os.path.isabs(p):
+ all_model_checkpoint_paths[i] = os.path.relpath(p, save_dir)
+
+ coord_checkpoint_proto = CheckpointState(
+ model_checkpoint_path=model_checkpoint_path,
+ all_model_checkpoint_paths=all_model_checkpoint_paths)
+
+ return coord_checkpoint_proto
+
+
+@tf_export("train.update_checkpoint_state")
+def update_checkpoint_state(save_dir,
+ model_checkpoint_path,
+ all_model_checkpoint_paths=None,
+ latest_filename=None):
+ """Updates the content of the 'checkpoint' file.
+
+ This updates the checkpoint file containing a CheckpointState
+ proto.
+
+ Args:
+ save_dir: Directory where the model was saved.
+ model_checkpoint_path: The checkpoint file.
+ all_model_checkpoint_paths: List of strings. Paths to all not-yet-deleted
+ checkpoints, sorted from oldest to newest. If this is a non-empty list,
+ the last element must be equal to model_checkpoint_path. These paths
+ are also saved in the CheckpointState proto.
+ latest_filename: Optional name of the checkpoint file. Default to
+ 'checkpoint'.
+
+ Raises:
+ RuntimeError: If any of the model checkpoint paths conflict with the file
+ containing CheckpointSate.
+ """
+ update_checkpoint_state_internal(
+ save_dir=save_dir,
+ model_checkpoint_path=model_checkpoint_path,
+ all_model_checkpoint_paths=all_model_checkpoint_paths,
+ latest_filename=latest_filename,
+ save_relative_paths=False)
+
+
+def update_checkpoint_state_internal(save_dir,
+ model_checkpoint_path,
+ all_model_checkpoint_paths=None,
+ latest_filename=None,
+ save_relative_paths=False):
+ """Updates the content of the 'checkpoint' file.
+
+ This updates the checkpoint file containing a CheckpointState
+ proto.
+
+ Args:
+ save_dir: Directory where the model was saved.
+ model_checkpoint_path: The checkpoint file.
+ all_model_checkpoint_paths: List of strings. Paths to all not-yet-deleted
+ checkpoints, sorted from oldest to newest. If this is a non-empty list,
+ the last element must be equal to model_checkpoint_path. These paths
+ are also saved in the CheckpointState proto.
+ latest_filename: Optional name of the checkpoint file. Default to
+ 'checkpoint'.
+ save_relative_paths: If `True`, will write relative paths to the checkpoint
+ state file.
+
+ Raises:
+ RuntimeError: If any of the model checkpoint paths conflict with the file
+ containing CheckpointSate.
+ """
+ # Writes the "checkpoint" file for the coordinator for later restoration.
+ coord_checkpoint_filename = _GetCheckpointFilename(save_dir, latest_filename)
+ if save_relative_paths:
+ if os.path.isabs(model_checkpoint_path):
+ rel_model_checkpoint_path = os.path.relpath(
+ model_checkpoint_path, save_dir)
+ else:
+ rel_model_checkpoint_path = model_checkpoint_path
+ rel_all_model_checkpoint_paths = []
+ for p in all_model_checkpoint_paths:
+ if os.path.isabs(p):
+ rel_all_model_checkpoint_paths.append(os.path.relpath(p, save_dir))
+ else:
+ rel_all_model_checkpoint_paths.append(p)
+ ckpt = generate_checkpoint_state_proto(
+ save_dir,
+ rel_model_checkpoint_path,
+ all_model_checkpoint_paths=rel_all_model_checkpoint_paths)
+ else:
+ ckpt = generate_checkpoint_state_proto(
+ save_dir,
+ model_checkpoint_path,
+ all_model_checkpoint_paths=all_model_checkpoint_paths)
+
+ if coord_checkpoint_filename == ckpt.model_checkpoint_path:
+ raise RuntimeError("Save path '%s' conflicts with path used for "
+ "checkpoint state. Please use a different save path." %
+ model_checkpoint_path)
+
+ # Preventing potential read/write race condition by *atomically* writing to a
+ # file.
+ file_io.atomic_write_string_to_file(coord_checkpoint_filename,
+ text_format.MessageToString(ckpt))
+
+
+@tf_export("train.get_checkpoint_state")
+def get_checkpoint_state(checkpoint_dir, latest_filename=None):
+ """Returns CheckpointState proto from the "checkpoint" file.
+
+ If the "checkpoint" file contains a valid CheckpointState
+ proto, returns it.
+
+ Args:
+ checkpoint_dir: The directory of checkpoints.
+ latest_filename: Optional name of the checkpoint file. Default to
+ 'checkpoint'.
+
+ Returns:
+ A CheckpointState if the state was available, None
+ otherwise.
+
+ Raises:
+ ValueError: if the checkpoint read doesn't have model_checkpoint_path set.
+ """
+ ckpt = None
+ coord_checkpoint_filename = _GetCheckpointFilename(checkpoint_dir,
+ latest_filename)
+ f = None
+ try:
+ # Check that the file exists before opening it to avoid
+ # many lines of errors from colossus in the logs.
+ if file_io.file_exists(coord_checkpoint_filename):
+ file_content = file_io.read_file_to_string(
+ coord_checkpoint_filename)
+ ckpt = CheckpointState()
+ text_format.Merge(file_content, ckpt)
+ if not ckpt.model_checkpoint_path:
+ raise ValueError("Invalid checkpoint state loaded from "
+ + checkpoint_dir)
+ # For relative model_checkpoint_path and all_model_checkpoint_paths,
+ # prepend checkpoint_dir.
+ if not os.path.isabs(ckpt.model_checkpoint_path):
+ ckpt.model_checkpoint_path = os.path.join(checkpoint_dir,
+ ckpt.model_checkpoint_path)
+ for i in range(len(ckpt.all_model_checkpoint_paths)):
+ p = ckpt.all_model_checkpoint_paths[i]
+ if not os.path.isabs(p):
+ ckpt.all_model_checkpoint_paths[i] = os.path.join(checkpoint_dir, p)
+ except errors.OpError as e:
+ # It's ok if the file cannot be read
+ logging.warning("%s: %s", type(e).__name__, e)
+ logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename)
+ return None
+ except text_format.ParseError as e:
+ logging.warning("%s: %s", type(e).__name__, e)
+ logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename)
+ return None
+ finally:
+ if f:
+ f.close()
+ return ckpt
+
+
+def _prefix_to_checkpoint_path(prefix, format_version):
+ """Returns the pathname of a checkpoint file, given the checkpoint prefix.
+
+ For V1 checkpoint, simply returns the prefix itself (the data file). For V2,
+ returns the pathname to the index file.
+
+ Args:
+ prefix: a string, the prefix of a checkpoint.
+ format_version: the checkpoint format version that corresponds to the
+ prefix.
+ Returns:
+ The pathname of a checkpoint file, taking into account the checkpoint
+ format version.
+ """
+ if format_version == saver_pb2.SaverDef.V2:
+ return prefix + ".index" # The index file identifies a checkpoint.
+ return prefix # Just the data file.
+
+
+@tf_export("train.latest_checkpoint")
+def latest_checkpoint(checkpoint_dir, latest_filename=None):
+ """Finds the filename of latest saved checkpoint file.
+
+ Args:
+ checkpoint_dir: Directory where the variables were saved.
+ latest_filename: Optional name for the protocol buffer file that
+ contains the list of most recent checkpoint filenames.
+ See the corresponding argument to `Saver.save()`.
+
+ Returns:
+ The full path to the latest checkpoint or `None` if no checkpoint was found.
+ """
+ # Pick the latest checkpoint based on checkpoint state.
+ ckpt = get_checkpoint_state(checkpoint_dir, latest_filename)
+ if ckpt and ckpt.model_checkpoint_path:
+ # Look for either a V2 path or a V1 path, with priority for V2.
+ v2_path = _prefix_to_checkpoint_path(ckpt.model_checkpoint_path,
+ saver_pb2.SaverDef.V2)
+ v1_path = _prefix_to_checkpoint_path(ckpt.model_checkpoint_path,
+ saver_pb2.SaverDef.V1)
+ if file_io.get_matching_files(v2_path) or file_io.get_matching_files(
+ v1_path):
+ return ckpt.model_checkpoint_path
+ else:
+ logging.error("Couldn't match files for checkpoint %s",
+ ckpt.model_checkpoint_path)
+ return None
+
+
+@tf_export("train.checkpoint_exists")
+def checkpoint_exists(checkpoint_prefix):
+ """Checks whether a V1 or V2 checkpoint exists with the specified prefix.
+
+ This is the recommended way to check if a checkpoint exists, since it takes
+ into account the naming difference between V1 and V2 formats.
+
+ Args:
+ checkpoint_prefix: the prefix of a V1 or V2 checkpoint, with V2 taking
+ priority. Typically the result of `Saver.save()` or that of
+ `tf.train.latest_checkpoint()`, regardless of sharded/non-sharded or
+ V1/V2.
+ Returns:
+ A bool, true iff a checkpoint referred to by `checkpoint_prefix` exists.
+ """
+ pathname = _prefix_to_checkpoint_path(checkpoint_prefix,
+ saver_pb2.SaverDef.V2)
+ if file_io.get_matching_files(pathname):
+ return True
+ elif file_io.get_matching_files(checkpoint_prefix):
+ return True
+ else:
+ return False
+
+
+@tf_export("train.get_checkpoint_mtimes")
+def get_checkpoint_mtimes(checkpoint_prefixes):
+ """Returns the mtimes (modification timestamps) of the checkpoints.
+
+ Globs for the checkpoints pointed to by `checkpoint_prefixes`. If the files
+ exist, collect their mtime. Both V2 and V1 checkpoints are considered, in
+ that priority.
+
+ This is the recommended way to get the mtimes, since it takes into account
+ the naming difference between V1 and V2 formats.
+
+ Args:
+ checkpoint_prefixes: a list of checkpoint paths, typically the results of
+ `Saver.save()` or those of `tf.train.latest_checkpoint()`, regardless of
+ sharded/non-sharded or V1/V2.
+ Returns:
+ A list of mtimes (in microseconds) of the found checkpoints.
+ """
+ mtimes = []
+
+ def match_maybe_append(pathname):
+ fnames = file_io.get_matching_files(pathname)
+ if fnames:
+ mtimes.append(file_io.stat(fnames[0]).mtime_nsec / 1e9)
+ return True
+ return False
+
+ for checkpoint_prefix in checkpoint_prefixes:
+ # Tries V2's metadata file first.
+ pathname = _prefix_to_checkpoint_path(checkpoint_prefix,
+ saver_pb2.SaverDef.V2)
+ if match_maybe_append(pathname):
+ continue
+ # Otherwise, tries V1, where the prefix is the complete pathname.
+ match_maybe_append(checkpoint_prefix)
+
+ return mtimes
+
+
+@tf_export("train.remove_checkpoint")
+def remove_checkpoint(checkpoint_prefix,
+ checkpoint_format_version=saver_pb2.SaverDef.V2,
+ meta_graph_suffix="meta"):
+ """Removes a checkpoint given by `checkpoint_prefix`.
+
+ Args:
+ checkpoint_prefix: The prefix of a V1 or V2 checkpoint. Typically the result
+ of `Saver.save()` or that of `tf.train.latest_checkpoint()`, regardless of
+ sharded/non-sharded or V1/V2.
+ checkpoint_format_version: `SaverDef.CheckpointFormatVersion`, defaults to
+ `SaverDef.V2`.
+ meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'.
+ """
+ _delete_file_if_exists(
+ meta_graph_filename(checkpoint_prefix, meta_graph_suffix))
+ if checkpoint_format_version == saver_pb2.SaverDef.V2:
+ # V2 has a metadata file and some data files.
+ _delete_file_if_exists(checkpoint_prefix + ".index")
+ _delete_file_if_exists(checkpoint_prefix + ".data-?????-of-?????")
+ else:
+ # V1, Legacy. Exact match on the data file.
+ _delete_file_if_exists(checkpoint_prefix)
+
+
+def _delete_file_if_exists(filespec):
+ """Deletes files matching `filespec`."""
+ for pathname in file_io.get_matching_files(filespec):
+ file_io.delete_file(pathname)
+
+
+def meta_graph_filename(checkpoint_filename, meta_graph_suffix="meta"):
+ """Returns the meta graph filename.
+
+ Args:
+ checkpoint_filename: Name of the checkpoint file.
+ meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'.
+
+ Returns:
+ MetaGraph file name.
+ """
+ # If the checkpoint_filename is sharded, the checkpoint_filename could
+ # be of format model.ckpt-step#-?????-of-shard#. For example,
+ # model.ckpt-123456-?????-of-00005, or model.ckpt-123456-00001-of-00002.
+ basename = re.sub(r"-[\d\?]+-of-\d+$", "", checkpoint_filename)
+ suffixed_filename = ".".join([basename, meta_graph_suffix])
+ return suffixed_filename
diff --git a/tensorflow/python/training/checkpoint_management_test.py b/tensorflow/python/training/checkpoint_management_test.py
new file mode 100644
index 0000000000..4b31d0c613
--- /dev/null
+++ b/tensorflow/python/training/checkpoint_management_test.py
@@ -0,0 +1,316 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+"""Tests for tensorflow.python.training.saver.py."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import contextlib
+import os
+import shutil
+import tempfile
+
+from google.protobuf import text_format
+
+from tensorflow.core.protobuf import saver_pb2
+from tensorflow.python.framework import ops as ops_lib
+from tensorflow.python.lib.io import file_io
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import gfile
+from tensorflow.python.platform import test
+from tensorflow.python.training import checkpoint_management
+from tensorflow.python.training import saver as saver_module
+from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState
+
+
+class LatestCheckpointWithRelativePaths(test.TestCase):
+
+ @staticmethod
+ @contextlib.contextmanager
+ def tempWorkingDir(temppath):
+ cwd = os.getcwd()
+ os.chdir(temppath)
+ try:
+ yield
+ finally:
+ os.chdir(cwd)
+
+ @staticmethod
+ @contextlib.contextmanager
+ def tempDir():
+ tempdir = tempfile.mkdtemp()
+ try:
+ yield tempdir
+ finally:
+ shutil.rmtree(tempdir)
+
+ def testNameCollision(self):
+ # Make sure we have a clean directory to work in.
+ with self.tempDir() as tempdir:
+ # Jump to that directory until this test is done.
+ with self.tempWorkingDir(tempdir):
+ # Save training snapshots to a relative path.
+ traindir = "train/"
+ os.mkdir(traindir)
+ # Collides with the default name of the checkpoint state file.
+ filepath = os.path.join(traindir, "checkpoint")
+
+ with self.test_session() as sess:
+ unused_a = variables.Variable(0.0) # So that Saver saves something.
+ variables.global_variables_initializer().run()
+
+ # Should fail.
+ saver = saver_module.Saver(sharded=False)
+ with self.assertRaisesRegexp(ValueError, "collides with"):
+ saver.save(sess, filepath)
+
+ # Succeeds: the file will be named "checkpoint-<step>".
+ saver.save(sess, filepath, global_step=1)
+ self.assertIsNotNone(
+ checkpoint_management.latest_checkpoint(traindir))
+
+ # Succeeds: the file will be named "checkpoint-<i>-of-<n>".
+ saver = saver_module.Saver(sharded=True)
+ saver.save(sess, filepath)
+ self.assertIsNotNone(
+ checkpoint_management.latest_checkpoint(traindir))
+
+ # Succeeds: the file will be named "checkpoint-<step>-<i>-of-<n>".
+ saver = saver_module.Saver(sharded=True)
+ saver.save(sess, filepath, global_step=1)
+ self.assertIsNotNone(
+ checkpoint_management.latest_checkpoint(traindir))
+
+ def testRelativePath(self):
+ # Make sure we have a clean directory to work in.
+ with self.tempDir() as tempdir:
+
+ # Jump to that directory until this test is done.
+ with self.tempWorkingDir(tempdir):
+
+ # Save training snapshots to a relative path.
+ traindir = "train/"
+ os.mkdir(traindir)
+
+ filename = "snapshot"
+ filepath = os.path.join(traindir, filename)
+
+ with self.test_session() as sess:
+ # Build a simple graph.
+ v0 = variables.Variable(0.0)
+ inc = v0.assign_add(1.0)
+
+ save = saver_module.Saver({"v0": v0})
+
+ # Record a short training history.
+ variables.global_variables_initializer().run()
+ save.save(sess, filepath, global_step=0)
+ inc.eval()
+ save.save(sess, filepath, global_step=1)
+ inc.eval()
+ save.save(sess, filepath, global_step=2)
+
+ with self.test_session() as sess:
+ # Build a new graph with different initialization.
+ v0 = variables.Variable(-1.0)
+
+ # Create a new saver.
+ save = saver_module.Saver({"v0": v0})
+ variables.global_variables_initializer().run()
+
+ # Get the most recent checkpoint name from the training history file.
+ name = checkpoint_management.latest_checkpoint(traindir)
+ self.assertIsNotNone(name)
+
+ # Restore "v0" from that checkpoint.
+ save.restore(sess, name)
+ self.assertEqual(v0.eval(), 2.0)
+
+
+class CheckpointStateTest(test.TestCase):
+
+ def _get_test_dir(self, dirname):
+ test_dir = os.path.join(self.get_temp_dir(), dirname)
+ gfile.MakeDirs(test_dir)
+ return test_dir
+
+ def testAbsPath(self):
+ save_dir = self._get_test_dir("abs_paths")
+ abs_path = os.path.join(save_dir, "model-0")
+ ckpt = checkpoint_management.generate_checkpoint_state_proto(
+ save_dir, abs_path)
+ self.assertEqual(ckpt.model_checkpoint_path, abs_path)
+ self.assertTrue(os.path.isabs(ckpt.model_checkpoint_path))
+ self.assertEqual(len(ckpt.all_model_checkpoint_paths), 1)
+ self.assertEqual(ckpt.all_model_checkpoint_paths[-1], abs_path)
+
+ def testRelPath(self):
+ train_dir = "train"
+ model = os.path.join(train_dir, "model-0")
+ # model_checkpoint_path should have no "train" directory part.
+ new_rel_path = "model-0"
+ ckpt = checkpoint_management.generate_checkpoint_state_proto(
+ train_dir, model)
+ self.assertEqual(ckpt.model_checkpoint_path, new_rel_path)
+ self.assertEqual(len(ckpt.all_model_checkpoint_paths), 1)
+ self.assertEqual(ckpt.all_model_checkpoint_paths[-1], new_rel_path)
+
+ def testAllModelCheckpointPaths(self):
+ save_dir = self._get_test_dir("all_models_test")
+ abs_path = os.path.join(save_dir, "model-0")
+ for paths in [None, [], ["model-2"]]:
+ ckpt = checkpoint_management.generate_checkpoint_state_proto(
+ save_dir, abs_path, all_model_checkpoint_paths=paths)
+ self.assertEqual(ckpt.model_checkpoint_path, abs_path)
+ self.assertTrue(os.path.isabs(ckpt.model_checkpoint_path))
+ self.assertEqual(
+ len(ckpt.all_model_checkpoint_paths), len(paths) if paths else 1)
+ self.assertEqual(ckpt.all_model_checkpoint_paths[-1], abs_path)
+
+ def testUpdateCheckpointState(self):
+ save_dir = self._get_test_dir("update_checkpoint_state")
+ os.chdir(save_dir)
+ # Make a temporary train directory.
+ train_dir = "train"
+ os.mkdir(train_dir)
+ abs_path = os.path.join(save_dir, "model-0")
+ rel_path = os.path.join("train", "model-2")
+ checkpoint_management.update_checkpoint_state(
+ train_dir, rel_path, all_model_checkpoint_paths=[abs_path, rel_path])
+ ckpt = checkpoint_management.get_checkpoint_state(train_dir)
+ self.assertEqual(ckpt.model_checkpoint_path, rel_path)
+ self.assertEqual(len(ckpt.all_model_checkpoint_paths), 2)
+ self.assertEqual(ckpt.all_model_checkpoint_paths[-1], rel_path)
+ self.assertEqual(ckpt.all_model_checkpoint_paths[0], abs_path)
+
+ def testUpdateCheckpointStateSaveRelativePaths(self):
+ save_dir = self._get_test_dir("update_checkpoint_state")
+ os.chdir(save_dir)
+ abs_path2 = os.path.join(save_dir, "model-2")
+ rel_path2 = "model-2"
+ abs_path0 = os.path.join(save_dir, "model-0")
+ rel_path0 = "model-0"
+ checkpoint_management.update_checkpoint_state_internal(
+ save_dir=save_dir,
+ model_checkpoint_path=abs_path2,
+ all_model_checkpoint_paths=[rel_path0, abs_path2],
+ save_relative_paths=True)
+
+ # File should contain relative paths.
+ file_content = file_io.read_file_to_string(
+ os.path.join(save_dir, "checkpoint"))
+ ckpt = CheckpointState()
+ text_format.Merge(file_content, ckpt)
+ self.assertEqual(ckpt.model_checkpoint_path, rel_path2)
+ self.assertEqual(len(ckpt.all_model_checkpoint_paths), 2)
+ self.assertEqual(ckpt.all_model_checkpoint_paths[-1], rel_path2)
+ self.assertEqual(ckpt.all_model_checkpoint_paths[0], rel_path0)
+
+ # get_checkpoint_state should return absolute paths.
+ ckpt = checkpoint_management.get_checkpoint_state(save_dir)
+ self.assertEqual(ckpt.model_checkpoint_path, abs_path2)
+ self.assertEqual(len(ckpt.all_model_checkpoint_paths), 2)
+ self.assertEqual(ckpt.all_model_checkpoint_paths[-1], abs_path2)
+ self.assertEqual(ckpt.all_model_checkpoint_paths[0], abs_path0)
+
+ def testCheckPointStateFailsWhenIncomplete(self):
+ save_dir = self._get_test_dir("checkpoint_state_fails_when_incomplete")
+ os.chdir(save_dir)
+ ckpt_path = os.path.join(save_dir, "checkpoint")
+ ckpt_file = open(ckpt_path, "w")
+ ckpt_file.write("")
+ ckpt_file.close()
+ with self.assertRaises(ValueError):
+ checkpoint_management.get_checkpoint_state(save_dir)
+
+ def testCheckPointCompletesRelativePaths(self):
+ save_dir = self._get_test_dir("checkpoint_completes_relative_paths")
+ os.chdir(save_dir)
+ ckpt_path = os.path.join(save_dir, "checkpoint")
+ ckpt_file = open(ckpt_path, "w")
+ ckpt_file.write("""
+ model_checkpoint_path: "./model.ckpt-687529"
+ all_model_checkpoint_paths: "./model.ckpt-687500"
+ all_model_checkpoint_paths: "./model.ckpt-687529"
+ """)
+ ckpt_file.close()
+ ckpt = checkpoint_management.get_checkpoint_state(save_dir)
+ self.assertEqual(ckpt.model_checkpoint_path,
+ os.path.join(save_dir, "./model.ckpt-687529"))
+ self.assertEqual(ckpt.all_model_checkpoint_paths[0],
+ os.path.join(save_dir, "./model.ckpt-687500"))
+ self.assertEqual(ckpt.all_model_checkpoint_paths[1],
+ os.path.join(save_dir, "./model.ckpt-687529"))
+
+
+class SaverUtilsTest(test.TestCase):
+
+ def setUp(self):
+ self._base_dir = os.path.join(self.get_temp_dir(), "saver_utils_test")
+ gfile.MakeDirs(self._base_dir)
+
+ def tearDown(self):
+ gfile.DeleteRecursively(self._base_dir)
+
+ def testCheckpointExists(self):
+ for sharded in (False, True):
+ for version in (saver_pb2.SaverDef.V2, saver_pb2.SaverDef.V1):
+ with self.test_session(graph=ops_lib.Graph()) as sess:
+ unused_v = variables.Variable(1.0, name="v")
+ variables.global_variables_initializer().run()
+ saver = saver_module.Saver(sharded=sharded, write_version=version)
+
+ path = os.path.join(self._base_dir, "%s-%s" % (sharded, version))
+ self.assertFalse(
+ checkpoint_management.checkpoint_exists(path)) # Not saved yet.
+
+ ckpt_prefix = saver.save(sess, path)
+ self.assertTrue(checkpoint_management.checkpoint_exists(ckpt_prefix))
+
+ ckpt_prefix = checkpoint_management.latest_checkpoint(self._base_dir)
+ self.assertTrue(checkpoint_management.checkpoint_exists(ckpt_prefix))
+
+ def testGetCheckpointMtimes(self):
+ prefixes = []
+ for version in (saver_pb2.SaverDef.V2, saver_pb2.SaverDef.V1):
+ with self.test_session(graph=ops_lib.Graph()) as sess:
+ unused_v = variables.Variable(1.0, name="v")
+ variables.global_variables_initializer().run()
+ saver = saver_module.Saver(write_version=version)
+ prefixes.append(
+ saver.save(sess, os.path.join(self._base_dir, str(version))))
+
+ mtimes = checkpoint_management.get_checkpoint_mtimes(prefixes)
+ self.assertEqual(2, len(mtimes))
+ self.assertTrue(mtimes[1] >= mtimes[0])
+
+ def testRemoveCheckpoint(self):
+ for sharded in (False, True):
+ for version in (saver_pb2.SaverDef.V2, saver_pb2.SaverDef.V1):
+ with self.test_session(graph=ops_lib.Graph()) as sess:
+ unused_v = variables.Variable(1.0, name="v")
+ variables.global_variables_initializer().run()
+ saver = saver_module.Saver(sharded=sharded, write_version=version)
+
+ path = os.path.join(self._base_dir, "%s-%s" % (sharded, version))
+ ckpt_prefix = saver.save(sess, path)
+ self.assertTrue(checkpoint_management.checkpoint_exists(ckpt_prefix))
+ checkpoint_management.remove_checkpoint(ckpt_prefix, version)
+ self.assertFalse(checkpoint_management.checkpoint_exists(ckpt_prefix))
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/training/checkpoint_utils.py b/tensorflow/python/training/checkpoint_utils.py
index a052081630..9b72b09f08 100644
--- a/tensorflow/python/training/checkpoint_utils.py
+++ b/tensorflow/python/training/checkpoint_utils.py
@@ -28,6 +28,7 @@ from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops import variables
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.training import saver
from tensorflow.python.util.tf_export import tf_export
@@ -277,7 +278,7 @@ def _init_from_checkpoint(_, ckpt_dir_or_file, assignment_map):
def _get_checkpoint_filename(ckpt_dir_or_file):
"""Returns checkpoint filename given directory or specific checkpoint file."""
if gfile.IsDirectory(ckpt_dir_or_file):
- return saver.latest_checkpoint(ckpt_dir_or_file)
+ return checkpoint_management.latest_checkpoint(ckpt_dir_or_file)
return ckpt_dir_or_file
diff --git a/tensorflow/python/training/checkpointable/BUILD b/tensorflow/python/training/checkpointable/BUILD
index 35007653a0..8a289b31b5 100644
--- a/tensorflow/python/training/checkpointable/BUILD
+++ b/tensorflow/python/training/checkpointable/BUILD
@@ -124,14 +124,18 @@ py_test(
],
deps = [
":base",
+ ":tracking",
":util",
+ "//tensorflow/python:checkpoint_management",
"//tensorflow/python:constant_op",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:init_ops",
+ "//tensorflow/python:pywrap_tensorflow",
"//tensorflow/python:resource_variable_ops",
+ "//tensorflow/python:saver",
"//tensorflow/python:session",
"//tensorflow/python:state_ops",
"//tensorflow/python:template",
diff --git a/tensorflow/python/training/checkpointable/util_test.py b/tensorflow/python/training/checkpointable/util_test.py
index 3c1a4a6f83..5506e6bc4e 100644
--- a/tensorflow/python/training/checkpointable/util_test.py
+++ b/tensorflow/python/training/checkpointable/util_test.py
@@ -42,6 +42,7 @@ from tensorflow.python.ops import state_ops
from tensorflow.python.ops import template
from tensorflow.python.ops import variable_scope
from tensorflow.python.training import adam
+from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import saver as saver_lib
from tensorflow.python.training import training_util
from tensorflow.python.training.checkpointable import base
@@ -467,7 +468,8 @@ class CheckpointingTests(test.TestCase):
root = checkpointable_utils.Checkpoint(
optimizer=optimizer, model=model,
optimizer_step=training_util.get_or_create_global_step())
- root.restore(saver_lib.latest_checkpoint(checkpoint_directory))
+ root.restore(checkpoint_management.latest_checkpoint(
+ checkpoint_directory))
for _ in range(num_training_steps):
# TODO(allenl): Use a Dataset and serialize/checkpoint it.
input_value = constant_op.constant([[3.]])
@@ -495,7 +497,8 @@ class CheckpointingTests(test.TestCase):
train_op = optimizer.minimize(
model(input_value),
global_step=root.global_step)
- checkpoint_path = saver_lib.latest_checkpoint(checkpoint_directory)
+ checkpoint_path = checkpoint_management.latest_checkpoint(
+ checkpoint_directory)
with self.test_session(graph=ops.get_default_graph()) as session:
status = root.restore(save_path=checkpoint_path)
status.initialize_or_restore(session=session)
@@ -528,7 +531,8 @@ class CheckpointingTests(test.TestCase):
root = checkpointable_utils.Checkpoint(
optimizer=optimizer, model=model,
global_step=training_util.get_or_create_global_step())
- checkpoint_path = saver_lib.latest_checkpoint(checkpoint_directory)
+ checkpoint_path = checkpoint_management.latest_checkpoint(
+ checkpoint_directory)
status = root.restore(save_path=checkpoint_path)
input_value = constant_op.constant([[3.]])
train_fn = functools.partial(
@@ -561,7 +565,8 @@ class CheckpointingTests(test.TestCase):
root = checkpointable_utils.Checkpoint(
optimizer=optimizer, model=model,
global_step=training_util.get_or_create_global_step())
- checkpoint_path = saver_lib.latest_checkpoint(checkpoint_directory)
+ checkpoint_path = checkpoint_management.latest_checkpoint(
+ checkpoint_directory)
status = root.restore(save_path=checkpoint_path)
def train_fn():
@function.defun
@@ -1180,7 +1185,8 @@ class CheckpointingTests(test.TestCase):
optimizer_checkpoint = checkpointable_utils.Checkpoint(
optimizer=optimizer)
- checkpoint_path = saver_lib.latest_checkpoint(checkpoint_directory)
+ checkpoint_path = checkpoint_management.latest_checkpoint(
+ checkpoint_directory)
status = root.restore(save_path=checkpoint_path)
input_value = constant_op.constant([[3.]])
train_fn = functools.partial(
diff --git a/tensorflow/python/training/monitored_session_test.py b/tensorflow/python/training/monitored_session_test.py
index 3806056f01..92533ca4f3 100644
--- a/tensorflow/python/training/monitored_session_test.py
+++ b/tensorflow/python/training/monitored_session_test.py
@@ -44,6 +44,7 @@ from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.summary import summary
from tensorflow.python.training import basic_session_run_hooks
+from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import coordinator
from tensorflow.python.training import monitored_session
from tensorflow.python.training import saver as saver_lib
@@ -1364,8 +1365,8 @@ class MonitoredSessionTest(test.TestCase):
with monitored_session.MonitoredSession(
session_creator=monitored_session.ChiefSessionCreator(
scaffold,
- checkpoint_filename_with_path=saver_lib.latest_checkpoint(
- logdir))) as session:
+ checkpoint_filename_with_path=
+ checkpoint_management.latest_checkpoint(logdir))) as session:
self.assertEqual(2, session.run(gstep))
def test_retry_initialization_on_aborted_error(self):
diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py
index c80cdf03be..13a97a9545 100644
--- a/tensorflow/python/training/saver.py
+++ b/tensorflow/python/training/saver.py
@@ -21,15 +21,12 @@ from __future__ import print_function
import collections
import os.path
-import re
import time
import uuid
import numpy as np
import six
-from google.protobuf import text_format
-
from tensorflow.core.protobuf import checkpointable_object_graph_pb2
from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.core.protobuf import saver_pb2
@@ -41,7 +38,6 @@ from tensorflow.python.framework import device as pydev
from tensorflow.python.framework import errors
from tensorflow.python.framework import meta_graph
from tensorflow.python.framework import ops
-from tensorflow.python.lib.io import file_io
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_io_ops
@@ -52,14 +48,19 @@ from tensorflow.python.ops import string_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import saveable_object
from tensorflow.python.training import training_util
-from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState
from tensorflow.python.training.checkpointable import base as checkpointable
from tensorflow.python.util import compat
from tensorflow.python.util.tf_export import tf_export
+# TODO(allenl): Remove these aliases once all users are migrated off.
+get_checkpoint_state = checkpoint_management.get_checkpoint_state
+update_checkpoint_state = checkpoint_management.update_checkpoint_state
+
+
# Op names which identify variable reads which should be saved.
_VARIABLE_OPS = set(["Variable",
"VariableV2",
@@ -858,218 +859,6 @@ def _get_saver_or_default():
return saver
-def _GetCheckpointFilename(save_dir, latest_filename):
- """Returns a filename for storing the CheckpointState.
-
- Args:
- save_dir: The directory for saving and restoring checkpoints.
- latest_filename: Name of the file in 'save_dir' that is used
- to store the CheckpointState.
-
- Returns:
- The path of the file that contains the CheckpointState proto.
- """
- if latest_filename is None:
- latest_filename = "checkpoint"
- return os.path.join(save_dir, latest_filename)
-
-
-@tf_export("train.generate_checkpoint_state_proto")
-def generate_checkpoint_state_proto(save_dir,
- model_checkpoint_path,
- all_model_checkpoint_paths=None):
- """Generates a checkpoint state proto.
-
- Args:
- save_dir: Directory where the model was saved.
- model_checkpoint_path: The checkpoint file.
- all_model_checkpoint_paths: List of strings. Paths to all not-yet-deleted
- checkpoints, sorted from oldest to newest. If this is a non-empty list,
- the last element must be equal to model_checkpoint_path. These paths
- are also saved in the CheckpointState proto.
-
- Returns:
- CheckpointState proto with model_checkpoint_path and
- all_model_checkpoint_paths updated to either absolute paths or
- relative paths to the current save_dir.
- """
- if all_model_checkpoint_paths is None:
- all_model_checkpoint_paths = []
-
- if (not all_model_checkpoint_paths or
- all_model_checkpoint_paths[-1] != model_checkpoint_path):
- logging.info("%s is not in all_model_checkpoint_paths. Manually adding it.",
- model_checkpoint_path)
- all_model_checkpoint_paths.append(model_checkpoint_path)
-
- # Relative paths need to be rewritten to be relative to the "save_dir"
- # if model_checkpoint_path already contains "save_dir".
- if not os.path.isabs(save_dir):
- if not os.path.isabs(model_checkpoint_path):
- model_checkpoint_path = os.path.relpath(model_checkpoint_path, save_dir)
- for i in range(len(all_model_checkpoint_paths)):
- p = all_model_checkpoint_paths[i]
- if not os.path.isabs(p):
- all_model_checkpoint_paths[i] = os.path.relpath(p, save_dir)
-
- coord_checkpoint_proto = CheckpointState(
- model_checkpoint_path=model_checkpoint_path,
- all_model_checkpoint_paths=all_model_checkpoint_paths)
-
- return coord_checkpoint_proto
-
-
-@tf_export("train.update_checkpoint_state")
-def update_checkpoint_state(save_dir,
- model_checkpoint_path,
- all_model_checkpoint_paths=None,
- latest_filename=None):
- """Updates the content of the 'checkpoint' file.
-
- This updates the checkpoint file containing a CheckpointState
- proto.
-
- Args:
- save_dir: Directory where the model was saved.
- model_checkpoint_path: The checkpoint file.
- all_model_checkpoint_paths: List of strings. Paths to all not-yet-deleted
- checkpoints, sorted from oldest to newest. If this is a non-empty list,
- the last element must be equal to model_checkpoint_path. These paths
- are also saved in the CheckpointState proto.
- latest_filename: Optional name of the checkpoint file. Default to
- 'checkpoint'.
-
- Raises:
- RuntimeError: If any of the model checkpoint paths conflict with the file
- containing CheckpointSate.
- """
- _update_checkpoint_state(
- save_dir=save_dir,
- model_checkpoint_path=model_checkpoint_path,
- all_model_checkpoint_paths=all_model_checkpoint_paths,
- latest_filename=latest_filename,
- save_relative_paths=False)
-
-
-def _update_checkpoint_state(save_dir,
- model_checkpoint_path,
- all_model_checkpoint_paths=None,
- latest_filename=None,
- save_relative_paths=False):
- """Updates the content of the 'checkpoint' file.
-
- This updates the checkpoint file containing a CheckpointState
- proto.
-
- Args:
- save_dir: Directory where the model was saved.
- model_checkpoint_path: The checkpoint file.
- all_model_checkpoint_paths: List of strings. Paths to all not-yet-deleted
- checkpoints, sorted from oldest to newest. If this is a non-empty list,
- the last element must be equal to model_checkpoint_path. These paths
- are also saved in the CheckpointState proto.
- latest_filename: Optional name of the checkpoint file. Default to
- 'checkpoint'.
- save_relative_paths: If `True`, will write relative paths to the checkpoint
- state file.
-
- Raises:
- RuntimeError: If any of the model checkpoint paths conflict with the file
- containing CheckpointSate.
- """
- # Writes the "checkpoint" file for the coordinator for later restoration.
- coord_checkpoint_filename = _GetCheckpointFilename(save_dir, latest_filename)
- if save_relative_paths:
- if os.path.isabs(model_checkpoint_path):
- rel_model_checkpoint_path = os.path.relpath(
- model_checkpoint_path, save_dir)
- else:
- rel_model_checkpoint_path = model_checkpoint_path
- rel_all_model_checkpoint_paths = []
- for p in all_model_checkpoint_paths:
- if os.path.isabs(p):
- rel_all_model_checkpoint_paths.append(os.path.relpath(p, save_dir))
- else:
- rel_all_model_checkpoint_paths.append(p)
- ckpt = generate_checkpoint_state_proto(
- save_dir,
- rel_model_checkpoint_path,
- all_model_checkpoint_paths=rel_all_model_checkpoint_paths)
- else:
- ckpt = generate_checkpoint_state_proto(
- save_dir,
- model_checkpoint_path,
- all_model_checkpoint_paths=all_model_checkpoint_paths)
-
- if coord_checkpoint_filename == ckpt.model_checkpoint_path:
- raise RuntimeError("Save path '%s' conflicts with path used for "
- "checkpoint state. Please use a different save path." %
- model_checkpoint_path)
-
- # Preventing potential read/write race condition by *atomically* writing to a
- # file.
- file_io.atomic_write_string_to_file(coord_checkpoint_filename,
- text_format.MessageToString(ckpt))
-
-
-@tf_export("train.get_checkpoint_state")
-def get_checkpoint_state(checkpoint_dir, latest_filename=None):
- """Returns CheckpointState proto from the "checkpoint" file.
-
- If the "checkpoint" file contains a valid CheckpointState
- proto, returns it.
-
- Args:
- checkpoint_dir: The directory of checkpoints.
- latest_filename: Optional name of the checkpoint file. Default to
- 'checkpoint'.
-
- Returns:
- A CheckpointState if the state was available, None
- otherwise.
-
- Raises:
- ValueError: if the checkpoint read doesn't have model_checkpoint_path set.
- """
- ckpt = None
- coord_checkpoint_filename = _GetCheckpointFilename(checkpoint_dir,
- latest_filename)
- f = None
- try:
- # Check that the file exists before opening it to avoid
- # many lines of errors from colossus in the logs.
- if file_io.file_exists(coord_checkpoint_filename):
- file_content = file_io.read_file_to_string(
- coord_checkpoint_filename)
- ckpt = CheckpointState()
- text_format.Merge(file_content, ckpt)
- if not ckpt.model_checkpoint_path:
- raise ValueError("Invalid checkpoint state loaded from "
- + checkpoint_dir)
- # For relative model_checkpoint_path and all_model_checkpoint_paths,
- # prepend checkpoint_dir.
- if not os.path.isabs(ckpt.model_checkpoint_path):
- ckpt.model_checkpoint_path = os.path.join(checkpoint_dir,
- ckpt.model_checkpoint_path)
- for i in range(len(ckpt.all_model_checkpoint_paths)):
- p = ckpt.all_model_checkpoint_paths[i]
- if not os.path.isabs(p):
- ckpt.all_model_checkpoint_paths[i] = os.path.join(checkpoint_dir, p)
- except errors.OpError as e:
- # It's ok if the file cannot be read
- logging.warning("%s: %s", type(e).__name__, e)
- logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename)
- return None
- except text_format.ParseError as e:
- logging.warning("%s: %s", type(e).__name__, e)
- logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename)
- return None
- finally:
- if f:
- f.close()
- return ckpt
-
-
@tf_export("train.Saver")
class Saver(object):
"""Saves and restores variables.
@@ -1412,7 +1201,7 @@ class Saver(object):
# Otherwise delete the files.
try:
- remove_checkpoint(
+ checkpoint_management.remove_checkpoint(
self._CheckpointFilename(p), self.saver_def.version,
meta_graph_suffix)
except Exception as e: # pylint: disable=broad-except
@@ -1518,7 +1307,7 @@ class Saver(object):
Args:
checkpoint_paths: a list of checkpoint paths.
"""
- mtimes = get_checkpoint_mtimes(checkpoint_paths)
+ mtimes = checkpoint_management.get_checkpoint_mtimes(checkpoint_paths)
self.set_last_checkpoints_with_time(list(zip(checkpoint_paths, mtimes)))
def save(self,
@@ -1624,7 +1413,7 @@ class Saver(object):
model_checkpoint_path = compat.as_str(model_checkpoint_path)
if write_state:
self._RecordLastCheckpoint(model_checkpoint_path)
- _update_checkpoint_state(
+ checkpoint_management.update_checkpoint_state_internal(
save_dir=save_path_parent,
model_checkpoint_path=model_checkpoint_path,
all_model_checkpoint_paths=self.last_checkpoints,
@@ -1639,7 +1428,7 @@ class Saver(object):
raise exc
if write_meta_graph:
- meta_graph_filename = _meta_graph_filename(
+ meta_graph_filename = checkpoint_management.meta_graph_filename(
checkpoint_file, meta_graph_suffix=meta_graph_suffix)
if not context.executing_eagerly():
with sess.graph.as_default():
@@ -1714,7 +1503,7 @@ class Saver(object):
if save_path is None:
raise ValueError("Can't load save_path when it is None.")
- if not checkpoint_exists(compat.as_text(save_path)):
+ if not checkpoint_management.checkpoint_exists(compat.as_text(save_path)):
raise ValueError("The passed save_path is not a valid checkpoint: "
+ compat.as_text(save_path))
@@ -1800,55 +1589,6 @@ class Saver(object):
export_scope=export_scope)
-def _prefix_to_checkpoint_path(prefix, format_version):
- """Returns the pathname of a checkpoint file, given the checkpoint prefix.
-
- For V1 checkpoint, simply returns the prefix itself (the data file). For V2,
- returns the pathname to the index file.
-
- Args:
- prefix: a string, the prefix of a checkpoint.
- format_version: the checkpoint format version that corresponds to the
- prefix.
- Returns:
- The pathname of a checkpoint file, taking into account the checkpoint
- format version.
- """
- if format_version == saver_pb2.SaverDef.V2:
- return prefix + ".index" # The index file identifies a checkpoint.
- return prefix # Just the data file.
-
-
-@tf_export("train.latest_checkpoint")
-def latest_checkpoint(checkpoint_dir, latest_filename=None):
- """Finds the filename of latest saved checkpoint file.
-
- Args:
- checkpoint_dir: Directory where the variables were saved.
- latest_filename: Optional name for the protocol buffer file that
- contains the list of most recent checkpoint filenames.
- See the corresponding argument to `Saver.save()`.
-
- Returns:
- The full path to the latest checkpoint or `None` if no checkpoint was found.
- """
- # Pick the latest checkpoint based on checkpoint state.
- ckpt = get_checkpoint_state(checkpoint_dir, latest_filename)
- if ckpt and ckpt.model_checkpoint_path:
- # Look for either a V2 path or a V1 path, with priority for V2.
- v2_path = _prefix_to_checkpoint_path(ckpt.model_checkpoint_path,
- saver_pb2.SaverDef.V2)
- v1_path = _prefix_to_checkpoint_path(ckpt.model_checkpoint_path,
- saver_pb2.SaverDef.V1)
- if file_io.get_matching_files(v2_path) or file_io.get_matching_files(
- v1_path):
- return ckpt.model_checkpoint_path
- else:
- logging.error("Couldn't match files for checkpoint %s",
- ckpt.model_checkpoint_path)
- return None
-
-
@tf_export("train.import_meta_graph")
def import_meta_graph(meta_graph_or_file, clear_devices=False,
import_scope=None, **kwargs):
@@ -2056,119 +1796,6 @@ def export_meta_graph(filename=None,
return meta_graph_def
-@tf_export("train.checkpoint_exists")
-def checkpoint_exists(checkpoint_prefix):
- """Checks whether a V1 or V2 checkpoint exists with the specified prefix.
-
- This is the recommended way to check if a checkpoint exists, since it takes
- into account the naming difference between V1 and V2 formats.
-
- Args:
- checkpoint_prefix: the prefix of a V1 or V2 checkpoint, with V2 taking
- priority. Typically the result of `Saver.save()` or that of
- `tf.train.latest_checkpoint()`, regardless of sharded/non-sharded or
- V1/V2.
- Returns:
- A bool, true iff a checkpoint referred to by `checkpoint_prefix` exists.
- """
- pathname = _prefix_to_checkpoint_path(checkpoint_prefix,
- saver_pb2.SaverDef.V2)
- if file_io.get_matching_files(pathname):
- return True
- elif file_io.get_matching_files(checkpoint_prefix):
- return True
- else:
- return False
-
-
-@tf_export("train.get_checkpoint_mtimes")
-def get_checkpoint_mtimes(checkpoint_prefixes):
- """Returns the mtimes (modification timestamps) of the checkpoints.
-
- Globs for the checkpoints pointed to by `checkpoint_prefixes`. If the files
- exist, collect their mtime. Both V2 and V1 checkpoints are considered, in
- that priority.
-
- This is the recommended way to get the mtimes, since it takes into account
- the naming difference between V1 and V2 formats.
-
- Args:
- checkpoint_prefixes: a list of checkpoint paths, typically the results of
- `Saver.save()` or those of `tf.train.latest_checkpoint()`, regardless of
- sharded/non-sharded or V1/V2.
- Returns:
- A list of mtimes (in microseconds) of the found checkpoints.
- """
- mtimes = []
-
- def match_maybe_append(pathname):
- fnames = file_io.get_matching_files(pathname)
- if fnames:
- mtimes.append(file_io.stat(fnames[0]).mtime_nsec / 1e9)
- return True
- return False
-
- for checkpoint_prefix in checkpoint_prefixes:
- # Tries V2's metadata file first.
- pathname = _prefix_to_checkpoint_path(checkpoint_prefix,
- saver_pb2.SaverDef.V2)
- if match_maybe_append(pathname):
- continue
- # Otherwise, tries V1, where the prefix is the complete pathname.
- match_maybe_append(checkpoint_prefix)
-
- return mtimes
-
-
-@tf_export("train.remove_checkpoint")
-def remove_checkpoint(checkpoint_prefix,
- checkpoint_format_version=saver_pb2.SaverDef.V2,
- meta_graph_suffix="meta"):
- """Removes a checkpoint given by `checkpoint_prefix`.
-
- Args:
- checkpoint_prefix: The prefix of a V1 or V2 checkpoint. Typically the result
- of `Saver.save()` or that of `tf.train.latest_checkpoint()`, regardless of
- sharded/non-sharded or V1/V2.
- checkpoint_format_version: `SaverDef.CheckpointFormatVersion`, defaults to
- `SaverDef.V2`.
- meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'.
- """
- _delete_file_if_exists(
- _meta_graph_filename(checkpoint_prefix, meta_graph_suffix))
- if checkpoint_format_version == saver_pb2.SaverDef.V2:
- # V2 has a metadata file and some data files.
- _delete_file_if_exists(checkpoint_prefix + ".index")
- _delete_file_if_exists(checkpoint_prefix + ".data-?????-of-?????")
- else:
- # V1, Legacy. Exact match on the data file.
- _delete_file_if_exists(checkpoint_prefix)
-
-
-def _delete_file_if_exists(filespec):
- """Deletes files matching `filespec`."""
- for pathname in file_io.get_matching_files(filespec):
- file_io.delete_file(pathname)
-
-
-def _meta_graph_filename(checkpoint_filename, meta_graph_suffix="meta"):
- """Returns the meta graph filename.
-
- Args:
- checkpoint_filename: Name of the checkpoint file.
- meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'.
-
- Returns:
- MetaGraph file name.
- """
- # If the checkpoint_filename is sharded, the checkpoint_filename could
- # be of format model.ckpt-step#-?????-of-shard#. For example,
- # model.ckpt-123456-?????-of-00005, or model.ckpt-123456-00001-of-00002.
- basename = re.sub(r"-[\d\?]+-of-\d+$", "", checkpoint_filename)
- meta_graph_filename = ".".join([basename, meta_graph_suffix])
- return meta_graph_filename
-
-
def _wrap_restore_error_with_msg(err, extra_verbiage):
err_msg = ("Restoring from checkpoint failed. This is most likely "
"due to {} from the checkpoint. Please ensure that you "
diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py
index 204e81dda0..941aafc780 100644
--- a/tensorflow/python/training/saver_test.py
+++ b/tensorflow/python/training/saver_test.py
@@ -18,20 +18,16 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import contextlib
import functools
import math
import os
import random
-import shutil
-import tempfile
import time
import numpy as np
import six
from google.protobuf.any_pb2 import Any
-from google.protobuf import text_format
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import meta_graph_pb2
@@ -71,12 +67,12 @@ from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
from tensorflow.python.summary import summary
from tensorflow.python.training import adam
+from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import gradient_descent
from tensorflow.python.training import queue_runner_impl
from tensorflow.python.training import saver as saver_module
from tensorflow.python.training import saver_test_utils
from tensorflow.python.training import training_util
-from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState
from tensorflow.python.training.checkpointable import base as checkpointable_base
from tensorflow.python.training.checkpointable import tracking as checkpointable_tracking
from tensorflow.python.training.checkpointable import util as checkpointable_utils
@@ -343,11 +339,13 @@ class SaverTest(test.TestCase):
self.assertTrue(isinstance(val, six.string_types))
self.assertEqual(save_path1, val)
- self.assertEqual(saver_module.latest_checkpoint(save_dir1), save_path1)
+ self.assertEqual(
+ checkpoint_management.latest_checkpoint(save_dir1), save_path1)
save_dir2 = os.path.join(self.get_temp_dir(), "save_dir2")
os.renames(save_dir1, save_dir2)
save_path2 = os.path.join(save_dir2, "save_copy_restore")
- self.assertEqual(saver_module.latest_checkpoint(save_dir2), save_path2)
+ self.assertEqual(
+ checkpoint_management.latest_checkpoint(save_dir2), save_path2)
# Start a second session. In that session the parameter nodes
# have not been initialized either.
@@ -857,7 +855,7 @@ class SaveRestoreShardedTest(test.TestCase):
self.assertEqual(save_path + "-?????-of-00002", val)
else:
self.assertEqual(save_path, val)
- meta_graph_filename = saver_module._meta_graph_filename(val)
+ meta_graph_filename = checkpoint_management.meta_graph_filename(val)
self.assertEqual(save_path + ".meta", meta_graph_filename)
if save._write_version is saver_pb2.SaverDef.V1:
@@ -951,11 +949,11 @@ class SaveRestoreShardedTest(test.TestCase):
if save._write_version is saver_pb2.SaverDef.V1:
self.assertEqual(
- saver_module.latest_checkpoint(self.get_temp_dir()),
+ checkpoint_management.latest_checkpoint(self.get_temp_dir()),
os.path.join(self.get_temp_dir(), "sharded_basics-?????-of-00002"))
else:
self.assertEqual(
- saver_module.latest_checkpoint(self.get_temp_dir()),
+ checkpoint_management.latest_checkpoint(self.get_temp_dir()),
os.path.join(self.get_temp_dir(), "sharded_basics"))
def testSaverDef(self):
@@ -1105,7 +1103,7 @@ class MaxToKeepTest(test.TestCase):
def assertCheckpointState(self, model_checkpoint_path,
all_model_checkpoint_paths, save_dir):
- checkpoint_state = saver_module.get_checkpoint_state(save_dir)
+ checkpoint_state = checkpoint_management.get_checkpoint_state(save_dir)
self.assertEqual(checkpoint_state.model_checkpoint_path,
model_checkpoint_path)
self.assertEqual(checkpoint_state.all_model_checkpoint_paths,
@@ -1113,7 +1111,7 @@ class MaxToKeepTest(test.TestCase):
def testMaxToKeepEager(self):
with context.eager_mode():
- save_dir = self._get_test_dir("max_to_keep_non_sharded")
+ save_dir = self._get_test_dir("max_to_keep_eager")
v = variable_scope.variable(10.0, name="v")
save = saver_module.Saver({"v": v}, max_to_keep=2)
@@ -1123,7 +1121,7 @@ class MaxToKeepTest(test.TestCase):
s1 = save.save(None, os.path.join(save_dir, "s1"))
self.assertEqual([s1], save.last_checkpoints)
- self.assertTrue(saver_module.checkpoint_exists(s1))
+ self.assertTrue(checkpoint_management.checkpoint_exists(s1))
self.assertCheckpointState(
model_checkpoint_path=s1,
all_model_checkpoint_paths=[s1],
@@ -1131,8 +1129,8 @@ class MaxToKeepTest(test.TestCase):
s2 = save.save(None, os.path.join(save_dir, "s2"))
self.assertEqual([s1, s2], save.last_checkpoints)
- self.assertTrue(saver_module.checkpoint_exists(s1))
- self.assertTrue(saver_module.checkpoint_exists(s2))
+ self.assertTrue(checkpoint_management.checkpoint_exists(s1))
+ self.assertTrue(checkpoint_management.checkpoint_exists(s2))
self.assertCheckpointState(
model_checkpoint_path=s2,
all_model_checkpoint_paths=[s1, s2],
@@ -1140,9 +1138,9 @@ class MaxToKeepTest(test.TestCase):
s3 = save.save(None, os.path.join(save_dir, "s3"))
self.assertEqual([s2, s3], save.last_checkpoints)
- self.assertFalse(saver_module.checkpoint_exists(s1))
- self.assertTrue(saver_module.checkpoint_exists(s2))
- self.assertTrue(saver_module.checkpoint_exists(s3))
+ self.assertFalse(checkpoint_management.checkpoint_exists(s1))
+ self.assertTrue(checkpoint_management.checkpoint_exists(s2))
+ self.assertTrue(checkpoint_management.checkpoint_exists(s3))
self.assertCheckpointState(
model_checkpoint_path=s3,
all_model_checkpoint_paths=[s2, s3],
@@ -1157,9 +1155,9 @@ class MaxToKeepTest(test.TestCase):
# Adding s2 again (old s2 is removed first, then new s2 appended)
s2 = save.save(None, os.path.join(save_dir, "s2"))
self.assertEqual([s3, s2], save.last_checkpoints)
- self.assertFalse(saver_module.checkpoint_exists(s1))
- self.assertTrue(saver_module.checkpoint_exists(s3))
- self.assertTrue(saver_module.checkpoint_exists(s2))
+ self.assertFalse(checkpoint_management.checkpoint_exists(s1))
+ self.assertTrue(checkpoint_management.checkpoint_exists(s3))
+ self.assertTrue(checkpoint_management.checkpoint_exists(s2))
self.assertCheckpointState(
model_checkpoint_path=s2,
all_model_checkpoint_paths=[s3, s2],
@@ -1168,8 +1166,8 @@ class MaxToKeepTest(test.TestCase):
# Adding s1 (s3 should now be deleted as oldest in list)
s1 = save.save(None, os.path.join(save_dir, "s1"))
self.assertEqual([s2, s1], save.last_checkpoints)
- self.assertFalse(saver_module.checkpoint_exists(s3))
- self.assertTrue(saver_module.checkpoint_exists(s2))
+ self.assertFalse(checkpoint_management.checkpoint_exists(s3))
+ self.assertTrue(checkpoint_management.checkpoint_exists(s2))
self.assertCheckpointState(
model_checkpoint_path=s1,
all_model_checkpoint_paths=[s2, s1],
@@ -1178,9 +1176,9 @@ class MaxToKeepTest(test.TestCase):
s2 = save2.save(None, os.path.join(save_dir, "s2"))
self.assertEqual([s3, s2], save2.last_checkpoints)
# Created by the first helper.
- self.assertTrue(saver_module.checkpoint_exists(s1))
+ self.assertTrue(checkpoint_management.checkpoint_exists(s1))
# Deleted by the first helper.
- self.assertFalse(saver_module.checkpoint_exists(s3))
+ self.assertFalse(checkpoint_management.checkpoint_exists(s3))
def testNonSharded(self):
save_dir = self._get_test_dir("max_to_keep_non_sharded")
@@ -1193,7 +1191,7 @@ class MaxToKeepTest(test.TestCase):
s1 = save.save(sess, os.path.join(save_dir, "s1"))
self.assertEqual([s1], save.last_checkpoints)
- self.assertTrue(saver_module.checkpoint_exists(s1))
+ self.assertTrue(checkpoint_management.checkpoint_exists(s1))
self.assertCheckpointState(
model_checkpoint_path=s1,
all_model_checkpoint_paths=[s1],
@@ -1201,8 +1199,8 @@ class MaxToKeepTest(test.TestCase):
s2 = save.save(sess, os.path.join(save_dir, "s2"))
self.assertEqual([s1, s2], save.last_checkpoints)
- self.assertTrue(saver_module.checkpoint_exists(s1))
- self.assertTrue(saver_module.checkpoint_exists(s2))
+ self.assertTrue(checkpoint_management.checkpoint_exists(s1))
+ self.assertTrue(checkpoint_management.checkpoint_exists(s2))
self.assertCheckpointState(
model_checkpoint_path=s2,
all_model_checkpoint_paths=[s1, s2],
@@ -1210,9 +1208,9 @@ class MaxToKeepTest(test.TestCase):
s3 = save.save(sess, os.path.join(save_dir, "s3"))
self.assertEqual([s2, s3], save.last_checkpoints)
- self.assertFalse(saver_module.checkpoint_exists(s1))
- self.assertTrue(saver_module.checkpoint_exists(s2))
- self.assertTrue(saver_module.checkpoint_exists(s3))
+ self.assertFalse(checkpoint_management.checkpoint_exists(s1))
+ self.assertTrue(checkpoint_management.checkpoint_exists(s2))
+ self.assertTrue(checkpoint_management.checkpoint_exists(s3))
self.assertCheckpointState(
model_checkpoint_path=s3,
all_model_checkpoint_paths=[s2, s3],
@@ -1231,15 +1229,18 @@ class MaxToKeepTest(test.TestCase):
# Adding s2 again (old s2 is removed first, then new s2 appended)
s2 = save.save(sess, os.path.join(save_dir, "s2"))
self.assertEqual([s3, s2], save.last_checkpoints)
- self.assertFalse(saver_module.checkpoint_exists(s1))
+ self.assertFalse(checkpoint_management.checkpoint_exists(s1))
self.assertFalse(
- saver_module.checkpoint_exists(saver_module._meta_graph_filename(s1)))
- self.assertTrue(saver_module.checkpoint_exists(s3))
+ checkpoint_management.checkpoint_exists(
+ checkpoint_management.meta_graph_filename(s1)))
+ self.assertTrue(checkpoint_management.checkpoint_exists(s3))
self.assertTrue(
- saver_module.checkpoint_exists(saver_module._meta_graph_filename(s3)))
- self.assertTrue(saver_module.checkpoint_exists(s2))
+ checkpoint_management.checkpoint_exists(
+ checkpoint_management.meta_graph_filename(s3)))
+ self.assertTrue(checkpoint_management.checkpoint_exists(s2))
self.assertTrue(
- saver_module.checkpoint_exists(saver_module._meta_graph_filename(s2)))
+ checkpoint_management.checkpoint_exists(
+ checkpoint_management.meta_graph_filename(s2)))
self.assertCheckpointState(
model_checkpoint_path=s2,
all_model_checkpoint_paths=[s3, s2],
@@ -1248,15 +1249,18 @@ class MaxToKeepTest(test.TestCase):
# Adding s1 (s3 should now be deleted as oldest in list)
s1 = save.save(sess, os.path.join(save_dir, "s1"))
self.assertEqual([s2, s1], save.last_checkpoints)
- self.assertFalse(saver_module.checkpoint_exists(s3))
+ self.assertFalse(checkpoint_management.checkpoint_exists(s3))
self.assertFalse(
- saver_module.checkpoint_exists(saver_module._meta_graph_filename(s3)))
- self.assertTrue(saver_module.checkpoint_exists(s2))
+ checkpoint_management.checkpoint_exists(
+ checkpoint_management.meta_graph_filename(s3)))
+ self.assertTrue(checkpoint_management.checkpoint_exists(s2))
self.assertTrue(
- saver_module.checkpoint_exists(saver_module._meta_graph_filename(s2)))
- self.assertTrue(saver_module.checkpoint_exists(s1))
+ checkpoint_management.checkpoint_exists(
+ checkpoint_management.meta_graph_filename(s2)))
+ self.assertTrue(checkpoint_management.checkpoint_exists(s1))
self.assertTrue(
- saver_module.checkpoint_exists(saver_module._meta_graph_filename(s1)))
+ checkpoint_management.checkpoint_exists(
+ checkpoint_management.meta_graph_filename(s1)))
self.assertCheckpointState(
model_checkpoint_path=s1,
all_model_checkpoint_paths=[s2, s1],
@@ -1268,16 +1272,19 @@ class MaxToKeepTest(test.TestCase):
s2 = save2.save(sess, os.path.join(save_dir, "s2"))
self.assertEqual([s3, s2], save2.last_checkpoints)
# Created by the first helper.
- self.assertTrue(saver_module.checkpoint_exists(s1))
+ self.assertTrue(checkpoint_management.checkpoint_exists(s1))
self.assertTrue(
- saver_module.checkpoint_exists(saver_module._meta_graph_filename(s1)))
+ checkpoint_management.checkpoint_exists(
+ checkpoint_management.meta_graph_filename(s1)))
# Deleted by the first helper.
- self.assertFalse(saver_module.checkpoint_exists(s3))
+ self.assertFalse(checkpoint_management.checkpoint_exists(s3))
self.assertFalse(
- saver_module.checkpoint_exists(saver_module._meta_graph_filename(s3)))
- self.assertTrue(saver_module.checkpoint_exists(s2))
+ checkpoint_management.checkpoint_exists(
+ checkpoint_management.meta_graph_filename(s3)))
+ self.assertTrue(checkpoint_management.checkpoint_exists(s2))
self.assertTrue(
- saver_module.checkpoint_exists(saver_module._meta_graph_filename(s2)))
+ checkpoint_management.checkpoint_exists(
+ checkpoint_management.meta_graph_filename(s2)))
self.assertCheckpointState(
model_checkpoint_path=s2,
all_model_checkpoint_paths=[s3, s2],
@@ -1286,15 +1293,18 @@ class MaxToKeepTest(test.TestCase):
# Adding s1 (s3 should now be deleted as oldest in list)
s1 = save2.save(sess, os.path.join(save_dir, "s1"))
self.assertEqual([s2, s1], save2.last_checkpoints)
- self.assertFalse(saver_module.checkpoint_exists(s3))
+ self.assertFalse(checkpoint_management.checkpoint_exists(s3))
self.assertFalse(
- saver_module.checkpoint_exists(saver_module._meta_graph_filename(s3)))
- self.assertTrue(saver_module.checkpoint_exists(s2))
+ checkpoint_management.checkpoint_exists(
+ checkpoint_management.meta_graph_filename(s3)))
+ self.assertTrue(checkpoint_management.checkpoint_exists(s2))
self.assertTrue(
- saver_module.checkpoint_exists(saver_module._meta_graph_filename(s2)))
- self.assertTrue(saver_module.checkpoint_exists(s1))
+ checkpoint_management.checkpoint_exists(
+ checkpoint_management.meta_graph_filename(s2)))
+ self.assertTrue(checkpoint_management.checkpoint_exists(s1))
self.assertTrue(
- saver_module.checkpoint_exists(saver_module._meta_graph_filename(s1)))
+ checkpoint_management.checkpoint_exists(
+ checkpoint_management.meta_graph_filename(s1)))
self.assertCheckpointState(
model_checkpoint_path=s1,
all_model_checkpoint_paths=[s2, s1],
@@ -1306,16 +1316,19 @@ class MaxToKeepTest(test.TestCase):
s2 = save3.save(sess, os.path.join(save_dir, "s2"))
self.assertEqual([s2], save3.last_checkpoints)
# Created by the first helper.
- self.assertTrue(saver_module.checkpoint_exists(s1))
+ self.assertTrue(checkpoint_management.checkpoint_exists(s1))
self.assertTrue(
- saver_module.checkpoint_exists(saver_module._meta_graph_filename(s1)))
+ checkpoint_management.checkpoint_exists(
+ checkpoint_management.meta_graph_filename(s1)))
# Deleted by the first helper.
- self.assertFalse(saver_module.checkpoint_exists(s3))
+ self.assertFalse(checkpoint_management.checkpoint_exists(s3))
self.assertFalse(
- saver_module.checkpoint_exists(saver_module._meta_graph_filename(s3)))
- self.assertTrue(saver_module.checkpoint_exists(s2))
+ checkpoint_management.checkpoint_exists(
+ checkpoint_management.meta_graph_filename(s3)))
+ self.assertTrue(checkpoint_management.checkpoint_exists(s2))
self.assertTrue(
- saver_module.checkpoint_exists(saver_module._meta_graph_filename(s2)))
+ checkpoint_management.checkpoint_exists(
+ checkpoint_management.meta_graph_filename(s2)))
# Even though the file for s1 exists, this saver isn't aware of it, which
# is why it doesn't end up in the checkpoint state.
self.assertCheckpointState(
@@ -1326,15 +1339,18 @@ class MaxToKeepTest(test.TestCase):
# Adding s1 (s3 should not be deleted because helper is unaware of it)
s1 = save3.save(sess, os.path.join(save_dir, "s1"))
self.assertEqual([s2, s1], save3.last_checkpoints)
- self.assertFalse(saver_module.checkpoint_exists(s3))
+ self.assertFalse(checkpoint_management.checkpoint_exists(s3))
self.assertFalse(
- saver_module.checkpoint_exists(saver_module._meta_graph_filename(s3)))
- self.assertTrue(saver_module.checkpoint_exists(s2))
+ checkpoint_management.checkpoint_exists(
+ checkpoint_management.meta_graph_filename(s3)))
+ self.assertTrue(checkpoint_management.checkpoint_exists(s2))
self.assertTrue(
- saver_module.checkpoint_exists(saver_module._meta_graph_filename(s2)))
- self.assertTrue(saver_module.checkpoint_exists(s1))
+ checkpoint_management.checkpoint_exists(
+ checkpoint_management.meta_graph_filename(s2)))
+ self.assertTrue(checkpoint_management.checkpoint_exists(s1))
self.assertTrue(
- saver_module.checkpoint_exists(saver_module._meta_graph_filename(s1)))
+ checkpoint_management.checkpoint_exists(
+ checkpoint_management.meta_graph_filename(s1)))
self.assertCheckpointState(
model_checkpoint_path=s1,
all_model_checkpoint_paths=[s2, s1],
@@ -1365,7 +1381,8 @@ class MaxToKeepTest(test.TestCase):
else:
self.assertEqual(4, len(gfile.Glob(s1 + "*")))
- self.assertTrue(gfile.Exists(saver_module._meta_graph_filename(s1)))
+ self.assertTrue(
+ gfile.Exists(checkpoint_management.meta_graph_filename(s1)))
s2 = save.save(sess, os.path.join(save_dir, "s2"))
self.assertEqual([s1, s2], save.last_checkpoints)
@@ -1373,27 +1390,32 @@ class MaxToKeepTest(test.TestCase):
self.assertEqual(2, len(gfile.Glob(s1)))
else:
self.assertEqual(4, len(gfile.Glob(s1 + "*")))
- self.assertTrue(gfile.Exists(saver_module._meta_graph_filename(s1)))
+ self.assertTrue(
+ gfile.Exists(checkpoint_management.meta_graph_filename(s1)))
if save._write_version is saver_pb2.SaverDef.V1:
self.assertEqual(2, len(gfile.Glob(s2)))
else:
self.assertEqual(4, len(gfile.Glob(s2 + "*")))
- self.assertTrue(gfile.Exists(saver_module._meta_graph_filename(s2)))
+ self.assertTrue(
+ gfile.Exists(checkpoint_management.meta_graph_filename(s2)))
s3 = save.save(sess, os.path.join(save_dir, "s3"))
self.assertEqual([s2, s3], save.last_checkpoints)
self.assertEqual(0, len(gfile.Glob(s1 + "*")))
- self.assertFalse(gfile.Exists(saver_module._meta_graph_filename(s1)))
+ self.assertFalse(
+ gfile.Exists(checkpoint_management.meta_graph_filename(s1)))
if save._write_version is saver_pb2.SaverDef.V1:
self.assertEqual(2, len(gfile.Glob(s2)))
else:
self.assertEqual(4, len(gfile.Glob(s2 + "*")))
- self.assertTrue(gfile.Exists(saver_module._meta_graph_filename(s2)))
+ self.assertTrue(
+ gfile.Exists(checkpoint_management.meta_graph_filename(s2)))
if save._write_version is saver_pb2.SaverDef.V1:
self.assertEqual(2, len(gfile.Glob(s3)))
else:
self.assertEqual(4, len(gfile.Glob(s3 + "*")))
- self.assertTrue(gfile.Exists(saver_module._meta_graph_filename(s3)))
+ self.assertTrue(
+ gfile.Exists(checkpoint_management.meta_graph_filename(s3)))
def testNoMaxToKeep(self):
save_dir = self._get_test_dir("no_max_to_keep")
@@ -1408,20 +1430,20 @@ class MaxToKeepTest(test.TestCase):
self.assertEqual([], save.last_checkpoints)
s1 = save.save(sess, os.path.join(save_dir, "s1"))
self.assertEqual([], save.last_checkpoints)
- self.assertTrue(saver_module.checkpoint_exists(s1))
+ self.assertTrue(checkpoint_management.checkpoint_exists(s1))
s2 = save.save(sess, os.path.join(save_dir, "s2"))
self.assertEqual([], save.last_checkpoints)
- self.assertTrue(saver_module.checkpoint_exists(s2))
+ self.assertTrue(checkpoint_management.checkpoint_exists(s2))
# Test max_to_keep being 0.
save2 = saver_module.Saver({"v": v}, max_to_keep=0)
self.assertEqual([], save2.last_checkpoints)
s1 = save2.save(sess, os.path.join(save_dir2, "s1"))
self.assertEqual([], save2.last_checkpoints)
- self.assertTrue(saver_module.checkpoint_exists(s1))
+ self.assertTrue(checkpoint_management.checkpoint_exists(s1))
s2 = save2.save(sess, os.path.join(save_dir2, "s2"))
self.assertEqual([], save2.last_checkpoints)
- self.assertTrue(saver_module.checkpoint_exists(s2))
+ self.assertTrue(checkpoint_management.checkpoint_exists(s2))
def testNoMetaGraph(self):
save_dir = self._get_test_dir("no_meta_graph")
@@ -1432,8 +1454,9 @@ class MaxToKeepTest(test.TestCase):
variables.global_variables_initializer().run()
s1 = save.save(sess, os.path.join(save_dir, "s1"), write_meta_graph=False)
- self.assertTrue(saver_module.checkpoint_exists(s1))
- self.assertFalse(gfile.Exists(saver_module._meta_graph_filename(s1)))
+ self.assertTrue(checkpoint_management.checkpoint_exists(s1))
+ self.assertFalse(
+ gfile.Exists(checkpoint_management.meta_graph_filename(s1)))
class KeepCheckpointEveryNHoursTest(test.TestCase):
@@ -1489,10 +1512,10 @@ class KeepCheckpointEveryNHoursTest(test.TestCase):
self.assertEqual([s3, s4], save.last_checkpoints)
# Check that s1 is still here, but s2 is gone.
- self.assertTrue(saver_module.checkpoint_exists(s1))
- self.assertFalse(saver_module.checkpoint_exists(s2))
- self.assertTrue(saver_module.checkpoint_exists(s3))
- self.assertTrue(saver_module.checkpoint_exists(s4))
+ self.assertTrue(checkpoint_management.checkpoint_exists(s1))
+ self.assertFalse(checkpoint_management.checkpoint_exists(s2))
+ self.assertTrue(checkpoint_management.checkpoint_exists(s3))
+ self.assertTrue(checkpoint_management.checkpoint_exists(s4))
class SaveRestoreWithVariableNameMap(test.TestCase):
@@ -1571,221 +1594,6 @@ class SaveRestoreWithVariableNameMap(test.TestCase):
self._testNonReshape(variables.Variable)
-class LatestCheckpointWithRelativePaths(test.TestCase):
-
- @staticmethod
- @contextlib.contextmanager
- def tempWorkingDir(temppath):
- cwd = os.getcwd()
- os.chdir(temppath)
- try:
- yield
- finally:
- os.chdir(cwd)
-
- @staticmethod
- @contextlib.contextmanager
- def tempDir():
- tempdir = tempfile.mkdtemp()
- try:
- yield tempdir
- finally:
- shutil.rmtree(tempdir)
-
- def testNameCollision(self):
- # Make sure we have a clean directory to work in.
- with self.tempDir() as tempdir:
- # Jump to that directory until this test is done.
- with self.tempWorkingDir(tempdir):
- # Save training snapshots to a relative path.
- traindir = "train/"
- os.mkdir(traindir)
- # Collides with the default name of the checkpoint state file.
- filepath = os.path.join(traindir, "checkpoint")
-
- with self.test_session() as sess:
- unused_a = variables.Variable(0.0) # So that Saver saves something.
- variables.global_variables_initializer().run()
-
- # Should fail.
- saver = saver_module.Saver(sharded=False)
- with self.assertRaisesRegexp(ValueError, "collides with"):
- saver.save(sess, filepath)
-
- # Succeeds: the file will be named "checkpoint-<step>".
- saver.save(sess, filepath, global_step=1)
- self.assertIsNotNone(saver_module.latest_checkpoint(traindir))
-
- # Succeeds: the file will be named "checkpoint-<i>-of-<n>".
- saver = saver_module.Saver(sharded=True)
- saver.save(sess, filepath)
- self.assertIsNotNone(saver_module.latest_checkpoint(traindir))
-
- # Succeeds: the file will be named "checkpoint-<step>-<i>-of-<n>".
- saver = saver_module.Saver(sharded=True)
- saver.save(sess, filepath, global_step=1)
- self.assertIsNotNone(saver_module.latest_checkpoint(traindir))
-
- def testRelativePath(self):
- # Make sure we have a clean directory to work in.
- with self.tempDir() as tempdir:
-
- # Jump to that directory until this test is done.
- with self.tempWorkingDir(tempdir):
-
- # Save training snapshots to a relative path.
- traindir = "train/"
- os.mkdir(traindir)
-
- filename = "snapshot"
- filepath = os.path.join(traindir, filename)
-
- with self.test_session() as sess:
- # Build a simple graph.
- v0 = variables.Variable(0.0)
- inc = v0.assign_add(1.0)
-
- save = saver_module.Saver({"v0": v0})
-
- # Record a short training history.
- variables.global_variables_initializer().run()
- save.save(sess, filepath, global_step=0)
- inc.eval()
- save.save(sess, filepath, global_step=1)
- inc.eval()
- save.save(sess, filepath, global_step=2)
-
- with self.test_session() as sess:
- # Build a new graph with different initialization.
- v0 = variables.Variable(-1.0)
-
- # Create a new saver.
- save = saver_module.Saver({"v0": v0})
- variables.global_variables_initializer().run()
-
- # Get the most recent checkpoint name from the training history file.
- name = saver_module.latest_checkpoint(traindir)
- self.assertIsNotNone(name)
-
- # Restore "v0" from that checkpoint.
- save.restore(sess, name)
- self.assertEqual(v0.eval(), 2.0)
-
-
-class CheckpointStateTest(test.TestCase):
-
- def _get_test_dir(self, dirname):
- test_dir = os.path.join(self.get_temp_dir(), dirname)
- gfile.MakeDirs(test_dir)
- return test_dir
-
- def testAbsPath(self):
- save_dir = self._get_test_dir("abs_paths")
- abs_path = os.path.join(save_dir, "model-0")
- ckpt = saver_module.generate_checkpoint_state_proto(save_dir, abs_path)
- self.assertEqual(ckpt.model_checkpoint_path, abs_path)
- self.assertTrue(os.path.isabs(ckpt.model_checkpoint_path))
- self.assertEqual(len(ckpt.all_model_checkpoint_paths), 1)
- self.assertEqual(ckpt.all_model_checkpoint_paths[-1], abs_path)
-
- def testRelPath(self):
- train_dir = "train"
- model = os.path.join(train_dir, "model-0")
- # model_checkpoint_path should have no "train" directory part.
- new_rel_path = "model-0"
- ckpt = saver_module.generate_checkpoint_state_proto(train_dir, model)
- self.assertEqual(ckpt.model_checkpoint_path, new_rel_path)
- self.assertEqual(len(ckpt.all_model_checkpoint_paths), 1)
- self.assertEqual(ckpt.all_model_checkpoint_paths[-1], new_rel_path)
-
- def testAllModelCheckpointPaths(self):
- save_dir = self._get_test_dir("all_models_test")
- abs_path = os.path.join(save_dir, "model-0")
- for paths in [None, [], ["model-2"]]:
- ckpt = saver_module.generate_checkpoint_state_proto(
- save_dir, abs_path, all_model_checkpoint_paths=paths)
- self.assertEqual(ckpt.model_checkpoint_path, abs_path)
- self.assertTrue(os.path.isabs(ckpt.model_checkpoint_path))
- self.assertEqual(
- len(ckpt.all_model_checkpoint_paths), len(paths) if paths else 1)
- self.assertEqual(ckpt.all_model_checkpoint_paths[-1], abs_path)
-
- def testUpdateCheckpointState(self):
- save_dir = self._get_test_dir("update_checkpoint_state")
- os.chdir(save_dir)
- # Make a temporary train directory.
- train_dir = "train"
- os.mkdir(train_dir)
- abs_path = os.path.join(save_dir, "model-0")
- rel_path = os.path.join("train", "model-2")
- saver_module.update_checkpoint_state(
- train_dir, rel_path, all_model_checkpoint_paths=[abs_path, rel_path])
- ckpt = saver_module.get_checkpoint_state(train_dir)
- self.assertEqual(ckpt.model_checkpoint_path, rel_path)
- self.assertEqual(len(ckpt.all_model_checkpoint_paths), 2)
- self.assertEqual(ckpt.all_model_checkpoint_paths[-1], rel_path)
- self.assertEqual(ckpt.all_model_checkpoint_paths[0], abs_path)
-
- def testUpdateCheckpointStateSaveRelativePaths(self):
- save_dir = self._get_test_dir("update_checkpoint_state")
- os.chdir(save_dir)
- abs_path2 = os.path.join(save_dir, "model-2")
- rel_path2 = "model-2"
- abs_path0 = os.path.join(save_dir, "model-0")
- rel_path0 = "model-0"
- saver_module._update_checkpoint_state( # pylint: disable=protected-access
- save_dir=save_dir,
- model_checkpoint_path=abs_path2,
- all_model_checkpoint_paths=[rel_path0, abs_path2],
- save_relative_paths=True)
-
- # File should contain relative paths.
- file_content = file_io.read_file_to_string(
- os.path.join(save_dir, "checkpoint"))
- ckpt = CheckpointState()
- text_format.Merge(file_content, ckpt)
- self.assertEqual(ckpt.model_checkpoint_path, rel_path2)
- self.assertEqual(len(ckpt.all_model_checkpoint_paths), 2)
- self.assertEqual(ckpt.all_model_checkpoint_paths[-1], rel_path2)
- self.assertEqual(ckpt.all_model_checkpoint_paths[0], rel_path0)
-
- # get_checkpoint_state should return absolute paths.
- ckpt = saver_module.get_checkpoint_state(save_dir)
- self.assertEqual(ckpt.model_checkpoint_path, abs_path2)
- self.assertEqual(len(ckpt.all_model_checkpoint_paths), 2)
- self.assertEqual(ckpt.all_model_checkpoint_paths[-1], abs_path2)
- self.assertEqual(ckpt.all_model_checkpoint_paths[0], abs_path0)
-
- def testCheckPointStateFailsWhenIncomplete(self):
- save_dir = self._get_test_dir("checkpoint_state_fails_when_incomplete")
- os.chdir(save_dir)
- ckpt_path = os.path.join(save_dir, "checkpoint")
- ckpt_file = open(ckpt_path, "w")
- ckpt_file.write("")
- ckpt_file.close()
- with self.assertRaises(ValueError):
- saver_module.get_checkpoint_state(save_dir)
-
- def testCheckPointCompletesRelativePaths(self):
- save_dir = self._get_test_dir("checkpoint_completes_relative_paths")
- os.chdir(save_dir)
- ckpt_path = os.path.join(save_dir, "checkpoint")
- ckpt_file = open(ckpt_path, "w")
- ckpt_file.write("""
- model_checkpoint_path: "./model.ckpt-687529"
- all_model_checkpoint_paths: "./model.ckpt-687500"
- all_model_checkpoint_paths: "./model.ckpt-687529"
- """)
- ckpt_file.close()
- ckpt = saver_module.get_checkpoint_state(save_dir)
- self.assertEqual(ckpt.model_checkpoint_path,
- os.path.join(save_dir, "./model.ckpt-687529"))
- self.assertEqual(ckpt.all_model_checkpoint_paths[0],
- os.path.join(save_dir, "./model.ckpt-687500"))
- self.assertEqual(ckpt.all_model_checkpoint_paths[1],
- os.path.join(save_dir, "./model.ckpt-687529"))
-
-
class MetaGraphTest(test.TestCase):
def _get_test_dir(self, dirname):
@@ -2628,62 +2436,6 @@ class WriteGraphTest(test.TestCase):
self.assertTrue(os.path.exists(path))
-class SaverUtilsTest(test.TestCase):
-
- def setUp(self):
- self._base_dir = os.path.join(self.get_temp_dir(), "saver_utils_test")
- gfile.MakeDirs(self._base_dir)
-
- def tearDown(self):
- gfile.DeleteRecursively(self._base_dir)
-
- def testCheckpointExists(self):
- for sharded in (False, True):
- for version in (saver_pb2.SaverDef.V2, saver_pb2.SaverDef.V1):
- with self.test_session(graph=ops_lib.Graph()) as sess:
- unused_v = variables.Variable(1.0, name="v")
- variables.global_variables_initializer().run()
- saver = saver_module.Saver(sharded=sharded, write_version=version)
-
- path = os.path.join(self._base_dir, "%s-%s" % (sharded, version))
- self.assertFalse(
- saver_module.checkpoint_exists(path)) # Not saved yet.
-
- ckpt_prefix = saver.save(sess, path)
- self.assertTrue(saver_module.checkpoint_exists(ckpt_prefix))
-
- ckpt_prefix = saver_module.latest_checkpoint(self._base_dir)
- self.assertTrue(saver_module.checkpoint_exists(ckpt_prefix))
-
- def testGetCheckpointMtimes(self):
- prefixes = []
- for version in (saver_pb2.SaverDef.V2, saver_pb2.SaverDef.V1):
- with self.test_session(graph=ops_lib.Graph()) as sess:
- unused_v = variables.Variable(1.0, name="v")
- variables.global_variables_initializer().run()
- saver = saver_module.Saver(write_version=version)
- prefixes.append(
- saver.save(sess, os.path.join(self._base_dir, str(version))))
-
- mtimes = saver_module.get_checkpoint_mtimes(prefixes)
- self.assertEqual(2, len(mtimes))
- self.assertTrue(mtimes[1] >= mtimes[0])
-
- def testRemoveCheckpoint(self):
- for sharded in (False, True):
- for version in (saver_pb2.SaverDef.V2, saver_pb2.SaverDef.V1):
- with self.test_session(graph=ops_lib.Graph()) as sess:
- unused_v = variables.Variable(1.0, name="v")
- variables.global_variables_initializer().run()
- saver = saver_module.Saver(sharded=sharded, write_version=version)
-
- path = os.path.join(self._base_dir, "%s-%s" % (sharded, version))
- ckpt_prefix = saver.save(sess, path)
- self.assertTrue(saver_module.checkpoint_exists(ckpt_prefix))
- saver_module.remove_checkpoint(ckpt_prefix, version)
- self.assertFalse(saver_module.checkpoint_exists(ckpt_prefix))
-
-
class ScopedGraphTest(test.TestCase):
def _get_test_dir(self, dirname):
diff --git a/tensorflow/python/training/session_manager.py b/tensorflow/python/training/session_manager.py
index 974f75777f..a2e0645ba8 100644
--- a/tensorflow/python/training/session_manager.py
+++ b/tensorflow/python/training/session_manager.py
@@ -24,7 +24,7 @@ from tensorflow.python.client import session
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.platform import tf_logging as logging
-from tensorflow.python.training import saver as saver_mod
+from tensorflow.python.training import checkpoint_management
from tensorflow.python.util.tf_export import tf_export
@@ -197,13 +197,13 @@ class SessionManager(object):
# Waits up until max_wait_secs for checkpoint to become available.
wait_time = 0
- ckpt = saver_mod.get_checkpoint_state(checkpoint_dir)
+ ckpt = checkpoint_management.get_checkpoint_state(checkpoint_dir)
while not ckpt or not ckpt.model_checkpoint_path:
if wait_for_checkpoint and wait_time < max_wait_secs:
logging.info("Waiting for checkpoint to be available.")
time.sleep(self._recovery_wait_secs)
wait_time += self._recovery_wait_secs
- ckpt = saver_mod.get_checkpoint_state(checkpoint_dir)
+ ckpt = checkpoint_management.get_checkpoint_state(checkpoint_dir)
else:
return sess, False
diff --git a/tensorflow/python/training/session_manager_test.py b/tensorflow/python/training/session_manager_test.py
index 6670d9365f..d7e6dac95b 100644
--- a/tensorflow/python/training/session_manager_test.py
+++ b/tensorflow/python/training/session_manager_test.py
@@ -30,6 +30,7 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
+from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import saver as saver_lib
from tensorflow.python.training import server_lib
from tensorflow.python.training import session_manager
@@ -174,13 +175,13 @@ class SessionManagerTest(test.TestCase):
os.path.join(checkpoint_dir, "recover_session_checkpoint"))
self._test_recovered_variable(checkpoint_dir=checkpoint_dir)
self._test_recovered_variable(
- checkpoint_filename_with_path=saver_lib.latest_checkpoint(
+ checkpoint_filename_with_path=checkpoint_management.latest_checkpoint(
checkpoint_dir))
# Cannot set both checkpoint_dir and checkpoint_filename_with_path.
with self.assertRaises(ValueError):
self._test_recovered_variable(
checkpoint_dir=checkpoint_dir,
- checkpoint_filename_with_path=saver_lib.latest_checkpoint(
+ checkpoint_filename_with_path=checkpoint_management.latest_checkpoint(
checkpoint_dir))
def testWaitForSessionReturnsNoneAfterTimeout(self):
diff --git a/tensorflow/python/training/supervisor_test.py b/tensorflow/python/training/supervisor_test.py
index 4abce85852..71ed88093a 100644
--- a/tensorflow/python/training/supervisor_test.py
+++ b/tensorflow/python/training/supervisor_test.py
@@ -44,6 +44,7 @@ from tensorflow.python.platform import test
from tensorflow.python.summary import summary
from tensorflow.python.summary import summary_iterator
from tensorflow.python.summary.writer import writer
+from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import input as input_lib
from tensorflow.python.training import saver as saver_lib
from tensorflow.python.training import server_lib
@@ -83,7 +84,7 @@ class SupervisorTest(test.TestCase):
end_time = time.time() + timeout_secs
while time.time() < end_time:
if for_checkpoint:
- if saver_lib.checkpoint_exists(pattern):
+ if checkpoint_management.checkpoint_exists(pattern):
return
else:
if len(gfile.Glob(pattern)) >= 1:
diff --git a/tensorflow/python/training/training.py b/tensorflow/python/training/training.py
index 3f2dc67976..544010afbe 100644
--- a/tensorflow/python/training/training.py
+++ b/tensorflow/python/training/training.py
@@ -82,12 +82,12 @@ from tensorflow.python.training.monitored_session import WorkerSessionCreator
from tensorflow.python.training.monitored_session import MonitoredSession
from tensorflow.python.training.monitored_session import SingularMonitoredSession
from tensorflow.python.training.saver import Saver
-from tensorflow.python.training.saver import checkpoint_exists
-from tensorflow.python.training.saver import generate_checkpoint_state_proto
-from tensorflow.python.training.saver import get_checkpoint_mtimes
-from tensorflow.python.training.saver import get_checkpoint_state
-from tensorflow.python.training.saver import latest_checkpoint
-from tensorflow.python.training.saver import update_checkpoint_state
+from tensorflow.python.training.checkpoint_management import checkpoint_exists
+from tensorflow.python.training.checkpoint_management import generate_checkpoint_state_proto
+from tensorflow.python.training.checkpoint_management import get_checkpoint_mtimes
+from tensorflow.python.training.checkpoint_management import get_checkpoint_state
+from tensorflow.python.training.checkpoint_management import latest_checkpoint
+from tensorflow.python.training.checkpoint_management import update_checkpoint_state
from tensorflow.python.training.saver import export_meta_graph
from tensorflow.python.training.saver import import_meta_graph
from tensorflow.python.training.session_run_hook import SessionRunHook