diff options
author | 2018-08-02 15:47:43 -0700 | |
---|---|---|
committer | 2018-08-02 15:51:17 -0700 | |
commit | 1bf206bc82f600886f1e19c9860f09f18984346b (patch) | |
tree | fbd6ee10df16e491142017e96120181b81a72ec5 /tensorflow/contrib/learn | |
parent | 6fbbad97e293cc39bde32495e92614c69a9a7896 (diff) |
Split checkpoint management utility functions out of saver.py
Pure refactor, in preparation for adding a higher level checkpoint management utility. This utility will also need to work with the Checkpoint proto, and globbing it on to saver.py seems dirty.
PiperOrigin-RevId: 207179646
Diffstat (limited to 'tensorflow/contrib/learn')
7 files changed, 31 insertions, 22 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py index 7a026a15e4..c1de42782e 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py @@ -72,6 +72,7 @@ from tensorflow.python.saved_model import builder as saved_model_builder from tensorflow.python.saved_model import tag_constants from tensorflow.python.summary import summary as core_summary from tensorflow.python.training import basic_session_run_hooks +from tensorflow.python.training import checkpoint_management from tensorflow.python.training import device_setter from tensorflow.python.training import monitored_session from tensorflow.python.training import saver @@ -891,7 +892,7 @@ class BaseEstimator(sklearn.BaseEstimator, evaluable.Evaluable, # Check that model has been trained (if nothing has been set explicitly). if not checkpoint_path: - latest_path = saver.latest_checkpoint(self._model_dir) + latest_path = checkpoint_management.latest_checkpoint(self._model_dir) if not latest_path: raise NotFittedError( "Couldn't find trained model at %s." % self._model_dir) @@ -956,7 +957,7 @@ class BaseEstimator(sklearn.BaseEstimator, evaluable.Evaluable, as_iterable=True, iterate_batches=False): # Check that model has been trained. - checkpoint_path = saver.latest_checkpoint(self._model_dir) + checkpoint_path = checkpoint_management.latest_checkpoint(self._model_dir) if not checkpoint_path: raise NotFittedError( "Couldn't find trained model at %s." % self._model_dir) @@ -1364,7 +1365,7 @@ class Estimator(BaseEstimator): if not checkpoint_path: # Locate the latest checkpoint - checkpoint_path = saver.latest_checkpoint(self._model_dir) + checkpoint_path = checkpoint_management.latest_checkpoint(self._model_dir) if not checkpoint_path: raise NotFittedError( "Couldn't find trained model at %s." % self._model_dir) diff --git a/tensorflow/contrib/learn/python/learn/experiment.py b/tensorflow/contrib/learn/python/learn/experiment.py index f8a3709ee5..08e907a608 100644 --- a/tensorflow/contrib/learn/python/learn/experiment.py +++ b/tensorflow/contrib/learn/python/learn/experiment.py @@ -41,7 +41,7 @@ from tensorflow.python.estimator import estimator as core_estimator from tensorflow.python.framework import ops from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import basic_session_run_hooks -from tensorflow.python.training import saver +from tensorflow.python.training import checkpoint_management from tensorflow.python.training import server_lib from tensorflow.python.util import compat from tensorflow.python.util import function_utils @@ -95,7 +95,7 @@ class _EvalAndExportListener(basic_session_run_hooks.CheckpointSaverListener): # Load and cache the path of the most recent checkpoint to avoid duplicate # searches on GCS. logging.info("Checking for checkpoint in %s", self._model_dir) - latest_path = saver.latest_checkpoint(self._model_dir) + latest_path = checkpoint_management.latest_checkpoint(self._model_dir) if not latest_path: logging.warning("Skipping evaluation and export since model has not been " @@ -516,7 +516,8 @@ class Experiment(object): start = time.time() error_msg = None - latest_path = saver.latest_checkpoint(self._estimator.model_dir) + latest_path = checkpoint_management.latest_checkpoint( + self._estimator.model_dir) if not latest_path: error_msg = ("Estimator is not fitted yet. " "Will start an evaluation when a checkpoint is ready.") @@ -778,7 +779,8 @@ class Experiment(object): saving_listeners=self._saving_listeners) logging.info("Evaluating model now.") - latest_checkpoint = saver.latest_checkpoint(self._estimator.model_dir) + latest_checkpoint = checkpoint_management.latest_checkpoint( + self._estimator.model_dir) eval_result = self._call_evaluate( input_fn=self._eval_input_fn, steps=self._eval_steps, diff --git a/tensorflow/contrib/learn/python/learn/graph_actions_test.py b/tensorflow/contrib/learn/python/learn/graph_actions_test.py index 0d039d593b..df156da3f4 100644 --- a/tensorflow/contrib/learn/python/learn/graph_actions_test.py +++ b/tensorflow/contrib/learn/python/learn/graph_actions_test.py @@ -35,6 +35,7 @@ from tensorflow.python.ops import state_ops from tensorflow.python.ops import variables from tensorflow.python.platform import test from tensorflow.python.summary import summary +from tensorflow.python.training import checkpoint_management from tensorflow.python.training import saver as saver_lib @@ -124,7 +125,7 @@ class GraphActionsTest(test.TestCase): # TODO(ptucker): Test number and contents of checkpoint files. def _assert_ckpt(self, output_dir, expected=True): - ckpt_state = saver_lib.get_checkpoint_state(output_dir) + ckpt_state = checkpoint_management.get_checkpoint_state(output_dir) if expected: pattern = '%s/model.ckpt-.*' % output_dir primary_ckpt_path = ckpt_state.model_checkpoint_path @@ -434,7 +435,7 @@ class GraphActionsTrainTest(test.TestCase): # TODO(ptucker): Test number and contents of checkpoint files. def _assert_ckpt(self, output_dir, expected=True): - ckpt_state = saver_lib.get_checkpoint_state(output_dir) + ckpt_state = checkpoint_management.get_checkpoint_state(output_dir) if expected: pattern = '%s/model.ckpt-.*' % output_dir primary_ckpt_path = ckpt_state.model_checkpoint_path diff --git a/tensorflow/contrib/learn/python/learn/monitors.py b/tensorflow/contrib/learn/python/learn/monitors.py index 77f7c73d54..3d691d4340 100644 --- a/tensorflow/contrib/learn/python/learn/monitors.py +++ b/tensorflow/contrib/learn/python/learn/monitors.py @@ -51,7 +51,7 @@ from tensorflow.python.estimator import estimator as core_estimator from tensorflow.python.framework import ops from tensorflow.python.platform import tf_logging as logging from tensorflow.python.summary import summary as core_summary -from tensorflow.python.training import saver as saver_lib +from tensorflow.python.training import checkpoint_management from tensorflow.python.training import session_run_hook from tensorflow.python.training import training_util from tensorflow.python.util import deprecation @@ -735,7 +735,8 @@ class ValidationMonitor(EveryN): return False self._last_checkpoint_check_time = current_time # Check that we are not running evaluation on the same checkpoint. - latest_path = saver_lib.latest_checkpoint(self._estimator.model_dir) + latest_path = checkpoint_management.latest_checkpoint( + self._estimator.model_dir) if latest_path is None: logging.debug("Skipping evaluation since model has not been saved yet " "at step %d.", step) @@ -1059,7 +1060,8 @@ class ExportMonitor(EveryN): def end(self, session=None): super(ExportMonitor, self).end(session=session) - latest_path = saver_lib.latest_checkpoint(self._estimator.model_dir) + latest_path = checkpoint_management.latest_checkpoint( + self._estimator.model_dir) if latest_path is None: logging.info("Skipping export at the end since model has not been saved " "yet.") diff --git a/tensorflow/contrib/learn/python/learn/monitors_test.py b/tensorflow/contrib/learn/python/learn/monitors_test.py index 5c34d0ddb0..ff1da32c21 100644 --- a/tensorflow/contrib/learn/python/learn/monitors_test.py +++ b/tensorflow/contrib/learn/python/learn/monitors_test.py @@ -39,9 +39,9 @@ from tensorflow.python.ops import variables from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging as logging from tensorflow.python.summary import summary +from tensorflow.python.training import checkpoint_management from tensorflow.python.training import gradient_descent from tensorflow.python.training import monitored_session -from tensorflow.python.training import saver from tensorflow.python.training import training_util @@ -317,7 +317,7 @@ class MonitorsTest(test.TestCase): self._run_monitor(monitor) @test.mock.patch.object(estimators, 'Estimator', autospec=True) - @test.mock.patch.object(saver, 'latest_checkpoint') + @test.mock.patch.object(checkpoint_management, 'latest_checkpoint') def test_validation_monitor_no_ckpt(self, mock_latest_checkpoint, mock_estimator_class): estimator = mock_estimator_class() @@ -336,7 +336,7 @@ class MonitorsTest(test.TestCase): mock_latest_checkpoint.assert_called_with(model_dir) @test.mock.patch.object(estimators, 'Estimator', autospec=True) - @test.mock.patch.object(saver, 'latest_checkpoint') + @test.mock.patch.object(checkpoint_management, 'latest_checkpoint') def test_validation_monitor_no_early_stopping_rounds(self, mock_latest_checkpoint, mock_estimator_class): @@ -356,7 +356,7 @@ class MonitorsTest(test.TestCase): self._assert_validation_monitor(monitor) @test.mock.patch.object(estimators, 'Estimator', autospec=True) - @test.mock.patch.object(saver, 'latest_checkpoint') + @test.mock.patch.object(checkpoint_management, 'latest_checkpoint') def test_validation_monitor_invalid_metric(self, mock_latest_checkpoint, mock_estimator_class): estimator = mock_estimator_class() @@ -375,7 +375,7 @@ class MonitorsTest(test.TestCase): self._run_monitor(monitor, num_epochs=1, num_steps_per_epoch=1) @test.mock.patch.object(estimators, 'Estimator', autospec=True) - @test.mock.patch.object(saver, 'latest_checkpoint') + @test.mock.patch.object(checkpoint_management, 'latest_checkpoint') def test_validation_monitor(self, mock_latest_checkpoint, mock_estimator_class): estimator = mock_estimator_class() @@ -464,7 +464,7 @@ class MonitorsTest(test.TestCase): monitor.epoch_end(epoch=0) monitor.end() - @test.mock.patch.object(saver, 'latest_checkpoint') + @test.mock.patch.object(checkpoint_management, 'latest_checkpoint') def test_validation_monitor_with_core_estimator(self, mock_latest_checkpoint): estimator = test.mock.Mock(spec=core_estimator.Estimator) model_dir = 'model/dir' @@ -495,7 +495,7 @@ class MonitorsTest(test.TestCase): expected_best_metrics={'loss': 42.0, 'auc': 0.5}) monitor.post_step(step=step, session=None) - @test.mock.patch.object(saver, 'latest_checkpoint') + @test.mock.patch.object(checkpoint_management, 'latest_checkpoint') def test_validation_monitor_fail_with_core_estimator_and_metrics( self, mock_latest_checkpoint): estimator = test.mock.Mock(spec=core_estimator.Estimator) diff --git a/tensorflow/contrib/learn/python/learn/utils/export.py b/tensorflow/contrib/learn/python/learn/utils/export.py index 3eacac7a3d..0144b93814 100644 --- a/tensorflow/contrib/learn/python/learn/utils/export.py +++ b/tensorflow/contrib/learn/python/learn/utils/export.py @@ -35,6 +35,7 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.training import checkpoint_management from tensorflow.python.training import saver as tf_saver from tensorflow.python.training import training_util @@ -298,7 +299,8 @@ def _export_estimator(estimator, # If checkpoint_path is specified, use the specified checkpoint path. checkpoint_path = (checkpoint_path or - tf_saver.latest_checkpoint(estimator._model_dir)) + checkpoint_management.latest_checkpoint( + estimator._model_dir)) with ops.Graph().as_default() as g: training_util.create_global_step(g) diff --git a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py index f8106d1e4a..66af6833da 100644 --- a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py +++ b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py @@ -55,7 +55,7 @@ from tensorflow.python.platform import tf_logging as logging from tensorflow.python.saved_model import signature_constants from tensorflow.python.saved_model import signature_def_utils from tensorflow.python.summary import summary_iterator -from tensorflow.python.training import saver +from tensorflow.python.training import checkpoint_management from tensorflow.python.util import compat from tensorflow.python.util.deprecation import deprecated @@ -714,7 +714,8 @@ def make_best_model_export_strategy( # as soon as contrib is cleaned up and we can thus be sure that # estimator is a tf.estimator.Estimator and not a # tf.contrib.learn.Estimator - checkpoint_path = saver.latest_checkpoint(estimator.model_dir) + checkpoint_path = checkpoint_management.latest_checkpoint( + estimator.model_dir) export_checkpoint_path, export_eval_result = best_model_selector.update( checkpoint_path, eval_result) |