diff options
35 files changed, 1011 insertions, 817 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py index 30a993b1f7..77148aceec 100644 --- a/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py @@ -28,6 +28,7 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import variables from tensorflow.python.platform import test +from tensorflow.python.training import checkpoint_management from tensorflow.python.training import saver as saver_lib from tensorflow.python.training import training_util @@ -55,7 +56,7 @@ class CheckpointInputPipelineHookTest(test.TestCase): def _read_vars(self, model_dir): """Returns (global_step, latest_feature).""" with ops.Graph().as_default() as g: - ckpt_path = saver_lib.latest_checkpoint(model_dir) + ckpt_path = checkpoint_management.latest_checkpoint(model_dir) meta_filename = ckpt_path + '.meta' saver_lib.import_meta_graph(meta_filename) saver = saver_lib.Saver() diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/dataset_serialization_test_base.py b/tensorflow/contrib/data/python/kernel_tests/serialization/dataset_serialization_test_base.py index 393f08850b..3ed4dfb729 100644 --- a/tensorflow/contrib/data/python/kernel_tests/serialization/dataset_serialization_test_base.py +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/dataset_serialization_test_base.py @@ -32,6 +32,7 @@ from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import variables from tensorflow.python.platform import gfile from tensorflow.python.platform import test +from tensorflow.python.training import checkpoint_management from tensorflow.python.training import saver as saver_lib from tensorflow.python.util import nest @@ -655,7 +656,7 @@ class DatasetSerializationTestBase(test.TestCase): return os.path.join(self.get_temp_dir(), "iterator") def _latest_ckpt(self): - return saver_lib.latest_checkpoint(self.get_temp_dir()) + return checkpoint_management.latest_checkpoint(self.get_temp_dir()) def _save(self, sess, saver): saver.save(sess, self._ckpt_path()) diff --git a/tensorflow/contrib/data/python/ops/iterator_ops.py b/tensorflow/contrib/data/python/ops/iterator_ops.py index 0d71be6601..d2c1d0d362 100644 --- a/tensorflow/contrib/data/python/ops/iterator_ops.py +++ b/tensorflow/contrib/data/python/ops/iterator_ops.py @@ -20,6 +20,7 @@ from tensorflow.python.data.ops import iterator_ops from tensorflow.python.framework import ops from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.training import basic_session_run_hooks +from tensorflow.python.training import checkpoint_management from tensorflow.python.training import saver as saver_lib from tensorflow.python.training import session_run_hook @@ -206,7 +207,7 @@ class CheckpointInputPipelineHook(session_run_hook.SessionRunHook): # Check if there is an existing checkpoint. If so, restore from it. # pylint: disable=protected-access - latest_checkpoint_path = saver_lib.latest_checkpoint( + latest_checkpoint_path = checkpoint_management.latest_checkpoint( self._checkpoint_saver_hook._checkpoint_dir, latest_filename=self._latest_filename) if latest_checkpoint_path: diff --git a/tensorflow/contrib/eager/python/datasets_test.py b/tensorflow/contrib/eager/python/datasets_test.py index 2917eaac97..a753d77580 100644 --- a/tensorflow/contrib/eager/python/datasets_test.py +++ b/tensorflow/contrib/eager/python/datasets_test.py @@ -37,7 +37,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import math_ops from tensorflow.python.ops import script_ops -from tensorflow.python.training import saver +from tensorflow.python.training import checkpoint_management from tensorflow.python.training.checkpointable import util as checkpointable_utils @@ -314,7 +314,8 @@ class IteratorTest(test.TestCase): for i in range(5): iterator = datasets.Iterator(dataset) checkpoint = checkpointable_utils.Checkpoint(iterator=iterator) - checkpoint.restore(saver.latest_checkpoint(checkpoint_directory)) + checkpoint.restore(checkpoint_management.latest_checkpoint( + checkpoint_directory)) for j in range(2): self.assertEqual(i * 2 + j, iterator.get_next().numpy()) checkpoint.save(file_prefix=checkpoint_prefix) diff --git a/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py b/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py index 8ac553e0ae..d18a097063 100644 --- a/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py +++ b/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py @@ -36,7 +36,7 @@ from third_party.examples.eager.spinn import spinn from tensorflow.contrib.summary import summary_test_util from tensorflow.python.eager import test from tensorflow.python.framework import test_util -from tensorflow.python.training import saver +from tensorflow.python.training import checkpoint_management from tensorflow.python.training.checkpointable import util as checkpointable_utils # pylint: enable=g-bad-import-order @@ -422,7 +422,7 @@ class SpinnTest(test_util.TensorFlowTestCase): # 5. Verify that checkpoints exist and contains all the expected variables. self.assertTrue(glob.glob(os.path.join(config.logdir, "ckpt*"))) object_graph = checkpointable_utils.object_metadata( - saver.latest_checkpoint(config.logdir)) + checkpoint_management.latest_checkpoint(config.logdir)) ckpt_variable_names = set() for node in object_graph.nodes: for attribute in node.attributes: diff --git a/tensorflow/contrib/framework/python/framework/checkpoint_utils.py b/tensorflow/contrib/framework/python/framework/checkpoint_utils.py index 9e356dd965..e7184a01fb 100644 --- a/tensorflow/contrib/framework/python/framework/checkpoint_utils.py +++ b/tensorflow/contrib/framework/python/framework/checkpoint_utils.py @@ -27,7 +27,7 @@ from tensorflow.python.ops import variable_scope as vs from tensorflow.python.ops import variables from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging as logging -from tensorflow.python.training import saver +from tensorflow.python.training import checkpoint_management from tensorflow.python.training import training as train __all__ = [ @@ -40,7 +40,7 @@ __all__ = [ def _get_checkpoint_filename(filepattern): """Returns checkpoint filename given directory or specific filepattern.""" if gfile.IsDirectory(filepattern): - return saver.latest_checkpoint(filepattern) + return checkpoint_management.latest_checkpoint(filepattern) return filepattern diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py index 7a026a15e4..c1de42782e 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py @@ -72,6 +72,7 @@ from tensorflow.python.saved_model import builder as saved_model_builder from tensorflow.python.saved_model import tag_constants from tensorflow.python.summary import summary as core_summary from tensorflow.python.training import basic_session_run_hooks +from tensorflow.python.training import checkpoint_management from tensorflow.python.training import device_setter from tensorflow.python.training import monitored_session from tensorflow.python.training import saver @@ -891,7 +892,7 @@ class BaseEstimator(sklearn.BaseEstimator, evaluable.Evaluable, # Check that model has been trained (if nothing has been set explicitly). if not checkpoint_path: - latest_path = saver.latest_checkpoint(self._model_dir) + latest_path = checkpoint_management.latest_checkpoint(self._model_dir) if not latest_path: raise NotFittedError( "Couldn't find trained model at %s." % self._model_dir) @@ -956,7 +957,7 @@ class BaseEstimator(sklearn.BaseEstimator, evaluable.Evaluable, as_iterable=True, iterate_batches=False): # Check that model has been trained. - checkpoint_path = saver.latest_checkpoint(self._model_dir) + checkpoint_path = checkpoint_management.latest_checkpoint(self._model_dir) if not checkpoint_path: raise NotFittedError( "Couldn't find trained model at %s." % self._model_dir) @@ -1364,7 +1365,7 @@ class Estimator(BaseEstimator): if not checkpoint_path: # Locate the latest checkpoint - checkpoint_path = saver.latest_checkpoint(self._model_dir) + checkpoint_path = checkpoint_management.latest_checkpoint(self._model_dir) if not checkpoint_path: raise NotFittedError( "Couldn't find trained model at %s." % self._model_dir) diff --git a/tensorflow/contrib/learn/python/learn/experiment.py b/tensorflow/contrib/learn/python/learn/experiment.py index f8a3709ee5..08e907a608 100644 --- a/tensorflow/contrib/learn/python/learn/experiment.py +++ b/tensorflow/contrib/learn/python/learn/experiment.py @@ -41,7 +41,7 @@ from tensorflow.python.estimator import estimator as core_estimator from tensorflow.python.framework import ops from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import basic_session_run_hooks -from tensorflow.python.training import saver +from tensorflow.python.training import checkpoint_management from tensorflow.python.training import server_lib from tensorflow.python.util import compat from tensorflow.python.util import function_utils @@ -95,7 +95,7 @@ class _EvalAndExportListener(basic_session_run_hooks.CheckpointSaverListener): # Load and cache the path of the most recent checkpoint to avoid duplicate # searches on GCS. logging.info("Checking for checkpoint in %s", self._model_dir) - latest_path = saver.latest_checkpoint(self._model_dir) + latest_path = checkpoint_management.latest_checkpoint(self._model_dir) if not latest_path: logging.warning("Skipping evaluation and export since model has not been " @@ -516,7 +516,8 @@ class Experiment(object): start = time.time() error_msg = None - latest_path = saver.latest_checkpoint(self._estimator.model_dir) + latest_path = checkpoint_management.latest_checkpoint( + self._estimator.model_dir) if not latest_path: error_msg = ("Estimator is not fitted yet. " "Will start an evaluation when a checkpoint is ready.") @@ -778,7 +779,8 @@ class Experiment(object): saving_listeners=self._saving_listeners) logging.info("Evaluating model now.") - latest_checkpoint = saver.latest_checkpoint(self._estimator.model_dir) + latest_checkpoint = checkpoint_management.latest_checkpoint( + self._estimator.model_dir) eval_result = self._call_evaluate( input_fn=self._eval_input_fn, steps=self._eval_steps, diff --git a/tensorflow/contrib/learn/python/learn/graph_actions_test.py b/tensorflow/contrib/learn/python/learn/graph_actions_test.py index 0d039d593b..df156da3f4 100644 --- a/tensorflow/contrib/learn/python/learn/graph_actions_test.py +++ b/tensorflow/contrib/learn/python/learn/graph_actions_test.py @@ -35,6 +35,7 @@ from tensorflow.python.ops import state_ops from tensorflow.python.ops import variables from tensorflow.python.platform import test from tensorflow.python.summary import summary +from tensorflow.python.training import checkpoint_management from tensorflow.python.training import saver as saver_lib @@ -124,7 +125,7 @@ class GraphActionsTest(test.TestCase): # TODO(ptucker): Test number and contents of checkpoint files. def _assert_ckpt(self, output_dir, expected=True): - ckpt_state = saver_lib.get_checkpoint_state(output_dir) + ckpt_state = checkpoint_management.get_checkpoint_state(output_dir) if expected: pattern = '%s/model.ckpt-.*' % output_dir primary_ckpt_path = ckpt_state.model_checkpoint_path @@ -434,7 +435,7 @@ class GraphActionsTrainTest(test.TestCase): # TODO(ptucker): Test number and contents of checkpoint files. def _assert_ckpt(self, output_dir, expected=True): - ckpt_state = saver_lib.get_checkpoint_state(output_dir) + ckpt_state = checkpoint_management.get_checkpoint_state(output_dir) if expected: pattern = '%s/model.ckpt-.*' % output_dir primary_ckpt_path = ckpt_state.model_checkpoint_path diff --git a/tensorflow/contrib/learn/python/learn/monitors.py b/tensorflow/contrib/learn/python/learn/monitors.py index 77f7c73d54..3d691d4340 100644 --- a/tensorflow/contrib/learn/python/learn/monitors.py +++ b/tensorflow/contrib/learn/python/learn/monitors.py @@ -51,7 +51,7 @@ from tensorflow.python.estimator import estimator as core_estimator from tensorflow.python.framework import ops from tensorflow.python.platform import tf_logging as logging from tensorflow.python.summary import summary as core_summary -from tensorflow.python.training import saver as saver_lib +from tensorflow.python.training import checkpoint_management from tensorflow.python.training import session_run_hook from tensorflow.python.training import training_util from tensorflow.python.util import deprecation @@ -735,7 +735,8 @@ class ValidationMonitor(EveryN): return False self._last_checkpoint_check_time = current_time # Check that we are not running evaluation on the same checkpoint. - latest_path = saver_lib.latest_checkpoint(self._estimator.model_dir) + latest_path = checkpoint_management.latest_checkpoint( + self._estimator.model_dir) if latest_path is None: logging.debug("Skipping evaluation since model has not been saved yet " "at step %d.", step) @@ -1059,7 +1060,8 @@ class ExportMonitor(EveryN): def end(self, session=None): super(ExportMonitor, self).end(session=session) - latest_path = saver_lib.latest_checkpoint(self._estimator.model_dir) + latest_path = checkpoint_management.latest_checkpoint( + self._estimator.model_dir) if latest_path is None: logging.info("Skipping export at the end since model has not been saved " "yet.") diff --git a/tensorflow/contrib/learn/python/learn/monitors_test.py b/tensorflow/contrib/learn/python/learn/monitors_test.py index 5c34d0ddb0..ff1da32c21 100644 --- a/tensorflow/contrib/learn/python/learn/monitors_test.py +++ b/tensorflow/contrib/learn/python/learn/monitors_test.py @@ -39,9 +39,9 @@ from tensorflow.python.ops import variables from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging as logging from tensorflow.python.summary import summary +from tensorflow.python.training import checkpoint_management from tensorflow.python.training import gradient_descent from tensorflow.python.training import monitored_session -from tensorflow.python.training import saver from tensorflow.python.training import training_util @@ -317,7 +317,7 @@ class MonitorsTest(test.TestCase): self._run_monitor(monitor) @test.mock.patch.object(estimators, 'Estimator', autospec=True) - @test.mock.patch.object(saver, 'latest_checkpoint') + @test.mock.patch.object(checkpoint_management, 'latest_checkpoint') def test_validation_monitor_no_ckpt(self, mock_latest_checkpoint, mock_estimator_class): estimator = mock_estimator_class() @@ -336,7 +336,7 @@ class MonitorsTest(test.TestCase): mock_latest_checkpoint.assert_called_with(model_dir) @test.mock.patch.object(estimators, 'Estimator', autospec=True) - @test.mock.patch.object(saver, 'latest_checkpoint') + @test.mock.patch.object(checkpoint_management, 'latest_checkpoint') def test_validation_monitor_no_early_stopping_rounds(self, mock_latest_checkpoint, mock_estimator_class): @@ -356,7 +356,7 @@ class MonitorsTest(test.TestCase): self._assert_validation_monitor(monitor) @test.mock.patch.object(estimators, 'Estimator', autospec=True) - @test.mock.patch.object(saver, 'latest_checkpoint') + @test.mock.patch.object(checkpoint_management, 'latest_checkpoint') def test_validation_monitor_invalid_metric(self, mock_latest_checkpoint, mock_estimator_class): estimator = mock_estimator_class() @@ -375,7 +375,7 @@ class MonitorsTest(test.TestCase): self._run_monitor(monitor, num_epochs=1, num_steps_per_epoch=1) @test.mock.patch.object(estimators, 'Estimator', autospec=True) - @test.mock.patch.object(saver, 'latest_checkpoint') + @test.mock.patch.object(checkpoint_management, 'latest_checkpoint') def test_validation_monitor(self, mock_latest_checkpoint, mock_estimator_class): estimator = mock_estimator_class() @@ -464,7 +464,7 @@ class MonitorsTest(test.TestCase): monitor.epoch_end(epoch=0) monitor.end() - @test.mock.patch.object(saver, 'latest_checkpoint') + @test.mock.patch.object(checkpoint_management, 'latest_checkpoint') def test_validation_monitor_with_core_estimator(self, mock_latest_checkpoint): estimator = test.mock.Mock(spec=core_estimator.Estimator) model_dir = 'model/dir' @@ -495,7 +495,7 @@ class MonitorsTest(test.TestCase): expected_best_metrics={'loss': 42.0, 'auc': 0.5}) monitor.post_step(step=step, session=None) - @test.mock.patch.object(saver, 'latest_checkpoint') + @test.mock.patch.object(checkpoint_management, 'latest_checkpoint') def test_validation_monitor_fail_with_core_estimator_and_metrics( self, mock_latest_checkpoint): estimator = test.mock.Mock(spec=core_estimator.Estimator) diff --git a/tensorflow/contrib/learn/python/learn/utils/export.py b/tensorflow/contrib/learn/python/learn/utils/export.py index 3eacac7a3d..0144b93814 100644 --- a/tensorflow/contrib/learn/python/learn/utils/export.py +++ b/tensorflow/contrib/learn/python/learn/utils/export.py @@ -35,6 +35,7 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.training import checkpoint_management from tensorflow.python.training import saver as tf_saver from tensorflow.python.training import training_util @@ -298,7 +299,8 @@ def _export_estimator(estimator, # If checkpoint_path is specified, use the specified checkpoint path. checkpoint_path = (checkpoint_path or - tf_saver.latest_checkpoint(estimator._model_dir)) + checkpoint_management.latest_checkpoint( + estimator._model_dir)) with ops.Graph().as_default() as g: training_util.create_global_step(g) diff --git a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py index f8106d1e4a..66af6833da 100644 --- a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py +++ b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py @@ -55,7 +55,7 @@ from tensorflow.python.platform import tf_logging as logging from tensorflow.python.saved_model import signature_constants from tensorflow.python.saved_model import signature_def_utils from tensorflow.python.summary import summary_iterator -from tensorflow.python.training import saver +from tensorflow.python.training import checkpoint_management from tensorflow.python.util import compat from tensorflow.python.util.deprecation import deprecated @@ -714,7 +714,8 @@ def make_best_model_export_strategy( # as soon as contrib is cleaned up and we can thus be sure that # estimator is a tf.estimator.Estimator and not a # tf.contrib.learn.Estimator - checkpoint_path = saver.latest_checkpoint(estimator.model_dir) + checkpoint_path = checkpoint_management.latest_checkpoint( + estimator.model_dir) export_checkpoint_path, export_eval_result = best_model_selector.update( checkpoint_path, eval_result) diff --git a/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py b/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py index 06ab58188a..28a531dfec 100644 --- a/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py +++ b/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py @@ -41,6 +41,7 @@ from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import template from tensorflow.python.ops import variable_scope +from tensorflow.python.training import checkpoint_management from tensorflow.python.training import saver as core_saver from tensorflow.python.training import training_util from tensorflow.python.training.checkpointable import tracking @@ -278,7 +279,8 @@ class CheckpointingTests(test.TestCase): root = util.Checkpoint( optimizer=optimizer, model=model, optimizer_step=training_util.get_or_create_global_step()) - root.restore(core_saver.latest_checkpoint(checkpoint_directory)) + root.restore(checkpoint_management.latest_checkpoint( + checkpoint_directory)) for _ in range(num_training_steps): # TODO(allenl): Use a Dataset and serialize/checkpoint it. input_value = constant_op.constant([[3.]]) @@ -306,7 +308,8 @@ class CheckpointingTests(test.TestCase): train_op = optimizer.minimize( model(input_value), global_step=root.global_step) - checkpoint_path = core_saver.latest_checkpoint(checkpoint_directory) + checkpoint_path = checkpoint_management.latest_checkpoint( + checkpoint_directory) with self.test_session(graph=ops.get_default_graph()) as session: status = root.restore(save_path=checkpoint_path) status.initialize_or_restore(session=session) @@ -339,7 +342,8 @@ class CheckpointingTests(test.TestCase): root = util.Checkpoint( optimizer=optimizer, model=model, global_step=training_util.get_or_create_global_step()) - checkpoint_path = core_saver.latest_checkpoint(checkpoint_directory) + checkpoint_path = checkpoint_management.latest_checkpoint( + checkpoint_directory) status = root.restore(save_path=checkpoint_path) input_value = constant_op.constant([[3.]]) train_fn = functools.partial( @@ -372,7 +376,8 @@ class CheckpointingTests(test.TestCase): root = util.Checkpoint( optimizer=optimizer, model=model, global_step=training_util.get_or_create_global_step()) - checkpoint_path = core_saver.latest_checkpoint(checkpoint_directory) + checkpoint_path = checkpoint_management.latest_checkpoint( + checkpoint_directory) status = root.restore(save_path=checkpoint_path) def train_fn(): @function.defun diff --git a/tensorflow/contrib/predictor/contrib_estimator_predictor.py b/tensorflow/contrib/predictor/contrib_estimator_predictor.py index af3b2ad1b5..c2166594e5 100644 --- a/tensorflow/contrib/predictor/contrib_estimator_predictor.py +++ b/tensorflow/contrib/predictor/contrib_estimator_predictor.py @@ -22,8 +22,8 @@ from __future__ import print_function from tensorflow.contrib.learn.python.learn.utils import saved_model_export_utils from tensorflow.contrib.predictor import predictor from tensorflow.python.framework import ops +from tensorflow.python.training import checkpoint_management from tensorflow.python.training import monitored_session -from tensorflow.python.training import saver class ContribEstimatorPredictor(predictor.Predictor): @@ -57,7 +57,8 @@ class ContribEstimatorPredictor(predictor.Predictor): # pylint: disable=protected-access model_fn_ops = estimator._get_predict_ops(input_fn_ops.features) # pylint: enable=protected-access - checkpoint_path = saver.latest_checkpoint(estimator.model_dir) + checkpoint_path = checkpoint_management.latest_checkpoint( + estimator.model_dir) self._session = monitored_session.MonitoredSession( session_creator=monitored_session.ChiefSessionCreator( config=config, diff --git a/tensorflow/contrib/training/python/training/evaluation.py b/tensorflow/contrib/training/python/training/evaluation.py index f7fd66d33f..01bac891da 100644 --- a/tensorflow/contrib/training/python/training/evaluation.py +++ b/tensorflow/contrib/training/python/training/evaluation.py @@ -142,9 +142,9 @@ from tensorflow.python.ops import state_ops from tensorflow.python.platform import tf_logging as logging from tensorflow.python.summary import summary from tensorflow.python.training import basic_session_run_hooks +from tensorflow.python.training import checkpoint_management from tensorflow.python.training import evaluation from tensorflow.python.training import monitored_session -from tensorflow.python.training import saver as tf_saver from tensorflow.python.training import session_run_hook from tensorflow.python.training import training_util @@ -189,7 +189,7 @@ def wait_for_new_checkpoint(checkpoint_dir, logging.info('Waiting for new checkpoint at %s', checkpoint_dir) stop_time = time.time() + timeout if timeout is not None else None while True: - checkpoint_path = tf_saver.latest_checkpoint(checkpoint_dir) + checkpoint_path = checkpoint_management.latest_checkpoint(checkpoint_dir) if checkpoint_path is None or checkpoint_path == last_checkpoint: if stop_time is not None and time.time() + seconds_to_sleep > stop_time: return None diff --git a/tensorflow/contrib/training/python/training/training_test.py b/tensorflow/contrib/training/python/training/training_test.py index 4877c010fa..94cf7788b2 100644 --- a/tensorflow/contrib/training/python/training/training_test.py +++ b/tensorflow/contrib/training/python/training/training_test.py @@ -36,6 +36,7 @@ from tensorflow.python.ops.losses import losses from tensorflow.python.platform import gfile from tensorflow.python.platform import test from tensorflow.python.training import basic_session_run_hooks +from tensorflow.python.training import checkpoint_management from tensorflow.python.training import gradient_descent from tensorflow.python.training import monitored_session from tensorflow.python.training import saver as saver_lib @@ -421,7 +422,7 @@ class TrainTest(test.TestCase): train_op = self.create_train_op() model_variables = variables_lib2.global_variables() - model_path = saver_lib.latest_checkpoint(logdir1) + model_path = checkpoint_management.latest_checkpoint(logdir1) assign_fn = variables_lib.assign_from_checkpoint_fn( model_path, model_variables) diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 2b8110a999..7cf8ddb1d9 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -3216,6 +3216,7 @@ py_library( "training/checkpointable/**/*.py", # The following targets have their own build rules (same name as the # file): + "training/checkpoint_management.py", "training/saveable_object.py", "training/saver.py", "training/training_util.py", @@ -3223,8 +3224,10 @@ py_library( ), srcs_version = "PY2AND3", deps = [ + "saver", ":array_ops", ":array_ops_gen", + ":checkpoint_management", ":checkpoint_ops_gen", ":client", ":control_flow_ops", @@ -3236,25 +3239,20 @@ py_library( ":framework_ops", ":gradients", ":init_ops", - ":distribute", ":io_ops", - ":io_ops_gen", ":layers_base", - ":lib", ":lookup_ops", ":math_ops", ":platform", - ":protos_all_py", ":pywrap_tensorflow", ":random_ops", ":resource_variable_ops", ":resources", - "saver", - ":saveable_object", ":sdca_ops", + ":session", ":sparse_ops", + ":sparse_tensor", ":state_ops", - ":string_ops", ":summary", ":training_ops_gen", ":training_util", @@ -3264,6 +3262,7 @@ py_library( "//third_party/py/numpy", "@six_archive//:six", "//tensorflow/core:protos_all_py", + "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/eager:backprop", "//tensorflow/python/eager:context", # `layers` dependency only exists due to the use of a small utility. @@ -3281,11 +3280,25 @@ py_library( ) py_library( + name = "checkpoint_management", + srcs = ["training/checkpoint_management.py"], + deps = [ + ":errors", + ":lib", + ":platform", + ":protos_all_py", + ":util", + "//tensorflow/core:protos_all_py", + ], +) + +py_library( name = "saver", srcs = ["training/saver.py"], srcs_version = "PY2AND3", deps = [ ":array_ops", + ":checkpoint_management", ":constant_op", ":control_flow_ops", ":device", @@ -3294,9 +3307,7 @@ py_library( ":framework_ops", ":io_ops", ":io_ops_gen", - ":lib", ":platform", - ":protos_all_py", ":pywrap_tensorflow", ":resource_variable_ops", ":saveable_object", @@ -4423,6 +4434,42 @@ cuda_py_test( tags = ["multi_gpu"], ) +cuda_py_test( + name = "checkpoint_management_test", + size = "small", + srcs = [ + "training/checkpoint_management_test.py", + ], + additional_deps = [ + ":array_ops", + ":client_testlib", + ":control_flow_ops", + ":data_flow_ops", + ":errors", + ":gradients", + ":math_ops", + ":nn_grad", + ":nn_ops", + ":saver_test_utils", + ":partitioned_variables", + ":platform", + ":platform_test", + ":pywrap_tensorflow", + ":random_ops", + ":resource_variable_ops", + ":sparse_ops", + ":summary", + ":training", + ":util", + ":variable_scope", + ":variables", + "//third_party/py/numpy", + "@six_archive//:six", + "//tensorflow/core:protos_all_py", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + py_test( name = "saver_large_variable_test", size = "medium", @@ -4489,6 +4536,7 @@ tf_py_test( srcs = ["training/supervisor_test.py"], additional_deps = [ ":array_ops", + ":checkpoint_management", ":client_testlib", ":errors", ":framework", @@ -4496,6 +4544,7 @@ tf_py_test( ":io_ops", ":parsing_ops", ":platform", + ":saver", ":summary", ":training", ":variables", @@ -4609,10 +4658,13 @@ py_test( tags = ["notsan"], # b/67945581 deps = [ ":array_ops", + ":checkpoint_management", ":client_testlib", ":control_flow_ops", ":errors", ":framework_for_generated_wrappers", + ":resource_variable_ops", + ":saver", ":session", ":state_ops", ":summary", diff --git a/tensorflow/python/data/kernel_tests/iterator_ops_test.py b/tensorflow/python/data/kernel_tests/iterator_ops_test.py index dd39262f9b..352424514e 100644 --- a/tensorflow/python/data/kernel_tests/iterator_ops_test.py +++ b/tensorflow/python/data/kernel_tests/iterator_ops_test.py @@ -47,7 +47,7 @@ from tensorflow.python.ops import parsing_ops from tensorflow.python.ops import script_ops from tensorflow.python.ops import variables from tensorflow.python.platform import test -from tensorflow.python.training import saver +from tensorflow.python.training import checkpoint_management from tensorflow.python.training import server_lib from tensorflow.python.training.checkpointable import util as checkpointable_utils from tensorflow.python.util import compat @@ -877,7 +877,7 @@ class IteratorCheckpointingTest(test.TestCase): checkpoint = checkpointable_utils.Checkpoint(iterator=iterator) for i in range(5): with self.test_session() as sess: - checkpoint.restore(saver.latest_checkpoint( + checkpoint.restore(checkpoint_management.latest_checkpoint( checkpoint_directory)).initialize_or_restore(sess) for j in range(2): self.assertEqual(i * 2 + j, sess.run(get_next)) diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py index a4e735a092..43deb8bc6c 100644 --- a/tensorflow/python/estimator/estimator.py +++ b/tensorflow/python/estimator/estimator.py @@ -53,6 +53,7 @@ from tensorflow.python.saved_model import builder as saved_model_builder from tensorflow.python.saved_model import constants from tensorflow.python.summary import summary from tensorflow.python.summary.writer import writer_cache +from tensorflow.python.training import checkpoint_management from tensorflow.python.training import device_setter from tensorflow.python.training import distribute as distribute_lib from tensorflow.python.training import evaluation @@ -268,7 +269,7 @@ class Estimator(object): found. """ with context.graph_mode(): - return saver.latest_checkpoint(self.model_dir) + return checkpoint_management.latest_checkpoint(self.model_dir) def train(self, input_fn, @@ -417,7 +418,7 @@ class Estimator(object): # Check that model has been trained (if nothing has been set explicitly). if not checkpoint_path: - latest_path = saver.latest_checkpoint(self._model_dir) + latest_path = checkpoint_management.latest_checkpoint(self._model_dir) if not latest_path: logging.info('Could not find trained model in model_dir: {}, running ' 'initialization to evaluate.'.format(self._model_dir)) @@ -504,7 +505,8 @@ class Estimator(object): hooks = _check_hooks_type(hooks) # Check that model has been trained. if not checkpoint_path: - checkpoint_path = saver.latest_checkpoint(self._model_dir) + checkpoint_path = checkpoint_management.latest_checkpoint( + self._model_dir) if not checkpoint_path: logging.info('Could not find trained model in model_dir: {}, running ' 'initialization to predict.'.format(self._model_dir)) @@ -769,7 +771,8 @@ class Estimator(object): with context.graph_mode(): if not checkpoint_path: # Locate the latest checkpoint - checkpoint_path = saver.latest_checkpoint(self._model_dir) + checkpoint_path = checkpoint_management.latest_checkpoint( + self._model_dir) if not checkpoint_path: raise ValueError("Couldn't find trained model at %s." % self._model_dir) @@ -1626,7 +1629,7 @@ def _combine_distributed_scaffold(grouped_scaffold, distribution): def _check_checkpoint_available(model_dir): - latest_path = saver.latest_checkpoint(model_dir) + latest_path = checkpoint_management.latest_checkpoint(model_dir) if not latest_path: raise ValueError( 'Could not find trained model in model_dir: {}.'.format(model_dir)) diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py index 68fc5bcadf..e8552092e0 100644 --- a/tensorflow/python/estimator/estimator_test.py +++ b/tensorflow/python/estimator/estimator_test.py @@ -69,6 +69,7 @@ from tensorflow.python.summary import summary from tensorflow.python.summary import summary_iterator from tensorflow.python.summary.writer import writer_cache from tensorflow.python.training import basic_session_run_hooks +from tensorflow.python.training import checkpoint_management from tensorflow.python.training import checkpoint_state_pb2 from tensorflow.python.training import saver from tensorflow.python.training import saver_test_utils @@ -1548,7 +1549,8 @@ class EstimatorPredictTest(test.TestCase): next( est.predict( dummy_input_fn, - checkpoint_path=saver.latest_checkpoint('fakedir'))) + checkpoint_path= + checkpoint_management.latest_checkpoint('fakedir'))) def test_tensor_predictions(self): diff --git a/tensorflow/python/estimator/keras.py b/tensorflow/python/estimator/keras.py index 079560c495..c63deb8f4d 100644 --- a/tensorflow/python/estimator/keras.py +++ b/tensorflow/python/estimator/keras.py @@ -42,6 +42,7 @@ from tensorflow.python.ops import metrics as metrics_module from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging as logging from tensorflow.python.saved_model import signature_constants +from tensorflow.python.training import checkpoint_management from tensorflow.python.training import distribute as distribute_lib from tensorflow.python.training import saver as saver_lib from tensorflow.python.training import training_util @@ -442,7 +443,7 @@ def _save_first_checkpoint(keras_model, custom_objects, config): # save checkpoint into subdirectory to allow warm start keras_model_dir = os.path.join(config.model_dir, 'keras') # Load weights and save to checkpoint if there is no checkpoint - latest_path = saver_lib.latest_checkpoint(keras_model_dir) + latest_path = checkpoint_management.latest_checkpoint(keras_model_dir) if not latest_path: keras_weights = None if _any_weight_initialized(keras_model): diff --git a/tensorflow/python/tools/freeze_graph.py b/tensorflow/python/tools/freeze_graph.py index 4349699a94..130fe70beb 100644 --- a/tensorflow/python/tools/freeze_graph.py +++ b/tensorflow/python/tools/freeze_graph.py @@ -55,6 +55,7 @@ from tensorflow.python.platform import gfile from tensorflow.python.saved_model import loader from tensorflow.python.saved_model import tag_constants from tensorflow.python.tools import saved_model_utils +from tensorflow.python.training import checkpoint_management from tensorflow.python.training import saver as saver_lib @@ -78,7 +79,7 @@ def freeze_graph_with_def_protos(input_graph_def, # 'input_checkpoint' may be a prefix if we're using Saver V2 format if (not input_saved_model_dir and - not saver_lib.checkpoint_exists(input_checkpoint)): + not checkpoint_management.checkpoint_exists(input_checkpoint)): print("Input checkpoint '" + input_checkpoint + "' doesn't exist!") return -1 diff --git a/tensorflow/python/training/checkpoint_management.py b/tensorflow/python/training/checkpoint_management.py new file mode 100644 index 0000000000..aaddc015ed --- /dev/null +++ b/tensorflow/python/training/checkpoint_management.py @@ -0,0 +1,406 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +# pylint: disable=invalid-name +"""Save and restore variables.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os.path +import re + +from google.protobuf import text_format + +from tensorflow.core.protobuf import saver_pb2 +from tensorflow.python.framework import errors +from tensorflow.python.lib.io import file_io +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState +from tensorflow.python.util.tf_export import tf_export + + +def _GetCheckpointFilename(save_dir, latest_filename): + """Returns a filename for storing the CheckpointState. + + Args: + save_dir: The directory for saving and restoring checkpoints. + latest_filename: Name of the file in 'save_dir' that is used + to store the CheckpointState. + + Returns: + The path of the file that contains the CheckpointState proto. + """ + if latest_filename is None: + latest_filename = "checkpoint" + return os.path.join(save_dir, latest_filename) + + +@tf_export("train.generate_checkpoint_state_proto") +def generate_checkpoint_state_proto(save_dir, + model_checkpoint_path, + all_model_checkpoint_paths=None): + """Generates a checkpoint state proto. + + Args: + save_dir: Directory where the model was saved. + model_checkpoint_path: The checkpoint file. + all_model_checkpoint_paths: List of strings. Paths to all not-yet-deleted + checkpoints, sorted from oldest to newest. If this is a non-empty list, + the last element must be equal to model_checkpoint_path. These paths + are also saved in the CheckpointState proto. + + Returns: + CheckpointState proto with model_checkpoint_path and + all_model_checkpoint_paths updated to either absolute paths or + relative paths to the current save_dir. + """ + if all_model_checkpoint_paths is None: + all_model_checkpoint_paths = [] + + if (not all_model_checkpoint_paths or + all_model_checkpoint_paths[-1] != model_checkpoint_path): + logging.info("%s is not in all_model_checkpoint_paths. Manually adding it.", + model_checkpoint_path) + all_model_checkpoint_paths.append(model_checkpoint_path) + + # Relative paths need to be rewritten to be relative to the "save_dir" + # if model_checkpoint_path already contains "save_dir". + if not os.path.isabs(save_dir): + if not os.path.isabs(model_checkpoint_path): + model_checkpoint_path = os.path.relpath(model_checkpoint_path, save_dir) + for i in range(len(all_model_checkpoint_paths)): + p = all_model_checkpoint_paths[i] + if not os.path.isabs(p): + all_model_checkpoint_paths[i] = os.path.relpath(p, save_dir) + + coord_checkpoint_proto = CheckpointState( + model_checkpoint_path=model_checkpoint_path, + all_model_checkpoint_paths=all_model_checkpoint_paths) + + return coord_checkpoint_proto + + +@tf_export("train.update_checkpoint_state") +def update_checkpoint_state(save_dir, + model_checkpoint_path, + all_model_checkpoint_paths=None, + latest_filename=None): + """Updates the content of the 'checkpoint' file. + + This updates the checkpoint file containing a CheckpointState + proto. + + Args: + save_dir: Directory where the model was saved. + model_checkpoint_path: The checkpoint file. + all_model_checkpoint_paths: List of strings. Paths to all not-yet-deleted + checkpoints, sorted from oldest to newest. If this is a non-empty list, + the last element must be equal to model_checkpoint_path. These paths + are also saved in the CheckpointState proto. + latest_filename: Optional name of the checkpoint file. Default to + 'checkpoint'. + + Raises: + RuntimeError: If any of the model checkpoint paths conflict with the file + containing CheckpointSate. + """ + update_checkpoint_state_internal( + save_dir=save_dir, + model_checkpoint_path=model_checkpoint_path, + all_model_checkpoint_paths=all_model_checkpoint_paths, + latest_filename=latest_filename, + save_relative_paths=False) + + +def update_checkpoint_state_internal(save_dir, + model_checkpoint_path, + all_model_checkpoint_paths=None, + latest_filename=None, + save_relative_paths=False): + """Updates the content of the 'checkpoint' file. + + This updates the checkpoint file containing a CheckpointState + proto. + + Args: + save_dir: Directory where the model was saved. + model_checkpoint_path: The checkpoint file. + all_model_checkpoint_paths: List of strings. Paths to all not-yet-deleted + checkpoints, sorted from oldest to newest. If this is a non-empty list, + the last element must be equal to model_checkpoint_path. These paths + are also saved in the CheckpointState proto. + latest_filename: Optional name of the checkpoint file. Default to + 'checkpoint'. + save_relative_paths: If `True`, will write relative paths to the checkpoint + state file. + + Raises: + RuntimeError: If any of the model checkpoint paths conflict with the file + containing CheckpointSate. + """ + # Writes the "checkpoint" file for the coordinator for later restoration. + coord_checkpoint_filename = _GetCheckpointFilename(save_dir, latest_filename) + if save_relative_paths: + if os.path.isabs(model_checkpoint_path): + rel_model_checkpoint_path = os.path.relpath( + model_checkpoint_path, save_dir) + else: + rel_model_checkpoint_path = model_checkpoint_path + rel_all_model_checkpoint_paths = [] + for p in all_model_checkpoint_paths: + if os.path.isabs(p): + rel_all_model_checkpoint_paths.append(os.path.relpath(p, save_dir)) + else: + rel_all_model_checkpoint_paths.append(p) + ckpt = generate_checkpoint_state_proto( + save_dir, + rel_model_checkpoint_path, + all_model_checkpoint_paths=rel_all_model_checkpoint_paths) + else: + ckpt = generate_checkpoint_state_proto( + save_dir, + model_checkpoint_path, + all_model_checkpoint_paths=all_model_checkpoint_paths) + + if coord_checkpoint_filename == ckpt.model_checkpoint_path: + raise RuntimeError("Save path '%s' conflicts with path used for " + "checkpoint state. Please use a different save path." % + model_checkpoint_path) + + # Preventing potential read/write race condition by *atomically* writing to a + # file. + file_io.atomic_write_string_to_file(coord_checkpoint_filename, + text_format.MessageToString(ckpt)) + + +@tf_export("train.get_checkpoint_state") +def get_checkpoint_state(checkpoint_dir, latest_filename=None): + """Returns CheckpointState proto from the "checkpoint" file. + + If the "checkpoint" file contains a valid CheckpointState + proto, returns it. + + Args: + checkpoint_dir: The directory of checkpoints. + latest_filename: Optional name of the checkpoint file. Default to + 'checkpoint'. + + Returns: + A CheckpointState if the state was available, None + otherwise. + + Raises: + ValueError: if the checkpoint read doesn't have model_checkpoint_path set. + """ + ckpt = None + coord_checkpoint_filename = _GetCheckpointFilename(checkpoint_dir, + latest_filename) + f = None + try: + # Check that the file exists before opening it to avoid + # many lines of errors from colossus in the logs. + if file_io.file_exists(coord_checkpoint_filename): + file_content = file_io.read_file_to_string( + coord_checkpoint_filename) + ckpt = CheckpointState() + text_format.Merge(file_content, ckpt) + if not ckpt.model_checkpoint_path: + raise ValueError("Invalid checkpoint state loaded from " + + checkpoint_dir) + # For relative model_checkpoint_path and all_model_checkpoint_paths, + # prepend checkpoint_dir. + if not os.path.isabs(ckpt.model_checkpoint_path): + ckpt.model_checkpoint_path = os.path.join(checkpoint_dir, + ckpt.model_checkpoint_path) + for i in range(len(ckpt.all_model_checkpoint_paths)): + p = ckpt.all_model_checkpoint_paths[i] + if not os.path.isabs(p): + ckpt.all_model_checkpoint_paths[i] = os.path.join(checkpoint_dir, p) + except errors.OpError as e: + # It's ok if the file cannot be read + logging.warning("%s: %s", type(e).__name__, e) + logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename) + return None + except text_format.ParseError as e: + logging.warning("%s: %s", type(e).__name__, e) + logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename) + return None + finally: + if f: + f.close() + return ckpt + + +def _prefix_to_checkpoint_path(prefix, format_version): + """Returns the pathname of a checkpoint file, given the checkpoint prefix. + + For V1 checkpoint, simply returns the prefix itself (the data file). For V2, + returns the pathname to the index file. + + Args: + prefix: a string, the prefix of a checkpoint. + format_version: the checkpoint format version that corresponds to the + prefix. + Returns: + The pathname of a checkpoint file, taking into account the checkpoint + format version. + """ + if format_version == saver_pb2.SaverDef.V2: + return prefix + ".index" # The index file identifies a checkpoint. + return prefix # Just the data file. + + +@tf_export("train.latest_checkpoint") +def latest_checkpoint(checkpoint_dir, latest_filename=None): + """Finds the filename of latest saved checkpoint file. + + Args: + checkpoint_dir: Directory where the variables were saved. + latest_filename: Optional name for the protocol buffer file that + contains the list of most recent checkpoint filenames. + See the corresponding argument to `Saver.save()`. + + Returns: + The full path to the latest checkpoint or `None` if no checkpoint was found. + """ + # Pick the latest checkpoint based on checkpoint state. + ckpt = get_checkpoint_state(checkpoint_dir, latest_filename) + if ckpt and ckpt.model_checkpoint_path: + # Look for either a V2 path or a V1 path, with priority for V2. + v2_path = _prefix_to_checkpoint_path(ckpt.model_checkpoint_path, + saver_pb2.SaverDef.V2) + v1_path = _prefix_to_checkpoint_path(ckpt.model_checkpoint_path, + saver_pb2.SaverDef.V1) + if file_io.get_matching_files(v2_path) or file_io.get_matching_files( + v1_path): + return ckpt.model_checkpoint_path + else: + logging.error("Couldn't match files for checkpoint %s", + ckpt.model_checkpoint_path) + return None + + +@tf_export("train.checkpoint_exists") +def checkpoint_exists(checkpoint_prefix): + """Checks whether a V1 or V2 checkpoint exists with the specified prefix. + + This is the recommended way to check if a checkpoint exists, since it takes + into account the naming difference between V1 and V2 formats. + + Args: + checkpoint_prefix: the prefix of a V1 or V2 checkpoint, with V2 taking + priority. Typically the result of `Saver.save()` or that of + `tf.train.latest_checkpoint()`, regardless of sharded/non-sharded or + V1/V2. + Returns: + A bool, true iff a checkpoint referred to by `checkpoint_prefix` exists. + """ + pathname = _prefix_to_checkpoint_path(checkpoint_prefix, + saver_pb2.SaverDef.V2) + if file_io.get_matching_files(pathname): + return True + elif file_io.get_matching_files(checkpoint_prefix): + return True + else: + return False + + +@tf_export("train.get_checkpoint_mtimes") +def get_checkpoint_mtimes(checkpoint_prefixes): + """Returns the mtimes (modification timestamps) of the checkpoints. + + Globs for the checkpoints pointed to by `checkpoint_prefixes`. If the files + exist, collect their mtime. Both V2 and V1 checkpoints are considered, in + that priority. + + This is the recommended way to get the mtimes, since it takes into account + the naming difference between V1 and V2 formats. + + Args: + checkpoint_prefixes: a list of checkpoint paths, typically the results of + `Saver.save()` or those of `tf.train.latest_checkpoint()`, regardless of + sharded/non-sharded or V1/V2. + Returns: + A list of mtimes (in microseconds) of the found checkpoints. + """ + mtimes = [] + + def match_maybe_append(pathname): + fnames = file_io.get_matching_files(pathname) + if fnames: + mtimes.append(file_io.stat(fnames[0]).mtime_nsec / 1e9) + return True + return False + + for checkpoint_prefix in checkpoint_prefixes: + # Tries V2's metadata file first. + pathname = _prefix_to_checkpoint_path(checkpoint_prefix, + saver_pb2.SaverDef.V2) + if match_maybe_append(pathname): + continue + # Otherwise, tries V1, where the prefix is the complete pathname. + match_maybe_append(checkpoint_prefix) + + return mtimes + + +@tf_export("train.remove_checkpoint") +def remove_checkpoint(checkpoint_prefix, + checkpoint_format_version=saver_pb2.SaverDef.V2, + meta_graph_suffix="meta"): + """Removes a checkpoint given by `checkpoint_prefix`. + + Args: + checkpoint_prefix: The prefix of a V1 or V2 checkpoint. Typically the result + of `Saver.save()` or that of `tf.train.latest_checkpoint()`, regardless of + sharded/non-sharded or V1/V2. + checkpoint_format_version: `SaverDef.CheckpointFormatVersion`, defaults to + `SaverDef.V2`. + meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'. + """ + _delete_file_if_exists( + meta_graph_filename(checkpoint_prefix, meta_graph_suffix)) + if checkpoint_format_version == saver_pb2.SaverDef.V2: + # V2 has a metadata file and some data files. + _delete_file_if_exists(checkpoint_prefix + ".index") + _delete_file_if_exists(checkpoint_prefix + ".data-?????-of-?????") + else: + # V1, Legacy. Exact match on the data file. + _delete_file_if_exists(checkpoint_prefix) + + +def _delete_file_if_exists(filespec): + """Deletes files matching `filespec`.""" + for pathname in file_io.get_matching_files(filespec): + file_io.delete_file(pathname) + + +def meta_graph_filename(checkpoint_filename, meta_graph_suffix="meta"): + """Returns the meta graph filename. + + Args: + checkpoint_filename: Name of the checkpoint file. + meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'. + + Returns: + MetaGraph file name. + """ + # If the checkpoint_filename is sharded, the checkpoint_filename could + # be of format model.ckpt-step#-?????-of-shard#. For example, + # model.ckpt-123456-?????-of-00005, or model.ckpt-123456-00001-of-00002. + basename = re.sub(r"-[\d\?]+-of-\d+$", "", checkpoint_filename) + suffixed_filename = ".".join([basename, meta_graph_suffix]) + return suffixed_filename diff --git a/tensorflow/python/training/checkpoint_management_test.py b/tensorflow/python/training/checkpoint_management_test.py new file mode 100644 index 0000000000..4b31d0c613 --- /dev/null +++ b/tensorflow/python/training/checkpoint_management_test.py @@ -0,0 +1,316 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +"""Tests for tensorflow.python.training.saver.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import contextlib +import os +import shutil +import tempfile + +from google.protobuf import text_format + +from tensorflow.core.protobuf import saver_pb2 +from tensorflow.python.framework import ops as ops_lib +from tensorflow.python.lib.io import file_io +from tensorflow.python.ops import variables +from tensorflow.python.platform import gfile +from tensorflow.python.platform import test +from tensorflow.python.training import checkpoint_management +from tensorflow.python.training import saver as saver_module +from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState + + +class LatestCheckpointWithRelativePaths(test.TestCase): + + @staticmethod + @contextlib.contextmanager + def tempWorkingDir(temppath): + cwd = os.getcwd() + os.chdir(temppath) + try: + yield + finally: + os.chdir(cwd) + + @staticmethod + @contextlib.contextmanager + def tempDir(): + tempdir = tempfile.mkdtemp() + try: + yield tempdir + finally: + shutil.rmtree(tempdir) + + def testNameCollision(self): + # Make sure we have a clean directory to work in. + with self.tempDir() as tempdir: + # Jump to that directory until this test is done. + with self.tempWorkingDir(tempdir): + # Save training snapshots to a relative path. + traindir = "train/" + os.mkdir(traindir) + # Collides with the default name of the checkpoint state file. + filepath = os.path.join(traindir, "checkpoint") + + with self.test_session() as sess: + unused_a = variables.Variable(0.0) # So that Saver saves something. + variables.global_variables_initializer().run() + + # Should fail. + saver = saver_module.Saver(sharded=False) + with self.assertRaisesRegexp(ValueError, "collides with"): + saver.save(sess, filepath) + + # Succeeds: the file will be named "checkpoint-<step>". + saver.save(sess, filepath, global_step=1) + self.assertIsNotNone( + checkpoint_management.latest_checkpoint(traindir)) + + # Succeeds: the file will be named "checkpoint-<i>-of-<n>". + saver = saver_module.Saver(sharded=True) + saver.save(sess, filepath) + self.assertIsNotNone( + checkpoint_management.latest_checkpoint(traindir)) + + # Succeeds: the file will be named "checkpoint-<step>-<i>-of-<n>". + saver = saver_module.Saver(sharded=True) + saver.save(sess, filepath, global_step=1) + self.assertIsNotNone( + checkpoint_management.latest_checkpoint(traindir)) + + def testRelativePath(self): + # Make sure we have a clean directory to work in. + with self.tempDir() as tempdir: + + # Jump to that directory until this test is done. + with self.tempWorkingDir(tempdir): + + # Save training snapshots to a relative path. + traindir = "train/" + os.mkdir(traindir) + + filename = "snapshot" + filepath = os.path.join(traindir, filename) + + with self.test_session() as sess: + # Build a simple graph. + v0 = variables.Variable(0.0) + inc = v0.assign_add(1.0) + + save = saver_module.Saver({"v0": v0}) + + # Record a short training history. + variables.global_variables_initializer().run() + save.save(sess, filepath, global_step=0) + inc.eval() + save.save(sess, filepath, global_step=1) + inc.eval() + save.save(sess, filepath, global_step=2) + + with self.test_session() as sess: + # Build a new graph with different initialization. + v0 = variables.Variable(-1.0) + + # Create a new saver. + save = saver_module.Saver({"v0": v0}) + variables.global_variables_initializer().run() + + # Get the most recent checkpoint name from the training history file. + name = checkpoint_management.latest_checkpoint(traindir) + self.assertIsNotNone(name) + + # Restore "v0" from that checkpoint. + save.restore(sess, name) + self.assertEqual(v0.eval(), 2.0) + + +class CheckpointStateTest(test.TestCase): + + def _get_test_dir(self, dirname): + test_dir = os.path.join(self.get_temp_dir(), dirname) + gfile.MakeDirs(test_dir) + return test_dir + + def testAbsPath(self): + save_dir = self._get_test_dir("abs_paths") + abs_path = os.path.join(save_dir, "model-0") + ckpt = checkpoint_management.generate_checkpoint_state_proto( + save_dir, abs_path) + self.assertEqual(ckpt.model_checkpoint_path, abs_path) + self.assertTrue(os.path.isabs(ckpt.model_checkpoint_path)) + self.assertEqual(len(ckpt.all_model_checkpoint_paths), 1) + self.assertEqual(ckpt.all_model_checkpoint_paths[-1], abs_path) + + def testRelPath(self): + train_dir = "train" + model = os.path.join(train_dir, "model-0") + # model_checkpoint_path should have no "train" directory part. + new_rel_path = "model-0" + ckpt = checkpoint_management.generate_checkpoint_state_proto( + train_dir, model) + self.assertEqual(ckpt.model_checkpoint_path, new_rel_path) + self.assertEqual(len(ckpt.all_model_checkpoint_paths), 1) + self.assertEqual(ckpt.all_model_checkpoint_paths[-1], new_rel_path) + + def testAllModelCheckpointPaths(self): + save_dir = self._get_test_dir("all_models_test") + abs_path = os.path.join(save_dir, "model-0") + for paths in [None, [], ["model-2"]]: + ckpt = checkpoint_management.generate_checkpoint_state_proto( + save_dir, abs_path, all_model_checkpoint_paths=paths) + self.assertEqual(ckpt.model_checkpoint_path, abs_path) + self.assertTrue(os.path.isabs(ckpt.model_checkpoint_path)) + self.assertEqual( + len(ckpt.all_model_checkpoint_paths), len(paths) if paths else 1) + self.assertEqual(ckpt.all_model_checkpoint_paths[-1], abs_path) + + def testUpdateCheckpointState(self): + save_dir = self._get_test_dir("update_checkpoint_state") + os.chdir(save_dir) + # Make a temporary train directory. + train_dir = "train" + os.mkdir(train_dir) + abs_path = os.path.join(save_dir, "model-0") + rel_path = os.path.join("train", "model-2") + checkpoint_management.update_checkpoint_state( + train_dir, rel_path, all_model_checkpoint_paths=[abs_path, rel_path]) + ckpt = checkpoint_management.get_checkpoint_state(train_dir) + self.assertEqual(ckpt.model_checkpoint_path, rel_path) + self.assertEqual(len(ckpt.all_model_checkpoint_paths), 2) + self.assertEqual(ckpt.all_model_checkpoint_paths[-1], rel_path) + self.assertEqual(ckpt.all_model_checkpoint_paths[0], abs_path) + + def testUpdateCheckpointStateSaveRelativePaths(self): + save_dir = self._get_test_dir("update_checkpoint_state") + os.chdir(save_dir) + abs_path2 = os.path.join(save_dir, "model-2") + rel_path2 = "model-2" + abs_path0 = os.path.join(save_dir, "model-0") + rel_path0 = "model-0" + checkpoint_management.update_checkpoint_state_internal( + save_dir=save_dir, + model_checkpoint_path=abs_path2, + all_model_checkpoint_paths=[rel_path0, abs_path2], + save_relative_paths=True) + + # File should contain relative paths. + file_content = file_io.read_file_to_string( + os.path.join(save_dir, "checkpoint")) + ckpt = CheckpointState() + text_format.Merge(file_content, ckpt) + self.assertEqual(ckpt.model_checkpoint_path, rel_path2) + self.assertEqual(len(ckpt.all_model_checkpoint_paths), 2) + self.assertEqual(ckpt.all_model_checkpoint_paths[-1], rel_path2) + self.assertEqual(ckpt.all_model_checkpoint_paths[0], rel_path0) + + # get_checkpoint_state should return absolute paths. + ckpt = checkpoint_management.get_checkpoint_state(save_dir) + self.assertEqual(ckpt.model_checkpoint_path, abs_path2) + self.assertEqual(len(ckpt.all_model_checkpoint_paths), 2) + self.assertEqual(ckpt.all_model_checkpoint_paths[-1], abs_path2) + self.assertEqual(ckpt.all_model_checkpoint_paths[0], abs_path0) + + def testCheckPointStateFailsWhenIncomplete(self): + save_dir = self._get_test_dir("checkpoint_state_fails_when_incomplete") + os.chdir(save_dir) + ckpt_path = os.path.join(save_dir, "checkpoint") + ckpt_file = open(ckpt_path, "w") + ckpt_file.write("") + ckpt_file.close() + with self.assertRaises(ValueError): + checkpoint_management.get_checkpoint_state(save_dir) + + def testCheckPointCompletesRelativePaths(self): + save_dir = self._get_test_dir("checkpoint_completes_relative_paths") + os.chdir(save_dir) + ckpt_path = os.path.join(save_dir, "checkpoint") + ckpt_file = open(ckpt_path, "w") + ckpt_file.write(""" + model_checkpoint_path: "./model.ckpt-687529" + all_model_checkpoint_paths: "./model.ckpt-687500" + all_model_checkpoint_paths: "./model.ckpt-687529" + """) + ckpt_file.close() + ckpt = checkpoint_management.get_checkpoint_state(save_dir) + self.assertEqual(ckpt.model_checkpoint_path, + os.path.join(save_dir, "./model.ckpt-687529")) + self.assertEqual(ckpt.all_model_checkpoint_paths[0], + os.path.join(save_dir, "./model.ckpt-687500")) + self.assertEqual(ckpt.all_model_checkpoint_paths[1], + os.path.join(save_dir, "./model.ckpt-687529")) + + +class SaverUtilsTest(test.TestCase): + + def setUp(self): + self._base_dir = os.path.join(self.get_temp_dir(), "saver_utils_test") + gfile.MakeDirs(self._base_dir) + + def tearDown(self): + gfile.DeleteRecursively(self._base_dir) + + def testCheckpointExists(self): + for sharded in (False, True): + for version in (saver_pb2.SaverDef.V2, saver_pb2.SaverDef.V1): + with self.test_session(graph=ops_lib.Graph()) as sess: + unused_v = variables.Variable(1.0, name="v") + variables.global_variables_initializer().run() + saver = saver_module.Saver(sharded=sharded, write_version=version) + + path = os.path.join(self._base_dir, "%s-%s" % (sharded, version)) + self.assertFalse( + checkpoint_management.checkpoint_exists(path)) # Not saved yet. + + ckpt_prefix = saver.save(sess, path) + self.assertTrue(checkpoint_management.checkpoint_exists(ckpt_prefix)) + + ckpt_prefix = checkpoint_management.latest_checkpoint(self._base_dir) + self.assertTrue(checkpoint_management.checkpoint_exists(ckpt_prefix)) + + def testGetCheckpointMtimes(self): + prefixes = [] + for version in (saver_pb2.SaverDef.V2, saver_pb2.SaverDef.V1): + with self.test_session(graph=ops_lib.Graph()) as sess: + unused_v = variables.Variable(1.0, name="v") + variables.global_variables_initializer().run() + saver = saver_module.Saver(write_version=version) + prefixes.append( + saver.save(sess, os.path.join(self._base_dir, str(version)))) + + mtimes = checkpoint_management.get_checkpoint_mtimes(prefixes) + self.assertEqual(2, len(mtimes)) + self.assertTrue(mtimes[1] >= mtimes[0]) + + def testRemoveCheckpoint(self): + for sharded in (False, True): + for version in (saver_pb2.SaverDef.V2, saver_pb2.SaverDef.V1): + with self.test_session(graph=ops_lib.Graph()) as sess: + unused_v = variables.Variable(1.0, name="v") + variables.global_variables_initializer().run() + saver = saver_module.Saver(sharded=sharded, write_version=version) + + path = os.path.join(self._base_dir, "%s-%s" % (sharded, version)) + ckpt_prefix = saver.save(sess, path) + self.assertTrue(checkpoint_management.checkpoint_exists(ckpt_prefix)) + checkpoint_management.remove_checkpoint(ckpt_prefix, version) + self.assertFalse(checkpoint_management.checkpoint_exists(ckpt_prefix)) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/training/checkpoint_utils.py b/tensorflow/python/training/checkpoint_utils.py index a052081630..9b72b09f08 100644 --- a/tensorflow/python/training/checkpoint_utils.py +++ b/tensorflow/python/training/checkpoint_utils.py @@ -28,6 +28,7 @@ from tensorflow.python.ops import variable_scope as vs from tensorflow.python.ops import variables from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.training import checkpoint_management from tensorflow.python.training import distribute as distribute_lib from tensorflow.python.training import saver from tensorflow.python.util.tf_export import tf_export @@ -277,7 +278,7 @@ def _init_from_checkpoint(_, ckpt_dir_or_file, assignment_map): def _get_checkpoint_filename(ckpt_dir_or_file): """Returns checkpoint filename given directory or specific checkpoint file.""" if gfile.IsDirectory(ckpt_dir_or_file): - return saver.latest_checkpoint(ckpt_dir_or_file) + return checkpoint_management.latest_checkpoint(ckpt_dir_or_file) return ckpt_dir_or_file diff --git a/tensorflow/python/training/checkpointable/BUILD b/tensorflow/python/training/checkpointable/BUILD index 35007653a0..8a289b31b5 100644 --- a/tensorflow/python/training/checkpointable/BUILD +++ b/tensorflow/python/training/checkpointable/BUILD @@ -124,14 +124,18 @@ py_test( ], deps = [ ":base", + ":tracking", ":util", + "//tensorflow/python:checkpoint_management", "//tensorflow/python:constant_op", "//tensorflow/python:control_flow_ops", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", "//tensorflow/python:framework_test_lib", "//tensorflow/python:init_ops", + "//tensorflow/python:pywrap_tensorflow", "//tensorflow/python:resource_variable_ops", + "//tensorflow/python:saver", "//tensorflow/python:session", "//tensorflow/python:state_ops", "//tensorflow/python:template", diff --git a/tensorflow/python/training/checkpointable/util_test.py b/tensorflow/python/training/checkpointable/util_test.py index 3c1a4a6f83..5506e6bc4e 100644 --- a/tensorflow/python/training/checkpointable/util_test.py +++ b/tensorflow/python/training/checkpointable/util_test.py @@ -42,6 +42,7 @@ from tensorflow.python.ops import state_ops from tensorflow.python.ops import template from tensorflow.python.ops import variable_scope from tensorflow.python.training import adam +from tensorflow.python.training import checkpoint_management from tensorflow.python.training import saver as saver_lib from tensorflow.python.training import training_util from tensorflow.python.training.checkpointable import base @@ -467,7 +468,8 @@ class CheckpointingTests(test.TestCase): root = checkpointable_utils.Checkpoint( optimizer=optimizer, model=model, optimizer_step=training_util.get_or_create_global_step()) - root.restore(saver_lib.latest_checkpoint(checkpoint_directory)) + root.restore(checkpoint_management.latest_checkpoint( + checkpoint_directory)) for _ in range(num_training_steps): # TODO(allenl): Use a Dataset and serialize/checkpoint it. input_value = constant_op.constant([[3.]]) @@ -495,7 +497,8 @@ class CheckpointingTests(test.TestCase): train_op = optimizer.minimize( model(input_value), global_step=root.global_step) - checkpoint_path = saver_lib.latest_checkpoint(checkpoint_directory) + checkpoint_path = checkpoint_management.latest_checkpoint( + checkpoint_directory) with self.test_session(graph=ops.get_default_graph()) as session: status = root.restore(save_path=checkpoint_path) status.initialize_or_restore(session=session) @@ -528,7 +531,8 @@ class CheckpointingTests(test.TestCase): root = checkpointable_utils.Checkpoint( optimizer=optimizer, model=model, global_step=training_util.get_or_create_global_step()) - checkpoint_path = saver_lib.latest_checkpoint(checkpoint_directory) + checkpoint_path = checkpoint_management.latest_checkpoint( + checkpoint_directory) status = root.restore(save_path=checkpoint_path) input_value = constant_op.constant([[3.]]) train_fn = functools.partial( @@ -561,7 +565,8 @@ class CheckpointingTests(test.TestCase): root = checkpointable_utils.Checkpoint( optimizer=optimizer, model=model, global_step=training_util.get_or_create_global_step()) - checkpoint_path = saver_lib.latest_checkpoint(checkpoint_directory) + checkpoint_path = checkpoint_management.latest_checkpoint( + checkpoint_directory) status = root.restore(save_path=checkpoint_path) def train_fn(): @function.defun @@ -1180,7 +1185,8 @@ class CheckpointingTests(test.TestCase): optimizer_checkpoint = checkpointable_utils.Checkpoint( optimizer=optimizer) - checkpoint_path = saver_lib.latest_checkpoint(checkpoint_directory) + checkpoint_path = checkpoint_management.latest_checkpoint( + checkpoint_directory) status = root.restore(save_path=checkpoint_path) input_value = constant_op.constant([[3.]]) train_fn = functools.partial( diff --git a/tensorflow/python/training/monitored_session_test.py b/tensorflow/python/training/monitored_session_test.py index 3806056f01..92533ca4f3 100644 --- a/tensorflow/python/training/monitored_session_test.py +++ b/tensorflow/python/training/monitored_session_test.py @@ -44,6 +44,7 @@ from tensorflow.python.ops import variables from tensorflow.python.platform import test from tensorflow.python.summary import summary from tensorflow.python.training import basic_session_run_hooks +from tensorflow.python.training import checkpoint_management from tensorflow.python.training import coordinator from tensorflow.python.training import monitored_session from tensorflow.python.training import saver as saver_lib @@ -1364,8 +1365,8 @@ class MonitoredSessionTest(test.TestCase): with monitored_session.MonitoredSession( session_creator=monitored_session.ChiefSessionCreator( scaffold, - checkpoint_filename_with_path=saver_lib.latest_checkpoint( - logdir))) as session: + checkpoint_filename_with_path= + checkpoint_management.latest_checkpoint(logdir))) as session: self.assertEqual(2, session.run(gstep)) def test_retry_initialization_on_aborted_error(self): diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py index c80cdf03be..13a97a9545 100644 --- a/tensorflow/python/training/saver.py +++ b/tensorflow/python/training/saver.py @@ -21,15 +21,12 @@ from __future__ import print_function import collections import os.path -import re import time import uuid import numpy as np import six -from google.protobuf import text_format - from tensorflow.core.protobuf import checkpointable_object_graph_pb2 from tensorflow.core.protobuf import meta_graph_pb2 from tensorflow.core.protobuf import saver_pb2 @@ -41,7 +38,6 @@ from tensorflow.python.framework import device as pydev from tensorflow.python.framework import errors from tensorflow.python.framework import meta_graph from tensorflow.python.framework import ops -from tensorflow.python.lib.io import file_io from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gen_io_ops @@ -52,14 +48,19 @@ from tensorflow.python.ops import string_ops from tensorflow.python.ops import variables from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.training import checkpoint_management from tensorflow.python.training import saveable_object from tensorflow.python.training import training_util -from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState from tensorflow.python.training.checkpointable import base as checkpointable from tensorflow.python.util import compat from tensorflow.python.util.tf_export import tf_export +# TODO(allenl): Remove these aliases once all users are migrated off. +get_checkpoint_state = checkpoint_management.get_checkpoint_state +update_checkpoint_state = checkpoint_management.update_checkpoint_state + + # Op names which identify variable reads which should be saved. _VARIABLE_OPS = set(["Variable", "VariableV2", @@ -858,218 +859,6 @@ def _get_saver_or_default(): return saver -def _GetCheckpointFilename(save_dir, latest_filename): - """Returns a filename for storing the CheckpointState. - - Args: - save_dir: The directory for saving and restoring checkpoints. - latest_filename: Name of the file in 'save_dir' that is used - to store the CheckpointState. - - Returns: - The path of the file that contains the CheckpointState proto. - """ - if latest_filename is None: - latest_filename = "checkpoint" - return os.path.join(save_dir, latest_filename) - - -@tf_export("train.generate_checkpoint_state_proto") -def generate_checkpoint_state_proto(save_dir, - model_checkpoint_path, - all_model_checkpoint_paths=None): - """Generates a checkpoint state proto. - - Args: - save_dir: Directory where the model was saved. - model_checkpoint_path: The checkpoint file. - all_model_checkpoint_paths: List of strings. Paths to all not-yet-deleted - checkpoints, sorted from oldest to newest. If this is a non-empty list, - the last element must be equal to model_checkpoint_path. These paths - are also saved in the CheckpointState proto. - - Returns: - CheckpointState proto with model_checkpoint_path and - all_model_checkpoint_paths updated to either absolute paths or - relative paths to the current save_dir. - """ - if all_model_checkpoint_paths is None: - all_model_checkpoint_paths = [] - - if (not all_model_checkpoint_paths or - all_model_checkpoint_paths[-1] != model_checkpoint_path): - logging.info("%s is not in all_model_checkpoint_paths. Manually adding it.", - model_checkpoint_path) - all_model_checkpoint_paths.append(model_checkpoint_path) - - # Relative paths need to be rewritten to be relative to the "save_dir" - # if model_checkpoint_path already contains "save_dir". - if not os.path.isabs(save_dir): - if not os.path.isabs(model_checkpoint_path): - model_checkpoint_path = os.path.relpath(model_checkpoint_path, save_dir) - for i in range(len(all_model_checkpoint_paths)): - p = all_model_checkpoint_paths[i] - if not os.path.isabs(p): - all_model_checkpoint_paths[i] = os.path.relpath(p, save_dir) - - coord_checkpoint_proto = CheckpointState( - model_checkpoint_path=model_checkpoint_path, - all_model_checkpoint_paths=all_model_checkpoint_paths) - - return coord_checkpoint_proto - - -@tf_export("train.update_checkpoint_state") -def update_checkpoint_state(save_dir, - model_checkpoint_path, - all_model_checkpoint_paths=None, - latest_filename=None): - """Updates the content of the 'checkpoint' file. - - This updates the checkpoint file containing a CheckpointState - proto. - - Args: - save_dir: Directory where the model was saved. - model_checkpoint_path: The checkpoint file. - all_model_checkpoint_paths: List of strings. Paths to all not-yet-deleted - checkpoints, sorted from oldest to newest. If this is a non-empty list, - the last element must be equal to model_checkpoint_path. These paths - are also saved in the CheckpointState proto. - latest_filename: Optional name of the checkpoint file. Default to - 'checkpoint'. - - Raises: - RuntimeError: If any of the model checkpoint paths conflict with the file - containing CheckpointSate. - """ - _update_checkpoint_state( - save_dir=save_dir, - model_checkpoint_path=model_checkpoint_path, - all_model_checkpoint_paths=all_model_checkpoint_paths, - latest_filename=latest_filename, - save_relative_paths=False) - - -def _update_checkpoint_state(save_dir, - model_checkpoint_path, - all_model_checkpoint_paths=None, - latest_filename=None, - save_relative_paths=False): - """Updates the content of the 'checkpoint' file. - - This updates the checkpoint file containing a CheckpointState - proto. - - Args: - save_dir: Directory where the model was saved. - model_checkpoint_path: The checkpoint file. - all_model_checkpoint_paths: List of strings. Paths to all not-yet-deleted - checkpoints, sorted from oldest to newest. If this is a non-empty list, - the last element must be equal to model_checkpoint_path. These paths - are also saved in the CheckpointState proto. - latest_filename: Optional name of the checkpoint file. Default to - 'checkpoint'. - save_relative_paths: If `True`, will write relative paths to the checkpoint - state file. - - Raises: - RuntimeError: If any of the model checkpoint paths conflict with the file - containing CheckpointSate. - """ - # Writes the "checkpoint" file for the coordinator for later restoration. - coord_checkpoint_filename = _GetCheckpointFilename(save_dir, latest_filename) - if save_relative_paths: - if os.path.isabs(model_checkpoint_path): - rel_model_checkpoint_path = os.path.relpath( - model_checkpoint_path, save_dir) - else: - rel_model_checkpoint_path = model_checkpoint_path - rel_all_model_checkpoint_paths = [] - for p in all_model_checkpoint_paths: - if os.path.isabs(p): - rel_all_model_checkpoint_paths.append(os.path.relpath(p, save_dir)) - else: - rel_all_model_checkpoint_paths.append(p) - ckpt = generate_checkpoint_state_proto( - save_dir, - rel_model_checkpoint_path, - all_model_checkpoint_paths=rel_all_model_checkpoint_paths) - else: - ckpt = generate_checkpoint_state_proto( - save_dir, - model_checkpoint_path, - all_model_checkpoint_paths=all_model_checkpoint_paths) - - if coord_checkpoint_filename == ckpt.model_checkpoint_path: - raise RuntimeError("Save path '%s' conflicts with path used for " - "checkpoint state. Please use a different save path." % - model_checkpoint_path) - - # Preventing potential read/write race condition by *atomically* writing to a - # file. - file_io.atomic_write_string_to_file(coord_checkpoint_filename, - text_format.MessageToString(ckpt)) - - -@tf_export("train.get_checkpoint_state") -def get_checkpoint_state(checkpoint_dir, latest_filename=None): - """Returns CheckpointState proto from the "checkpoint" file. - - If the "checkpoint" file contains a valid CheckpointState - proto, returns it. - - Args: - checkpoint_dir: The directory of checkpoints. - latest_filename: Optional name of the checkpoint file. Default to - 'checkpoint'. - - Returns: - A CheckpointState if the state was available, None - otherwise. - - Raises: - ValueError: if the checkpoint read doesn't have model_checkpoint_path set. - """ - ckpt = None - coord_checkpoint_filename = _GetCheckpointFilename(checkpoint_dir, - latest_filename) - f = None - try: - # Check that the file exists before opening it to avoid - # many lines of errors from colossus in the logs. - if file_io.file_exists(coord_checkpoint_filename): - file_content = file_io.read_file_to_string( - coord_checkpoint_filename) - ckpt = CheckpointState() - text_format.Merge(file_content, ckpt) - if not ckpt.model_checkpoint_path: - raise ValueError("Invalid checkpoint state loaded from " - + checkpoint_dir) - # For relative model_checkpoint_path and all_model_checkpoint_paths, - # prepend checkpoint_dir. - if not os.path.isabs(ckpt.model_checkpoint_path): - ckpt.model_checkpoint_path = os.path.join(checkpoint_dir, - ckpt.model_checkpoint_path) - for i in range(len(ckpt.all_model_checkpoint_paths)): - p = ckpt.all_model_checkpoint_paths[i] - if not os.path.isabs(p): - ckpt.all_model_checkpoint_paths[i] = os.path.join(checkpoint_dir, p) - except errors.OpError as e: - # It's ok if the file cannot be read - logging.warning("%s: %s", type(e).__name__, e) - logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename) - return None - except text_format.ParseError as e: - logging.warning("%s: %s", type(e).__name__, e) - logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename) - return None - finally: - if f: - f.close() - return ckpt - - @tf_export("train.Saver") class Saver(object): """Saves and restores variables. @@ -1412,7 +1201,7 @@ class Saver(object): # Otherwise delete the files. try: - remove_checkpoint( + checkpoint_management.remove_checkpoint( self._CheckpointFilename(p), self.saver_def.version, meta_graph_suffix) except Exception as e: # pylint: disable=broad-except @@ -1518,7 +1307,7 @@ class Saver(object): Args: checkpoint_paths: a list of checkpoint paths. """ - mtimes = get_checkpoint_mtimes(checkpoint_paths) + mtimes = checkpoint_management.get_checkpoint_mtimes(checkpoint_paths) self.set_last_checkpoints_with_time(list(zip(checkpoint_paths, mtimes))) def save(self, @@ -1624,7 +1413,7 @@ class Saver(object): model_checkpoint_path = compat.as_str(model_checkpoint_path) if write_state: self._RecordLastCheckpoint(model_checkpoint_path) - _update_checkpoint_state( + checkpoint_management.update_checkpoint_state_internal( save_dir=save_path_parent, model_checkpoint_path=model_checkpoint_path, all_model_checkpoint_paths=self.last_checkpoints, @@ -1639,7 +1428,7 @@ class Saver(object): raise exc if write_meta_graph: - meta_graph_filename = _meta_graph_filename( + meta_graph_filename = checkpoint_management.meta_graph_filename( checkpoint_file, meta_graph_suffix=meta_graph_suffix) if not context.executing_eagerly(): with sess.graph.as_default(): @@ -1714,7 +1503,7 @@ class Saver(object): if save_path is None: raise ValueError("Can't load save_path when it is None.") - if not checkpoint_exists(compat.as_text(save_path)): + if not checkpoint_management.checkpoint_exists(compat.as_text(save_path)): raise ValueError("The passed save_path is not a valid checkpoint: " + compat.as_text(save_path)) @@ -1800,55 +1589,6 @@ class Saver(object): export_scope=export_scope) -def _prefix_to_checkpoint_path(prefix, format_version): - """Returns the pathname of a checkpoint file, given the checkpoint prefix. - - For V1 checkpoint, simply returns the prefix itself (the data file). For V2, - returns the pathname to the index file. - - Args: - prefix: a string, the prefix of a checkpoint. - format_version: the checkpoint format version that corresponds to the - prefix. - Returns: - The pathname of a checkpoint file, taking into account the checkpoint - format version. - """ - if format_version == saver_pb2.SaverDef.V2: - return prefix + ".index" # The index file identifies a checkpoint. - return prefix # Just the data file. - - -@tf_export("train.latest_checkpoint") -def latest_checkpoint(checkpoint_dir, latest_filename=None): - """Finds the filename of latest saved checkpoint file. - - Args: - checkpoint_dir: Directory where the variables were saved. - latest_filename: Optional name for the protocol buffer file that - contains the list of most recent checkpoint filenames. - See the corresponding argument to `Saver.save()`. - - Returns: - The full path to the latest checkpoint or `None` if no checkpoint was found. - """ - # Pick the latest checkpoint based on checkpoint state. - ckpt = get_checkpoint_state(checkpoint_dir, latest_filename) - if ckpt and ckpt.model_checkpoint_path: - # Look for either a V2 path or a V1 path, with priority for V2. - v2_path = _prefix_to_checkpoint_path(ckpt.model_checkpoint_path, - saver_pb2.SaverDef.V2) - v1_path = _prefix_to_checkpoint_path(ckpt.model_checkpoint_path, - saver_pb2.SaverDef.V1) - if file_io.get_matching_files(v2_path) or file_io.get_matching_files( - v1_path): - return ckpt.model_checkpoint_path - else: - logging.error("Couldn't match files for checkpoint %s", - ckpt.model_checkpoint_path) - return None - - @tf_export("train.import_meta_graph") def import_meta_graph(meta_graph_or_file, clear_devices=False, import_scope=None, **kwargs): @@ -2056,119 +1796,6 @@ def export_meta_graph(filename=None, return meta_graph_def -@tf_export("train.checkpoint_exists") -def checkpoint_exists(checkpoint_prefix): - """Checks whether a V1 or V2 checkpoint exists with the specified prefix. - - This is the recommended way to check if a checkpoint exists, since it takes - into account the naming difference between V1 and V2 formats. - - Args: - checkpoint_prefix: the prefix of a V1 or V2 checkpoint, with V2 taking - priority. Typically the result of `Saver.save()` or that of - `tf.train.latest_checkpoint()`, regardless of sharded/non-sharded or - V1/V2. - Returns: - A bool, true iff a checkpoint referred to by `checkpoint_prefix` exists. - """ - pathname = _prefix_to_checkpoint_path(checkpoint_prefix, - saver_pb2.SaverDef.V2) - if file_io.get_matching_files(pathname): - return True - elif file_io.get_matching_files(checkpoint_prefix): - return True - else: - return False - - -@tf_export("train.get_checkpoint_mtimes") -def get_checkpoint_mtimes(checkpoint_prefixes): - """Returns the mtimes (modification timestamps) of the checkpoints. - - Globs for the checkpoints pointed to by `checkpoint_prefixes`. If the files - exist, collect their mtime. Both V2 and V1 checkpoints are considered, in - that priority. - - This is the recommended way to get the mtimes, since it takes into account - the naming difference between V1 and V2 formats. - - Args: - checkpoint_prefixes: a list of checkpoint paths, typically the results of - `Saver.save()` or those of `tf.train.latest_checkpoint()`, regardless of - sharded/non-sharded or V1/V2. - Returns: - A list of mtimes (in microseconds) of the found checkpoints. - """ - mtimes = [] - - def match_maybe_append(pathname): - fnames = file_io.get_matching_files(pathname) - if fnames: - mtimes.append(file_io.stat(fnames[0]).mtime_nsec / 1e9) - return True - return False - - for checkpoint_prefix in checkpoint_prefixes: - # Tries V2's metadata file first. - pathname = _prefix_to_checkpoint_path(checkpoint_prefix, - saver_pb2.SaverDef.V2) - if match_maybe_append(pathname): - continue - # Otherwise, tries V1, where the prefix is the complete pathname. - match_maybe_append(checkpoint_prefix) - - return mtimes - - -@tf_export("train.remove_checkpoint") -def remove_checkpoint(checkpoint_prefix, - checkpoint_format_version=saver_pb2.SaverDef.V2, - meta_graph_suffix="meta"): - """Removes a checkpoint given by `checkpoint_prefix`. - - Args: - checkpoint_prefix: The prefix of a V1 or V2 checkpoint. Typically the result - of `Saver.save()` or that of `tf.train.latest_checkpoint()`, regardless of - sharded/non-sharded or V1/V2. - checkpoint_format_version: `SaverDef.CheckpointFormatVersion`, defaults to - `SaverDef.V2`. - meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'. - """ - _delete_file_if_exists( - _meta_graph_filename(checkpoint_prefix, meta_graph_suffix)) - if checkpoint_format_version == saver_pb2.SaverDef.V2: - # V2 has a metadata file and some data files. - _delete_file_if_exists(checkpoint_prefix + ".index") - _delete_file_if_exists(checkpoint_prefix + ".data-?????-of-?????") - else: - # V1, Legacy. Exact match on the data file. - _delete_file_if_exists(checkpoint_prefix) - - -def _delete_file_if_exists(filespec): - """Deletes files matching `filespec`.""" - for pathname in file_io.get_matching_files(filespec): - file_io.delete_file(pathname) - - -def _meta_graph_filename(checkpoint_filename, meta_graph_suffix="meta"): - """Returns the meta graph filename. - - Args: - checkpoint_filename: Name of the checkpoint file. - meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'. - - Returns: - MetaGraph file name. - """ - # If the checkpoint_filename is sharded, the checkpoint_filename could - # be of format model.ckpt-step#-?????-of-shard#. For example, - # model.ckpt-123456-?????-of-00005, or model.ckpt-123456-00001-of-00002. - basename = re.sub(r"-[\d\?]+-of-\d+$", "", checkpoint_filename) - meta_graph_filename = ".".join([basename, meta_graph_suffix]) - return meta_graph_filename - - def _wrap_restore_error_with_msg(err, extra_verbiage): err_msg = ("Restoring from checkpoint failed. This is most likely " "due to {} from the checkpoint. Please ensure that you " diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py index 204e81dda0..941aafc780 100644 --- a/tensorflow/python/training/saver_test.py +++ b/tensorflow/python/training/saver_test.py @@ -18,20 +18,16 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import contextlib import functools import math import os import random -import shutil -import tempfile import time import numpy as np import six from google.protobuf.any_pb2 import Any -from google.protobuf import text_format from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import meta_graph_pb2 @@ -71,12 +67,12 @@ from tensorflow.python.platform import gfile from tensorflow.python.platform import test from tensorflow.python.summary import summary from tensorflow.python.training import adam +from tensorflow.python.training import checkpoint_management from tensorflow.python.training import gradient_descent from tensorflow.python.training import queue_runner_impl from tensorflow.python.training import saver as saver_module from tensorflow.python.training import saver_test_utils from tensorflow.python.training import training_util -from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState from tensorflow.python.training.checkpointable import base as checkpointable_base from tensorflow.python.training.checkpointable import tracking as checkpointable_tracking from tensorflow.python.training.checkpointable import util as checkpointable_utils @@ -343,11 +339,13 @@ class SaverTest(test.TestCase): self.assertTrue(isinstance(val, six.string_types)) self.assertEqual(save_path1, val) - self.assertEqual(saver_module.latest_checkpoint(save_dir1), save_path1) + self.assertEqual( + checkpoint_management.latest_checkpoint(save_dir1), save_path1) save_dir2 = os.path.join(self.get_temp_dir(), "save_dir2") os.renames(save_dir1, save_dir2) save_path2 = os.path.join(save_dir2, "save_copy_restore") - self.assertEqual(saver_module.latest_checkpoint(save_dir2), save_path2) + self.assertEqual( + checkpoint_management.latest_checkpoint(save_dir2), save_path2) # Start a second session. In that session the parameter nodes # have not been initialized either. @@ -857,7 +855,7 @@ class SaveRestoreShardedTest(test.TestCase): self.assertEqual(save_path + "-?????-of-00002", val) else: self.assertEqual(save_path, val) - meta_graph_filename = saver_module._meta_graph_filename(val) + meta_graph_filename = checkpoint_management.meta_graph_filename(val) self.assertEqual(save_path + ".meta", meta_graph_filename) if save._write_version is saver_pb2.SaverDef.V1: @@ -951,11 +949,11 @@ class SaveRestoreShardedTest(test.TestCase): if save._write_version is saver_pb2.SaverDef.V1: self.assertEqual( - saver_module.latest_checkpoint(self.get_temp_dir()), + checkpoint_management.latest_checkpoint(self.get_temp_dir()), os.path.join(self.get_temp_dir(), "sharded_basics-?????-of-00002")) else: self.assertEqual( - saver_module.latest_checkpoint(self.get_temp_dir()), + checkpoint_management.latest_checkpoint(self.get_temp_dir()), os.path.join(self.get_temp_dir(), "sharded_basics")) def testSaverDef(self): @@ -1105,7 +1103,7 @@ class MaxToKeepTest(test.TestCase): def assertCheckpointState(self, model_checkpoint_path, all_model_checkpoint_paths, save_dir): - checkpoint_state = saver_module.get_checkpoint_state(save_dir) + checkpoint_state = checkpoint_management.get_checkpoint_state(save_dir) self.assertEqual(checkpoint_state.model_checkpoint_path, model_checkpoint_path) self.assertEqual(checkpoint_state.all_model_checkpoint_paths, @@ -1113,7 +1111,7 @@ class MaxToKeepTest(test.TestCase): def testMaxToKeepEager(self): with context.eager_mode(): - save_dir = self._get_test_dir("max_to_keep_non_sharded") + save_dir = self._get_test_dir("max_to_keep_eager") v = variable_scope.variable(10.0, name="v") save = saver_module.Saver({"v": v}, max_to_keep=2) @@ -1123,7 +1121,7 @@ class MaxToKeepTest(test.TestCase): s1 = save.save(None, os.path.join(save_dir, "s1")) self.assertEqual([s1], save.last_checkpoints) - self.assertTrue(saver_module.checkpoint_exists(s1)) + self.assertTrue(checkpoint_management.checkpoint_exists(s1)) self.assertCheckpointState( model_checkpoint_path=s1, all_model_checkpoint_paths=[s1], @@ -1131,8 +1129,8 @@ class MaxToKeepTest(test.TestCase): s2 = save.save(None, os.path.join(save_dir, "s2")) self.assertEqual([s1, s2], save.last_checkpoints) - self.assertTrue(saver_module.checkpoint_exists(s1)) - self.assertTrue(saver_module.checkpoint_exists(s2)) + self.assertTrue(checkpoint_management.checkpoint_exists(s1)) + self.assertTrue(checkpoint_management.checkpoint_exists(s2)) self.assertCheckpointState( model_checkpoint_path=s2, all_model_checkpoint_paths=[s1, s2], @@ -1140,9 +1138,9 @@ class MaxToKeepTest(test.TestCase): s3 = save.save(None, os.path.join(save_dir, "s3")) self.assertEqual([s2, s3], save.last_checkpoints) - self.assertFalse(saver_module.checkpoint_exists(s1)) - self.assertTrue(saver_module.checkpoint_exists(s2)) - self.assertTrue(saver_module.checkpoint_exists(s3)) + self.assertFalse(checkpoint_management.checkpoint_exists(s1)) + self.assertTrue(checkpoint_management.checkpoint_exists(s2)) + self.assertTrue(checkpoint_management.checkpoint_exists(s3)) self.assertCheckpointState( model_checkpoint_path=s3, all_model_checkpoint_paths=[s2, s3], @@ -1157,9 +1155,9 @@ class MaxToKeepTest(test.TestCase): # Adding s2 again (old s2 is removed first, then new s2 appended) s2 = save.save(None, os.path.join(save_dir, "s2")) self.assertEqual([s3, s2], save.last_checkpoints) - self.assertFalse(saver_module.checkpoint_exists(s1)) - self.assertTrue(saver_module.checkpoint_exists(s3)) - self.assertTrue(saver_module.checkpoint_exists(s2)) + self.assertFalse(checkpoint_management.checkpoint_exists(s1)) + self.assertTrue(checkpoint_management.checkpoint_exists(s3)) + self.assertTrue(checkpoint_management.checkpoint_exists(s2)) self.assertCheckpointState( model_checkpoint_path=s2, all_model_checkpoint_paths=[s3, s2], @@ -1168,8 +1166,8 @@ class MaxToKeepTest(test.TestCase): # Adding s1 (s3 should now be deleted as oldest in list) s1 = save.save(None, os.path.join(save_dir, "s1")) self.assertEqual([s2, s1], save.last_checkpoints) - self.assertFalse(saver_module.checkpoint_exists(s3)) - self.assertTrue(saver_module.checkpoint_exists(s2)) + self.assertFalse(checkpoint_management.checkpoint_exists(s3)) + self.assertTrue(checkpoint_management.checkpoint_exists(s2)) self.assertCheckpointState( model_checkpoint_path=s1, all_model_checkpoint_paths=[s2, s1], @@ -1178,9 +1176,9 @@ class MaxToKeepTest(test.TestCase): s2 = save2.save(None, os.path.join(save_dir, "s2")) self.assertEqual([s3, s2], save2.last_checkpoints) # Created by the first helper. - self.assertTrue(saver_module.checkpoint_exists(s1)) + self.assertTrue(checkpoint_management.checkpoint_exists(s1)) # Deleted by the first helper. - self.assertFalse(saver_module.checkpoint_exists(s3)) + self.assertFalse(checkpoint_management.checkpoint_exists(s3)) def testNonSharded(self): save_dir = self._get_test_dir("max_to_keep_non_sharded") @@ -1193,7 +1191,7 @@ class MaxToKeepTest(test.TestCase): s1 = save.save(sess, os.path.join(save_dir, "s1")) self.assertEqual([s1], save.last_checkpoints) - self.assertTrue(saver_module.checkpoint_exists(s1)) + self.assertTrue(checkpoint_management.checkpoint_exists(s1)) self.assertCheckpointState( model_checkpoint_path=s1, all_model_checkpoint_paths=[s1], @@ -1201,8 +1199,8 @@ class MaxToKeepTest(test.TestCase): s2 = save.save(sess, os.path.join(save_dir, "s2")) self.assertEqual([s1, s2], save.last_checkpoints) - self.assertTrue(saver_module.checkpoint_exists(s1)) - self.assertTrue(saver_module.checkpoint_exists(s2)) + self.assertTrue(checkpoint_management.checkpoint_exists(s1)) + self.assertTrue(checkpoint_management.checkpoint_exists(s2)) self.assertCheckpointState( model_checkpoint_path=s2, all_model_checkpoint_paths=[s1, s2], @@ -1210,9 +1208,9 @@ class MaxToKeepTest(test.TestCase): s3 = save.save(sess, os.path.join(save_dir, "s3")) self.assertEqual([s2, s3], save.last_checkpoints) - self.assertFalse(saver_module.checkpoint_exists(s1)) - self.assertTrue(saver_module.checkpoint_exists(s2)) - self.assertTrue(saver_module.checkpoint_exists(s3)) + self.assertFalse(checkpoint_management.checkpoint_exists(s1)) + self.assertTrue(checkpoint_management.checkpoint_exists(s2)) + self.assertTrue(checkpoint_management.checkpoint_exists(s3)) self.assertCheckpointState( model_checkpoint_path=s3, all_model_checkpoint_paths=[s2, s3], @@ -1231,15 +1229,18 @@ class MaxToKeepTest(test.TestCase): # Adding s2 again (old s2 is removed first, then new s2 appended) s2 = save.save(sess, os.path.join(save_dir, "s2")) self.assertEqual([s3, s2], save.last_checkpoints) - self.assertFalse(saver_module.checkpoint_exists(s1)) + self.assertFalse(checkpoint_management.checkpoint_exists(s1)) self.assertFalse( - saver_module.checkpoint_exists(saver_module._meta_graph_filename(s1))) - self.assertTrue(saver_module.checkpoint_exists(s3)) + checkpoint_management.checkpoint_exists( + checkpoint_management.meta_graph_filename(s1))) + self.assertTrue(checkpoint_management.checkpoint_exists(s3)) self.assertTrue( - saver_module.checkpoint_exists(saver_module._meta_graph_filename(s3))) - self.assertTrue(saver_module.checkpoint_exists(s2)) + checkpoint_management.checkpoint_exists( + checkpoint_management.meta_graph_filename(s3))) + self.assertTrue(checkpoint_management.checkpoint_exists(s2)) self.assertTrue( - saver_module.checkpoint_exists(saver_module._meta_graph_filename(s2))) + checkpoint_management.checkpoint_exists( + checkpoint_management.meta_graph_filename(s2))) self.assertCheckpointState( model_checkpoint_path=s2, all_model_checkpoint_paths=[s3, s2], @@ -1248,15 +1249,18 @@ class MaxToKeepTest(test.TestCase): # Adding s1 (s3 should now be deleted as oldest in list) s1 = save.save(sess, os.path.join(save_dir, "s1")) self.assertEqual([s2, s1], save.last_checkpoints) - self.assertFalse(saver_module.checkpoint_exists(s3)) + self.assertFalse(checkpoint_management.checkpoint_exists(s3)) self.assertFalse( - saver_module.checkpoint_exists(saver_module._meta_graph_filename(s3))) - self.assertTrue(saver_module.checkpoint_exists(s2)) + checkpoint_management.checkpoint_exists( + checkpoint_management.meta_graph_filename(s3))) + self.assertTrue(checkpoint_management.checkpoint_exists(s2)) self.assertTrue( - saver_module.checkpoint_exists(saver_module._meta_graph_filename(s2))) - self.assertTrue(saver_module.checkpoint_exists(s1)) + checkpoint_management.checkpoint_exists( + checkpoint_management.meta_graph_filename(s2))) + self.assertTrue(checkpoint_management.checkpoint_exists(s1)) self.assertTrue( - saver_module.checkpoint_exists(saver_module._meta_graph_filename(s1))) + checkpoint_management.checkpoint_exists( + checkpoint_management.meta_graph_filename(s1))) self.assertCheckpointState( model_checkpoint_path=s1, all_model_checkpoint_paths=[s2, s1], @@ -1268,16 +1272,19 @@ class MaxToKeepTest(test.TestCase): s2 = save2.save(sess, os.path.join(save_dir, "s2")) self.assertEqual([s3, s2], save2.last_checkpoints) # Created by the first helper. - self.assertTrue(saver_module.checkpoint_exists(s1)) + self.assertTrue(checkpoint_management.checkpoint_exists(s1)) self.assertTrue( - saver_module.checkpoint_exists(saver_module._meta_graph_filename(s1))) + checkpoint_management.checkpoint_exists( + checkpoint_management.meta_graph_filename(s1))) # Deleted by the first helper. - self.assertFalse(saver_module.checkpoint_exists(s3)) + self.assertFalse(checkpoint_management.checkpoint_exists(s3)) self.assertFalse( - saver_module.checkpoint_exists(saver_module._meta_graph_filename(s3))) - self.assertTrue(saver_module.checkpoint_exists(s2)) + checkpoint_management.checkpoint_exists( + checkpoint_management.meta_graph_filename(s3))) + self.assertTrue(checkpoint_management.checkpoint_exists(s2)) self.assertTrue( - saver_module.checkpoint_exists(saver_module._meta_graph_filename(s2))) + checkpoint_management.checkpoint_exists( + checkpoint_management.meta_graph_filename(s2))) self.assertCheckpointState( model_checkpoint_path=s2, all_model_checkpoint_paths=[s3, s2], @@ -1286,15 +1293,18 @@ class MaxToKeepTest(test.TestCase): # Adding s1 (s3 should now be deleted as oldest in list) s1 = save2.save(sess, os.path.join(save_dir, "s1")) self.assertEqual([s2, s1], save2.last_checkpoints) - self.assertFalse(saver_module.checkpoint_exists(s3)) + self.assertFalse(checkpoint_management.checkpoint_exists(s3)) self.assertFalse( - saver_module.checkpoint_exists(saver_module._meta_graph_filename(s3))) - self.assertTrue(saver_module.checkpoint_exists(s2)) + checkpoint_management.checkpoint_exists( + checkpoint_management.meta_graph_filename(s3))) + self.assertTrue(checkpoint_management.checkpoint_exists(s2)) self.assertTrue( - saver_module.checkpoint_exists(saver_module._meta_graph_filename(s2))) - self.assertTrue(saver_module.checkpoint_exists(s1)) + checkpoint_management.checkpoint_exists( + checkpoint_management.meta_graph_filename(s2))) + self.assertTrue(checkpoint_management.checkpoint_exists(s1)) self.assertTrue( - saver_module.checkpoint_exists(saver_module._meta_graph_filename(s1))) + checkpoint_management.checkpoint_exists( + checkpoint_management.meta_graph_filename(s1))) self.assertCheckpointState( model_checkpoint_path=s1, all_model_checkpoint_paths=[s2, s1], @@ -1306,16 +1316,19 @@ class MaxToKeepTest(test.TestCase): s2 = save3.save(sess, os.path.join(save_dir, "s2")) self.assertEqual([s2], save3.last_checkpoints) # Created by the first helper. - self.assertTrue(saver_module.checkpoint_exists(s1)) + self.assertTrue(checkpoint_management.checkpoint_exists(s1)) self.assertTrue( - saver_module.checkpoint_exists(saver_module._meta_graph_filename(s1))) + checkpoint_management.checkpoint_exists( + checkpoint_management.meta_graph_filename(s1))) # Deleted by the first helper. - self.assertFalse(saver_module.checkpoint_exists(s3)) + self.assertFalse(checkpoint_management.checkpoint_exists(s3)) self.assertFalse( - saver_module.checkpoint_exists(saver_module._meta_graph_filename(s3))) - self.assertTrue(saver_module.checkpoint_exists(s2)) + checkpoint_management.checkpoint_exists( + checkpoint_management.meta_graph_filename(s3))) + self.assertTrue(checkpoint_management.checkpoint_exists(s2)) self.assertTrue( - saver_module.checkpoint_exists(saver_module._meta_graph_filename(s2))) + checkpoint_management.checkpoint_exists( + checkpoint_management.meta_graph_filename(s2))) # Even though the file for s1 exists, this saver isn't aware of it, which # is why it doesn't end up in the checkpoint state. self.assertCheckpointState( @@ -1326,15 +1339,18 @@ class MaxToKeepTest(test.TestCase): # Adding s1 (s3 should not be deleted because helper is unaware of it) s1 = save3.save(sess, os.path.join(save_dir, "s1")) self.assertEqual([s2, s1], save3.last_checkpoints) - self.assertFalse(saver_module.checkpoint_exists(s3)) + self.assertFalse(checkpoint_management.checkpoint_exists(s3)) self.assertFalse( - saver_module.checkpoint_exists(saver_module._meta_graph_filename(s3))) - self.assertTrue(saver_module.checkpoint_exists(s2)) + checkpoint_management.checkpoint_exists( + checkpoint_management.meta_graph_filename(s3))) + self.assertTrue(checkpoint_management.checkpoint_exists(s2)) self.assertTrue( - saver_module.checkpoint_exists(saver_module._meta_graph_filename(s2))) - self.assertTrue(saver_module.checkpoint_exists(s1)) + checkpoint_management.checkpoint_exists( + checkpoint_management.meta_graph_filename(s2))) + self.assertTrue(checkpoint_management.checkpoint_exists(s1)) self.assertTrue( - saver_module.checkpoint_exists(saver_module._meta_graph_filename(s1))) + checkpoint_management.checkpoint_exists( + checkpoint_management.meta_graph_filename(s1))) self.assertCheckpointState( model_checkpoint_path=s1, all_model_checkpoint_paths=[s2, s1], @@ -1365,7 +1381,8 @@ class MaxToKeepTest(test.TestCase): else: self.assertEqual(4, len(gfile.Glob(s1 + "*"))) - self.assertTrue(gfile.Exists(saver_module._meta_graph_filename(s1))) + self.assertTrue( + gfile.Exists(checkpoint_management.meta_graph_filename(s1))) s2 = save.save(sess, os.path.join(save_dir, "s2")) self.assertEqual([s1, s2], save.last_checkpoints) @@ -1373,27 +1390,32 @@ class MaxToKeepTest(test.TestCase): self.assertEqual(2, len(gfile.Glob(s1))) else: self.assertEqual(4, len(gfile.Glob(s1 + "*"))) - self.assertTrue(gfile.Exists(saver_module._meta_graph_filename(s1))) + self.assertTrue( + gfile.Exists(checkpoint_management.meta_graph_filename(s1))) if save._write_version is saver_pb2.SaverDef.V1: self.assertEqual(2, len(gfile.Glob(s2))) else: self.assertEqual(4, len(gfile.Glob(s2 + "*"))) - self.assertTrue(gfile.Exists(saver_module._meta_graph_filename(s2))) + self.assertTrue( + gfile.Exists(checkpoint_management.meta_graph_filename(s2))) s3 = save.save(sess, os.path.join(save_dir, "s3")) self.assertEqual([s2, s3], save.last_checkpoints) self.assertEqual(0, len(gfile.Glob(s1 + "*"))) - self.assertFalse(gfile.Exists(saver_module._meta_graph_filename(s1))) + self.assertFalse( + gfile.Exists(checkpoint_management.meta_graph_filename(s1))) if save._write_version is saver_pb2.SaverDef.V1: self.assertEqual(2, len(gfile.Glob(s2))) else: self.assertEqual(4, len(gfile.Glob(s2 + "*"))) - self.assertTrue(gfile.Exists(saver_module._meta_graph_filename(s2))) + self.assertTrue( + gfile.Exists(checkpoint_management.meta_graph_filename(s2))) if save._write_version is saver_pb2.SaverDef.V1: self.assertEqual(2, len(gfile.Glob(s3))) else: self.assertEqual(4, len(gfile.Glob(s3 + "*"))) - self.assertTrue(gfile.Exists(saver_module._meta_graph_filename(s3))) + self.assertTrue( + gfile.Exists(checkpoint_management.meta_graph_filename(s3))) def testNoMaxToKeep(self): save_dir = self._get_test_dir("no_max_to_keep") @@ -1408,20 +1430,20 @@ class MaxToKeepTest(test.TestCase): self.assertEqual([], save.last_checkpoints) s1 = save.save(sess, os.path.join(save_dir, "s1")) self.assertEqual([], save.last_checkpoints) - self.assertTrue(saver_module.checkpoint_exists(s1)) + self.assertTrue(checkpoint_management.checkpoint_exists(s1)) s2 = save.save(sess, os.path.join(save_dir, "s2")) self.assertEqual([], save.last_checkpoints) - self.assertTrue(saver_module.checkpoint_exists(s2)) + self.assertTrue(checkpoint_management.checkpoint_exists(s2)) # Test max_to_keep being 0. save2 = saver_module.Saver({"v": v}, max_to_keep=0) self.assertEqual([], save2.last_checkpoints) s1 = save2.save(sess, os.path.join(save_dir2, "s1")) self.assertEqual([], save2.last_checkpoints) - self.assertTrue(saver_module.checkpoint_exists(s1)) + self.assertTrue(checkpoint_management.checkpoint_exists(s1)) s2 = save2.save(sess, os.path.join(save_dir2, "s2")) self.assertEqual([], save2.last_checkpoints) - self.assertTrue(saver_module.checkpoint_exists(s2)) + self.assertTrue(checkpoint_management.checkpoint_exists(s2)) def testNoMetaGraph(self): save_dir = self._get_test_dir("no_meta_graph") @@ -1432,8 +1454,9 @@ class MaxToKeepTest(test.TestCase): variables.global_variables_initializer().run() s1 = save.save(sess, os.path.join(save_dir, "s1"), write_meta_graph=False) - self.assertTrue(saver_module.checkpoint_exists(s1)) - self.assertFalse(gfile.Exists(saver_module._meta_graph_filename(s1))) + self.assertTrue(checkpoint_management.checkpoint_exists(s1)) + self.assertFalse( + gfile.Exists(checkpoint_management.meta_graph_filename(s1))) class KeepCheckpointEveryNHoursTest(test.TestCase): @@ -1489,10 +1512,10 @@ class KeepCheckpointEveryNHoursTest(test.TestCase): self.assertEqual([s3, s4], save.last_checkpoints) # Check that s1 is still here, but s2 is gone. - self.assertTrue(saver_module.checkpoint_exists(s1)) - self.assertFalse(saver_module.checkpoint_exists(s2)) - self.assertTrue(saver_module.checkpoint_exists(s3)) - self.assertTrue(saver_module.checkpoint_exists(s4)) + self.assertTrue(checkpoint_management.checkpoint_exists(s1)) + self.assertFalse(checkpoint_management.checkpoint_exists(s2)) + self.assertTrue(checkpoint_management.checkpoint_exists(s3)) + self.assertTrue(checkpoint_management.checkpoint_exists(s4)) class SaveRestoreWithVariableNameMap(test.TestCase): @@ -1571,221 +1594,6 @@ class SaveRestoreWithVariableNameMap(test.TestCase): self._testNonReshape(variables.Variable) -class LatestCheckpointWithRelativePaths(test.TestCase): - - @staticmethod - @contextlib.contextmanager - def tempWorkingDir(temppath): - cwd = os.getcwd() - os.chdir(temppath) - try: - yield - finally: - os.chdir(cwd) - - @staticmethod - @contextlib.contextmanager - def tempDir(): - tempdir = tempfile.mkdtemp() - try: - yield tempdir - finally: - shutil.rmtree(tempdir) - - def testNameCollision(self): - # Make sure we have a clean directory to work in. - with self.tempDir() as tempdir: - # Jump to that directory until this test is done. - with self.tempWorkingDir(tempdir): - # Save training snapshots to a relative path. - traindir = "train/" - os.mkdir(traindir) - # Collides with the default name of the checkpoint state file. - filepath = os.path.join(traindir, "checkpoint") - - with self.test_session() as sess: - unused_a = variables.Variable(0.0) # So that Saver saves something. - variables.global_variables_initializer().run() - - # Should fail. - saver = saver_module.Saver(sharded=False) - with self.assertRaisesRegexp(ValueError, "collides with"): - saver.save(sess, filepath) - - # Succeeds: the file will be named "checkpoint-<step>". - saver.save(sess, filepath, global_step=1) - self.assertIsNotNone(saver_module.latest_checkpoint(traindir)) - - # Succeeds: the file will be named "checkpoint-<i>-of-<n>". - saver = saver_module.Saver(sharded=True) - saver.save(sess, filepath) - self.assertIsNotNone(saver_module.latest_checkpoint(traindir)) - - # Succeeds: the file will be named "checkpoint-<step>-<i>-of-<n>". - saver = saver_module.Saver(sharded=True) - saver.save(sess, filepath, global_step=1) - self.assertIsNotNone(saver_module.latest_checkpoint(traindir)) - - def testRelativePath(self): - # Make sure we have a clean directory to work in. - with self.tempDir() as tempdir: - - # Jump to that directory until this test is done. - with self.tempWorkingDir(tempdir): - - # Save training snapshots to a relative path. - traindir = "train/" - os.mkdir(traindir) - - filename = "snapshot" - filepath = os.path.join(traindir, filename) - - with self.test_session() as sess: - # Build a simple graph. - v0 = variables.Variable(0.0) - inc = v0.assign_add(1.0) - - save = saver_module.Saver({"v0": v0}) - - # Record a short training history. - variables.global_variables_initializer().run() - save.save(sess, filepath, global_step=0) - inc.eval() - save.save(sess, filepath, global_step=1) - inc.eval() - save.save(sess, filepath, global_step=2) - - with self.test_session() as sess: - # Build a new graph with different initialization. - v0 = variables.Variable(-1.0) - - # Create a new saver. - save = saver_module.Saver({"v0": v0}) - variables.global_variables_initializer().run() - - # Get the most recent checkpoint name from the training history file. - name = saver_module.latest_checkpoint(traindir) - self.assertIsNotNone(name) - - # Restore "v0" from that checkpoint. - save.restore(sess, name) - self.assertEqual(v0.eval(), 2.0) - - -class CheckpointStateTest(test.TestCase): - - def _get_test_dir(self, dirname): - test_dir = os.path.join(self.get_temp_dir(), dirname) - gfile.MakeDirs(test_dir) - return test_dir - - def testAbsPath(self): - save_dir = self._get_test_dir("abs_paths") - abs_path = os.path.join(save_dir, "model-0") - ckpt = saver_module.generate_checkpoint_state_proto(save_dir, abs_path) - self.assertEqual(ckpt.model_checkpoint_path, abs_path) - self.assertTrue(os.path.isabs(ckpt.model_checkpoint_path)) - self.assertEqual(len(ckpt.all_model_checkpoint_paths), 1) - self.assertEqual(ckpt.all_model_checkpoint_paths[-1], abs_path) - - def testRelPath(self): - train_dir = "train" - model = os.path.join(train_dir, "model-0") - # model_checkpoint_path should have no "train" directory part. - new_rel_path = "model-0" - ckpt = saver_module.generate_checkpoint_state_proto(train_dir, model) - self.assertEqual(ckpt.model_checkpoint_path, new_rel_path) - self.assertEqual(len(ckpt.all_model_checkpoint_paths), 1) - self.assertEqual(ckpt.all_model_checkpoint_paths[-1], new_rel_path) - - def testAllModelCheckpointPaths(self): - save_dir = self._get_test_dir("all_models_test") - abs_path = os.path.join(save_dir, "model-0") - for paths in [None, [], ["model-2"]]: - ckpt = saver_module.generate_checkpoint_state_proto( - save_dir, abs_path, all_model_checkpoint_paths=paths) - self.assertEqual(ckpt.model_checkpoint_path, abs_path) - self.assertTrue(os.path.isabs(ckpt.model_checkpoint_path)) - self.assertEqual( - len(ckpt.all_model_checkpoint_paths), len(paths) if paths else 1) - self.assertEqual(ckpt.all_model_checkpoint_paths[-1], abs_path) - - def testUpdateCheckpointState(self): - save_dir = self._get_test_dir("update_checkpoint_state") - os.chdir(save_dir) - # Make a temporary train directory. - train_dir = "train" - os.mkdir(train_dir) - abs_path = os.path.join(save_dir, "model-0") - rel_path = os.path.join("train", "model-2") - saver_module.update_checkpoint_state( - train_dir, rel_path, all_model_checkpoint_paths=[abs_path, rel_path]) - ckpt = saver_module.get_checkpoint_state(train_dir) - self.assertEqual(ckpt.model_checkpoint_path, rel_path) - self.assertEqual(len(ckpt.all_model_checkpoint_paths), 2) - self.assertEqual(ckpt.all_model_checkpoint_paths[-1], rel_path) - self.assertEqual(ckpt.all_model_checkpoint_paths[0], abs_path) - - def testUpdateCheckpointStateSaveRelativePaths(self): - save_dir = self._get_test_dir("update_checkpoint_state") - os.chdir(save_dir) - abs_path2 = os.path.join(save_dir, "model-2") - rel_path2 = "model-2" - abs_path0 = os.path.join(save_dir, "model-0") - rel_path0 = "model-0" - saver_module._update_checkpoint_state( # pylint: disable=protected-access - save_dir=save_dir, - model_checkpoint_path=abs_path2, - all_model_checkpoint_paths=[rel_path0, abs_path2], - save_relative_paths=True) - - # File should contain relative paths. - file_content = file_io.read_file_to_string( - os.path.join(save_dir, "checkpoint")) - ckpt = CheckpointState() - text_format.Merge(file_content, ckpt) - self.assertEqual(ckpt.model_checkpoint_path, rel_path2) - self.assertEqual(len(ckpt.all_model_checkpoint_paths), 2) - self.assertEqual(ckpt.all_model_checkpoint_paths[-1], rel_path2) - self.assertEqual(ckpt.all_model_checkpoint_paths[0], rel_path0) - - # get_checkpoint_state should return absolute paths. - ckpt = saver_module.get_checkpoint_state(save_dir) - self.assertEqual(ckpt.model_checkpoint_path, abs_path2) - self.assertEqual(len(ckpt.all_model_checkpoint_paths), 2) - self.assertEqual(ckpt.all_model_checkpoint_paths[-1], abs_path2) - self.assertEqual(ckpt.all_model_checkpoint_paths[0], abs_path0) - - def testCheckPointStateFailsWhenIncomplete(self): - save_dir = self._get_test_dir("checkpoint_state_fails_when_incomplete") - os.chdir(save_dir) - ckpt_path = os.path.join(save_dir, "checkpoint") - ckpt_file = open(ckpt_path, "w") - ckpt_file.write("") - ckpt_file.close() - with self.assertRaises(ValueError): - saver_module.get_checkpoint_state(save_dir) - - def testCheckPointCompletesRelativePaths(self): - save_dir = self._get_test_dir("checkpoint_completes_relative_paths") - os.chdir(save_dir) - ckpt_path = os.path.join(save_dir, "checkpoint") - ckpt_file = open(ckpt_path, "w") - ckpt_file.write(""" - model_checkpoint_path: "./model.ckpt-687529" - all_model_checkpoint_paths: "./model.ckpt-687500" - all_model_checkpoint_paths: "./model.ckpt-687529" - """) - ckpt_file.close() - ckpt = saver_module.get_checkpoint_state(save_dir) - self.assertEqual(ckpt.model_checkpoint_path, - os.path.join(save_dir, "./model.ckpt-687529")) - self.assertEqual(ckpt.all_model_checkpoint_paths[0], - os.path.join(save_dir, "./model.ckpt-687500")) - self.assertEqual(ckpt.all_model_checkpoint_paths[1], - os.path.join(save_dir, "./model.ckpt-687529")) - - class MetaGraphTest(test.TestCase): def _get_test_dir(self, dirname): @@ -2628,62 +2436,6 @@ class WriteGraphTest(test.TestCase): self.assertTrue(os.path.exists(path)) -class SaverUtilsTest(test.TestCase): - - def setUp(self): - self._base_dir = os.path.join(self.get_temp_dir(), "saver_utils_test") - gfile.MakeDirs(self._base_dir) - - def tearDown(self): - gfile.DeleteRecursively(self._base_dir) - - def testCheckpointExists(self): - for sharded in (False, True): - for version in (saver_pb2.SaverDef.V2, saver_pb2.SaverDef.V1): - with self.test_session(graph=ops_lib.Graph()) as sess: - unused_v = variables.Variable(1.0, name="v") - variables.global_variables_initializer().run() - saver = saver_module.Saver(sharded=sharded, write_version=version) - - path = os.path.join(self._base_dir, "%s-%s" % (sharded, version)) - self.assertFalse( - saver_module.checkpoint_exists(path)) # Not saved yet. - - ckpt_prefix = saver.save(sess, path) - self.assertTrue(saver_module.checkpoint_exists(ckpt_prefix)) - - ckpt_prefix = saver_module.latest_checkpoint(self._base_dir) - self.assertTrue(saver_module.checkpoint_exists(ckpt_prefix)) - - def testGetCheckpointMtimes(self): - prefixes = [] - for version in (saver_pb2.SaverDef.V2, saver_pb2.SaverDef.V1): - with self.test_session(graph=ops_lib.Graph()) as sess: - unused_v = variables.Variable(1.0, name="v") - variables.global_variables_initializer().run() - saver = saver_module.Saver(write_version=version) - prefixes.append( - saver.save(sess, os.path.join(self._base_dir, str(version)))) - - mtimes = saver_module.get_checkpoint_mtimes(prefixes) - self.assertEqual(2, len(mtimes)) - self.assertTrue(mtimes[1] >= mtimes[0]) - - def testRemoveCheckpoint(self): - for sharded in (False, True): - for version in (saver_pb2.SaverDef.V2, saver_pb2.SaverDef.V1): - with self.test_session(graph=ops_lib.Graph()) as sess: - unused_v = variables.Variable(1.0, name="v") - variables.global_variables_initializer().run() - saver = saver_module.Saver(sharded=sharded, write_version=version) - - path = os.path.join(self._base_dir, "%s-%s" % (sharded, version)) - ckpt_prefix = saver.save(sess, path) - self.assertTrue(saver_module.checkpoint_exists(ckpt_prefix)) - saver_module.remove_checkpoint(ckpt_prefix, version) - self.assertFalse(saver_module.checkpoint_exists(ckpt_prefix)) - - class ScopedGraphTest(test.TestCase): def _get_test_dir(self, dirname): diff --git a/tensorflow/python/training/session_manager.py b/tensorflow/python/training/session_manager.py index 974f75777f..a2e0645ba8 100644 --- a/tensorflow/python/training/session_manager.py +++ b/tensorflow/python/training/session_manager.py @@ -24,7 +24,7 @@ from tensorflow.python.client import session from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.platform import tf_logging as logging -from tensorflow.python.training import saver as saver_mod +from tensorflow.python.training import checkpoint_management from tensorflow.python.util.tf_export import tf_export @@ -197,13 +197,13 @@ class SessionManager(object): # Waits up until max_wait_secs for checkpoint to become available. wait_time = 0 - ckpt = saver_mod.get_checkpoint_state(checkpoint_dir) + ckpt = checkpoint_management.get_checkpoint_state(checkpoint_dir) while not ckpt or not ckpt.model_checkpoint_path: if wait_for_checkpoint and wait_time < max_wait_secs: logging.info("Waiting for checkpoint to be available.") time.sleep(self._recovery_wait_secs) wait_time += self._recovery_wait_secs - ckpt = saver_mod.get_checkpoint_state(checkpoint_dir) + ckpt = checkpoint_management.get_checkpoint_state(checkpoint_dir) else: return sess, False diff --git a/tensorflow/python/training/session_manager_test.py b/tensorflow/python/training/session_manager_test.py index 6670d9365f..d7e6dac95b 100644 --- a/tensorflow/python/training/session_manager_test.py +++ b/tensorflow/python/training/session_manager_test.py @@ -30,6 +30,7 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import variables from tensorflow.python.platform import gfile from tensorflow.python.platform import test +from tensorflow.python.training import checkpoint_management from tensorflow.python.training import saver as saver_lib from tensorflow.python.training import server_lib from tensorflow.python.training import session_manager @@ -174,13 +175,13 @@ class SessionManagerTest(test.TestCase): os.path.join(checkpoint_dir, "recover_session_checkpoint")) self._test_recovered_variable(checkpoint_dir=checkpoint_dir) self._test_recovered_variable( - checkpoint_filename_with_path=saver_lib.latest_checkpoint( + checkpoint_filename_with_path=checkpoint_management.latest_checkpoint( checkpoint_dir)) # Cannot set both checkpoint_dir and checkpoint_filename_with_path. with self.assertRaises(ValueError): self._test_recovered_variable( checkpoint_dir=checkpoint_dir, - checkpoint_filename_with_path=saver_lib.latest_checkpoint( + checkpoint_filename_with_path=checkpoint_management.latest_checkpoint( checkpoint_dir)) def testWaitForSessionReturnsNoneAfterTimeout(self): diff --git a/tensorflow/python/training/supervisor_test.py b/tensorflow/python/training/supervisor_test.py index 4abce85852..71ed88093a 100644 --- a/tensorflow/python/training/supervisor_test.py +++ b/tensorflow/python/training/supervisor_test.py @@ -44,6 +44,7 @@ from tensorflow.python.platform import test from tensorflow.python.summary import summary from tensorflow.python.summary import summary_iterator from tensorflow.python.summary.writer import writer +from tensorflow.python.training import checkpoint_management from tensorflow.python.training import input as input_lib from tensorflow.python.training import saver as saver_lib from tensorflow.python.training import server_lib @@ -83,7 +84,7 @@ class SupervisorTest(test.TestCase): end_time = time.time() + timeout_secs while time.time() < end_time: if for_checkpoint: - if saver_lib.checkpoint_exists(pattern): + if checkpoint_management.checkpoint_exists(pattern): return else: if len(gfile.Glob(pattern)) >= 1: diff --git a/tensorflow/python/training/training.py b/tensorflow/python/training/training.py index 3f2dc67976..544010afbe 100644 --- a/tensorflow/python/training/training.py +++ b/tensorflow/python/training/training.py @@ -82,12 +82,12 @@ from tensorflow.python.training.monitored_session import WorkerSessionCreator from tensorflow.python.training.monitored_session import MonitoredSession from tensorflow.python.training.monitored_session import SingularMonitoredSession from tensorflow.python.training.saver import Saver -from tensorflow.python.training.saver import checkpoint_exists -from tensorflow.python.training.saver import generate_checkpoint_state_proto -from tensorflow.python.training.saver import get_checkpoint_mtimes -from tensorflow.python.training.saver import get_checkpoint_state -from tensorflow.python.training.saver import latest_checkpoint -from tensorflow.python.training.saver import update_checkpoint_state +from tensorflow.python.training.checkpoint_management import checkpoint_exists +from tensorflow.python.training.checkpoint_management import generate_checkpoint_state_proto +from tensorflow.python.training.checkpoint_management import get_checkpoint_mtimes +from tensorflow.python.training.checkpoint_management import get_checkpoint_state +from tensorflow.python.training.checkpoint_management import latest_checkpoint +from tensorflow.python.training.checkpoint_management import update_checkpoint_state from tensorflow.python.training.saver import export_meta_graph from tensorflow.python.training.saver import import_meta_graph from tensorflow.python.training.session_run_hook import SessionRunHook |