aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/learn
diff options
context:
space:
mode:
authorGravatar Allen Lavoie <allenl@google.com>2018-08-02 15:47:43 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-02 15:51:17 -0700
commit1bf206bc82f600886f1e19c9860f09f18984346b (patch)
treefbd6ee10df16e491142017e96120181b81a72ec5 /tensorflow/contrib/learn
parent6fbbad97e293cc39bde32495e92614c69a9a7896 (diff)
Split checkpoint management utility functions out of saver.py
Pure refactor, in preparation for adding a higher level checkpoint management utility. This utility will also need to work with the Checkpoint proto, and globbing it on to saver.py seems dirty. PiperOrigin-RevId: 207179646
Diffstat (limited to 'tensorflow/contrib/learn')
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/estimator.py7
-rw-r--r--tensorflow/contrib/learn/python/learn/experiment.py10
-rw-r--r--tensorflow/contrib/learn/python/learn/graph_actions_test.py5
-rw-r--r--tensorflow/contrib/learn/python/learn/monitors.py8
-rw-r--r--tensorflow/contrib/learn/python/learn/monitors_test.py14
-rw-r--r--tensorflow/contrib/learn/python/learn/utils/export.py4
-rw-r--r--tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py5
7 files changed, 31 insertions, 22 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py
index 7a026a15e4..c1de42782e 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py
@@ -72,6 +72,7 @@ from tensorflow.python.saved_model import builder as saved_model_builder
from tensorflow.python.saved_model import tag_constants
from tensorflow.python.summary import summary as core_summary
from tensorflow.python.training import basic_session_run_hooks
+from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import device_setter
from tensorflow.python.training import monitored_session
from tensorflow.python.training import saver
@@ -891,7 +892,7 @@ class BaseEstimator(sklearn.BaseEstimator, evaluable.Evaluable,
# Check that model has been trained (if nothing has been set explicitly).
if not checkpoint_path:
- latest_path = saver.latest_checkpoint(self._model_dir)
+ latest_path = checkpoint_management.latest_checkpoint(self._model_dir)
if not latest_path:
raise NotFittedError(
"Couldn't find trained model at %s." % self._model_dir)
@@ -956,7 +957,7 @@ class BaseEstimator(sklearn.BaseEstimator, evaluable.Evaluable,
as_iterable=True,
iterate_batches=False):
# Check that model has been trained.
- checkpoint_path = saver.latest_checkpoint(self._model_dir)
+ checkpoint_path = checkpoint_management.latest_checkpoint(self._model_dir)
if not checkpoint_path:
raise NotFittedError(
"Couldn't find trained model at %s." % self._model_dir)
@@ -1364,7 +1365,7 @@ class Estimator(BaseEstimator):
if not checkpoint_path:
# Locate the latest checkpoint
- checkpoint_path = saver.latest_checkpoint(self._model_dir)
+ checkpoint_path = checkpoint_management.latest_checkpoint(self._model_dir)
if not checkpoint_path:
raise NotFittedError(
"Couldn't find trained model at %s." % self._model_dir)
diff --git a/tensorflow/contrib/learn/python/learn/experiment.py b/tensorflow/contrib/learn/python/learn/experiment.py
index f8a3709ee5..08e907a608 100644
--- a/tensorflow/contrib/learn/python/learn/experiment.py
+++ b/tensorflow/contrib/learn/python/learn/experiment.py
@@ -41,7 +41,7 @@ from tensorflow.python.estimator import estimator as core_estimator
from tensorflow.python.framework import ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import basic_session_run_hooks
-from tensorflow.python.training import saver
+from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import server_lib
from tensorflow.python.util import compat
from tensorflow.python.util import function_utils
@@ -95,7 +95,7 @@ class _EvalAndExportListener(basic_session_run_hooks.CheckpointSaverListener):
# Load and cache the path of the most recent checkpoint to avoid duplicate
# searches on GCS.
logging.info("Checking for checkpoint in %s", self._model_dir)
- latest_path = saver.latest_checkpoint(self._model_dir)
+ latest_path = checkpoint_management.latest_checkpoint(self._model_dir)
if not latest_path:
logging.warning("Skipping evaluation and export since model has not been "
@@ -516,7 +516,8 @@ class Experiment(object):
start = time.time()
error_msg = None
- latest_path = saver.latest_checkpoint(self._estimator.model_dir)
+ latest_path = checkpoint_management.latest_checkpoint(
+ self._estimator.model_dir)
if not latest_path:
error_msg = ("Estimator is not fitted yet. "
"Will start an evaluation when a checkpoint is ready.")
@@ -778,7 +779,8 @@ class Experiment(object):
saving_listeners=self._saving_listeners)
logging.info("Evaluating model now.")
- latest_checkpoint = saver.latest_checkpoint(self._estimator.model_dir)
+ latest_checkpoint = checkpoint_management.latest_checkpoint(
+ self._estimator.model_dir)
eval_result = self._call_evaluate(
input_fn=self._eval_input_fn,
steps=self._eval_steps,
diff --git a/tensorflow/contrib/learn/python/learn/graph_actions_test.py b/tensorflow/contrib/learn/python/learn/graph_actions_test.py
index 0d039d593b..df156da3f4 100644
--- a/tensorflow/contrib/learn/python/learn/graph_actions_test.py
+++ b/tensorflow/contrib/learn/python/learn/graph_actions_test.py
@@ -35,6 +35,7 @@ from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.summary import summary
+from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import saver as saver_lib
@@ -124,7 +125,7 @@ class GraphActionsTest(test.TestCase):
# TODO(ptucker): Test number and contents of checkpoint files.
def _assert_ckpt(self, output_dir, expected=True):
- ckpt_state = saver_lib.get_checkpoint_state(output_dir)
+ ckpt_state = checkpoint_management.get_checkpoint_state(output_dir)
if expected:
pattern = '%s/model.ckpt-.*' % output_dir
primary_ckpt_path = ckpt_state.model_checkpoint_path
@@ -434,7 +435,7 @@ class GraphActionsTrainTest(test.TestCase):
# TODO(ptucker): Test number and contents of checkpoint files.
def _assert_ckpt(self, output_dir, expected=True):
- ckpt_state = saver_lib.get_checkpoint_state(output_dir)
+ ckpt_state = checkpoint_management.get_checkpoint_state(output_dir)
if expected:
pattern = '%s/model.ckpt-.*' % output_dir
primary_ckpt_path = ckpt_state.model_checkpoint_path
diff --git a/tensorflow/contrib/learn/python/learn/monitors.py b/tensorflow/contrib/learn/python/learn/monitors.py
index 77f7c73d54..3d691d4340 100644
--- a/tensorflow/contrib/learn/python/learn/monitors.py
+++ b/tensorflow/contrib/learn/python/learn/monitors.py
@@ -51,7 +51,7 @@ from tensorflow.python.estimator import estimator as core_estimator
from tensorflow.python.framework import ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.summary import summary as core_summary
-from tensorflow.python.training import saver as saver_lib
+from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import session_run_hook
from tensorflow.python.training import training_util
from tensorflow.python.util import deprecation
@@ -735,7 +735,8 @@ class ValidationMonitor(EveryN):
return False
self._last_checkpoint_check_time = current_time
# Check that we are not running evaluation on the same checkpoint.
- latest_path = saver_lib.latest_checkpoint(self._estimator.model_dir)
+ latest_path = checkpoint_management.latest_checkpoint(
+ self._estimator.model_dir)
if latest_path is None:
logging.debug("Skipping evaluation since model has not been saved yet "
"at step %d.", step)
@@ -1059,7 +1060,8 @@ class ExportMonitor(EveryN):
def end(self, session=None):
super(ExportMonitor, self).end(session=session)
- latest_path = saver_lib.latest_checkpoint(self._estimator.model_dir)
+ latest_path = checkpoint_management.latest_checkpoint(
+ self._estimator.model_dir)
if latest_path is None:
logging.info("Skipping export at the end since model has not been saved "
"yet.")
diff --git a/tensorflow/contrib/learn/python/learn/monitors_test.py b/tensorflow/contrib/learn/python/learn/monitors_test.py
index 5c34d0ddb0..ff1da32c21 100644
--- a/tensorflow/contrib/learn/python/learn/monitors_test.py
+++ b/tensorflow/contrib/learn/python/learn/monitors_test.py
@@ -39,9 +39,9 @@ from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.summary import summary
+from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import gradient_descent
from tensorflow.python.training import monitored_session
-from tensorflow.python.training import saver
from tensorflow.python.training import training_util
@@ -317,7 +317,7 @@ class MonitorsTest(test.TestCase):
self._run_monitor(monitor)
@test.mock.patch.object(estimators, 'Estimator', autospec=True)
- @test.mock.patch.object(saver, 'latest_checkpoint')
+ @test.mock.patch.object(checkpoint_management, 'latest_checkpoint')
def test_validation_monitor_no_ckpt(self, mock_latest_checkpoint,
mock_estimator_class):
estimator = mock_estimator_class()
@@ -336,7 +336,7 @@ class MonitorsTest(test.TestCase):
mock_latest_checkpoint.assert_called_with(model_dir)
@test.mock.patch.object(estimators, 'Estimator', autospec=True)
- @test.mock.patch.object(saver, 'latest_checkpoint')
+ @test.mock.patch.object(checkpoint_management, 'latest_checkpoint')
def test_validation_monitor_no_early_stopping_rounds(self,
mock_latest_checkpoint,
mock_estimator_class):
@@ -356,7 +356,7 @@ class MonitorsTest(test.TestCase):
self._assert_validation_monitor(monitor)
@test.mock.patch.object(estimators, 'Estimator', autospec=True)
- @test.mock.patch.object(saver, 'latest_checkpoint')
+ @test.mock.patch.object(checkpoint_management, 'latest_checkpoint')
def test_validation_monitor_invalid_metric(self, mock_latest_checkpoint,
mock_estimator_class):
estimator = mock_estimator_class()
@@ -375,7 +375,7 @@ class MonitorsTest(test.TestCase):
self._run_monitor(monitor, num_epochs=1, num_steps_per_epoch=1)
@test.mock.patch.object(estimators, 'Estimator', autospec=True)
- @test.mock.patch.object(saver, 'latest_checkpoint')
+ @test.mock.patch.object(checkpoint_management, 'latest_checkpoint')
def test_validation_monitor(self, mock_latest_checkpoint,
mock_estimator_class):
estimator = mock_estimator_class()
@@ -464,7 +464,7 @@ class MonitorsTest(test.TestCase):
monitor.epoch_end(epoch=0)
monitor.end()
- @test.mock.patch.object(saver, 'latest_checkpoint')
+ @test.mock.patch.object(checkpoint_management, 'latest_checkpoint')
def test_validation_monitor_with_core_estimator(self, mock_latest_checkpoint):
estimator = test.mock.Mock(spec=core_estimator.Estimator)
model_dir = 'model/dir'
@@ -495,7 +495,7 @@ class MonitorsTest(test.TestCase):
expected_best_metrics={'loss': 42.0, 'auc': 0.5})
monitor.post_step(step=step, session=None)
- @test.mock.patch.object(saver, 'latest_checkpoint')
+ @test.mock.patch.object(checkpoint_management, 'latest_checkpoint')
def test_validation_monitor_fail_with_core_estimator_and_metrics(
self, mock_latest_checkpoint):
estimator = test.mock.Mock(spec=core_estimator.Estimator)
diff --git a/tensorflow/contrib/learn/python/learn/utils/export.py b/tensorflow/contrib/learn/python/learn/utils/export.py
index 3eacac7a3d..0144b93814 100644
--- a/tensorflow/contrib/learn/python/learn/utils/export.py
+++ b/tensorflow/contrib/learn/python/learn/utils/export.py
@@ -35,6 +35,7 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import saver as tf_saver
from tensorflow.python.training import training_util
@@ -298,7 +299,8 @@ def _export_estimator(estimator,
# If checkpoint_path is specified, use the specified checkpoint path.
checkpoint_path = (checkpoint_path or
- tf_saver.latest_checkpoint(estimator._model_dir))
+ checkpoint_management.latest_checkpoint(
+ estimator._model_dir))
with ops.Graph().as_default() as g:
training_util.create_global_step(g)
diff --git a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py
index f8106d1e4a..66af6833da 100644
--- a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py
+++ b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py
@@ -55,7 +55,7 @@ from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model import signature_def_utils
from tensorflow.python.summary import summary_iterator
-from tensorflow.python.training import saver
+from tensorflow.python.training import checkpoint_management
from tensorflow.python.util import compat
from tensorflow.python.util.deprecation import deprecated
@@ -714,7 +714,8 @@ def make_best_model_export_strategy(
# as soon as contrib is cleaned up and we can thus be sure that
# estimator is a tf.estimator.Estimator and not a
# tf.contrib.learn.Estimator
- checkpoint_path = saver.latest_checkpoint(estimator.model_dir)
+ checkpoint_path = checkpoint_management.latest_checkpoint(
+ estimator.model_dir)
export_checkpoint_path, export_eval_result = best_model_selector.update(
checkpoint_path, eval_result)