aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Martin Wicke <wicke@google.com>2016-12-01 23:10:01 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-12-01 23:23:10 -0800
commit40d4b8efa317411113784135612adc960f2721aa (patch)
treef962d65ccf26850b524cdb9168592fffbfc765f4
parentc78b1ec74d1c4aab88bca6e9010c7582c1f89706 (diff)
Don't evaluate checkpoints twice in continuous eval.
Change: 140819261
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/estimator.py31
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/random_forest.py5
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/svm.py11
-rw-r--r--tensorflow/contrib/learn/python/learn/evaluable.py9
-rw-r--r--tensorflow/contrib/learn/python/learn/experiment.py62
-rw-r--r--tensorflow/contrib/learn/python/learn/experiment_test.py61
6 files changed, 139 insertions, 40 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py
index ea43fba6b3..bf1af5390e 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py
@@ -414,7 +414,7 @@ class BaseEstimator(
)
def evaluate(
self, x=None, y=None, input_fn=None, feed_fn=None, batch_size=None,
- steps=None, metrics=None, name=None):
+ steps=None, metrics=None, name=None, checkpoint_path=None):
# pylint: disable=g-doc-args,g-doc-return-or-yield
"""See `Evaluable`.
@@ -429,11 +429,13 @@ class BaseEstimator(
if metrics is not None and not isinstance(metrics, dict):
raise ValueError('Metrics argument should be None or dict. '
'Got %s.' % metrics)
- eval_results, global_step = self._evaluate_model(input_fn=input_fn,
- feed_fn=feed_fn,
- steps=steps,
- metrics=metrics,
- name=name)
+ eval_results, global_step = self._evaluate_model(
+ input_fn=input_fn,
+ feed_fn=feed_fn,
+ steps=steps,
+ metrics=metrics,
+ name=name,
+ checkpoint_path=checkpoint_path)
if eval_results is not None:
eval_results.update({'global_step': global_step})
return eval_results
@@ -769,18 +771,21 @@ class BaseEstimator(
steps,
feed_fn=None,
metrics=None,
- name=''):
+ name='',
+ checkpoint_path=None):
# TODO(wicke): Remove this once Model and associated code are gone.
if (hasattr(self._config, 'execution_mode') and
self._config.execution_mode not in ('all', 'evaluate', 'eval_evalset')):
return None, None
- # Check that model has been trained.
- checkpoint_path = self._model_dir
- latest_path = saver.latest_checkpoint(checkpoint_path)
- if not latest_path:
- raise NotFittedError("Couldn't find trained model at %s."
- % checkpoint_path)
+ # Check that model has been trained (if nothing has been set explicitly).
+ if not checkpoint_path:
+ latest_path = saver.latest_checkpoint(self._model_dir)
+ if not latest_path:
+ raise NotFittedError("Couldn't find trained model at %s."
+ % self._model_dir)
+ checkpoint_path = self._model_dir
+
# Setup output directory.
eval_dir = os.path.join(self._model_dir, 'eval' if not name else
'eval_' + name)
diff --git a/tensorflow/contrib/learn/python/learn/estimators/random_forest.py b/tensorflow/contrib/learn/python/learn/estimators/random_forest.py
index deb55efc9f..cf9b508470 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/random_forest.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/random_forest.py
@@ -219,6 +219,11 @@ class TensorForestEstimator(evaluable.Evaluable, trainable.Trainable):
config=config,
feature_engineering_fn=feature_engineering_fn)
+ @property
+ def model_dir(self):
+ """See evaluable.Evaluable."""
+ return self._estimator.model_dir
+
def evaluate(
self, x=None, y=None, input_fn=None, feed_fn=None, batch_size=None,
steps=None, metrics=None, name=None):
diff --git a/tensorflow/contrib/learn/python/learn/estimators/svm.py b/tensorflow/contrib/learn/python/learn/estimators/svm.py
index a6e4e7b6a3..561a898e78 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/svm.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/svm.py
@@ -168,6 +168,11 @@ class SVM(trainable.Trainable, evaluable.Evaluable):
if not self._estimator.config.is_chief:
self._chief_hook = None
+ @property
+ def model_dir(self):
+ """See trainable.Evaluable."""
+ return self._estimator.model_dir
+
def fit(self, x=None, y=None, input_fn=None, steps=None, batch_size=None,
monitors=None, max_steps=None):
"""See trainable.Trainable."""
@@ -181,11 +186,13 @@ class SVM(trainable.Trainable, evaluable.Evaluable):
# pylint: disable=protected-access
def evaluate(self, x=None, y=None, input_fn=None, feed_fn=None,
- batch_size=None, steps=None, metrics=None, name=None):
+ batch_size=None, steps=None, metrics=None, name=None,
+ checkpoint_path=None):
"""See evaluable.Evaluable."""
return self._estimator.evaluate(x=x, y=y, input_fn=input_fn,
feed_fn=feed_fn, batch_size=batch_size,
- steps=steps, metrics=metrics, name=name)
+ steps=steps, metrics=metrics, name=name,
+ checkpoint_path=checkpoint_path)
@deprecated_arg_values(
estimator.AS_ITERABLE_DATE, estimator.AS_ITERABLE_INSTRUCTIONS,
diff --git a/tensorflow/contrib/learn/python/learn/evaluable.py b/tensorflow/contrib/learn/python/learn/evaluable.py
index 051d132416..14cf5f01b8 100644
--- a/tensorflow/contrib/learn/python/learn/evaluable.py
+++ b/tensorflow/contrib/learn/python/learn/evaluable.py
@@ -27,10 +27,15 @@ class Evaluable(object):
"""
__metaclass__ = abc.ABCMeta
+ @abc.abstractproperty
+ def model_dir(self):
+ """Returns a path in which the eval process will look for checkpoints."""
+ raise NotImplementedError
+
@abc.abstractmethod
def evaluate(
self, x=None, y=None, input_fn=None, feed_fn=None, batch_size=None,
- steps=None, metrics=None, name=None):
+ steps=None, metrics=None, name=None, checkpoint_path=None):
"""Evaluates given model with provided evaluation data.
Stop conditions - we evaluate on the given input data until one of the
@@ -81,6 +86,8 @@ class Evaluable(object):
`../../../metrics/python/ops/metrics_ops.py`.
name: Name of the evaluation if user needs to run multiple evaluations on
different data sets, such as on training data vs test data.
+ checkpoint_path: Path of a specific checkpoint to evaluate. If `None`, the
+ latest checkpoint in `model_dir` is used.
Returns:
Returns `dict` with evaluation results.
diff --git a/tensorflow/contrib/learn/python/learn/experiment.py b/tensorflow/contrib/learn/python/learn/experiment.py
index 2615fa8a00..bf0791af51 100644
--- a/tensorflow/contrib/learn/python/learn/experiment.py
+++ b/tensorflow/contrib/learn/python/learn/experiment.py
@@ -29,9 +29,9 @@ from tensorflow.contrib.learn.python.learn import evaluable
from tensorflow.contrib.learn.python.learn import monitors
from tensorflow.contrib.learn.python.learn import trainable
from tensorflow.contrib.learn.python.learn.estimators import run_config
-from tensorflow.contrib.learn.python.learn.estimators._sklearn import NotFittedError
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import basic_session_run_hooks
+from tensorflow.python.training import saver
from tensorflow.python.training import server_lib
@@ -219,7 +219,8 @@ class Experiment(object):
input_fn,
name,
delay_secs,
- throttle_delay_secs):
+ throttle_delay_secs,
+ evaluate_checkpoint_only_once=True):
"""Run continuous eval.
Runs infinite eval on the evaluation data set. This function starts
@@ -235,6 +236,8 @@ class Experiment(object):
throttle_delay_secs: Do not re-evaluate unless the last evaluation was
started at least this many seconds ago. If None, defaults to
self._continuous_eval_throttle_secs.
+ evaluate_checkpoint_only_once: Whether to skip evaluation of checkpoints
+ that have already been evaluated. Default is `True`.
"""
if delay_secs is None:
delay_secs = self._eval_delay_secs
@@ -245,21 +248,33 @@ class Experiment(object):
logging.info("Waiting %f secs before starting eval.", delay_secs)
time.sleep(delay_secs)
- last_fitted_error_time = 0
+ previous_path = None
+ last_warning_time = 0
while True:
start = time.time()
- try:
+
+ error_msg = None
+ latest_path = saver.latest_checkpoint(self._estimator.model_dir)
+ if not latest_path:
+ error_msg = ("Estimator is not fitted yet. "
+ "Will start an evaluation when a checkpoint is ready.")
+ elif evaluate_checkpoint_only_once and latest_path == previous_path:
+ error_msg = "No new checkpoint ready for evaluation."
+
+ if error_msg:
+ # Print warning message every 10 mins.
+ if time.time() - last_warning_time > 600:
+ logging.warning(error_msg)
+ last_warning_time = time.time()
+ else:
self._estimator.evaluate(input_fn=input_fn,
steps=self._eval_steps,
metrics=self._eval_metrics,
- name=name)
- except NotFittedError:
- # Print warning message every 10 mins.
- if time.time() - last_fitted_error_time > 600:
- logging.warning(
- "Estimator is not fitted yet. "
- "Will start an evaluation when a checkpoint will be ready.")
- last_fitted_error_time = time.time()
+ name=name,
+ use_checkpoint=latest_path)
+ # Clear warning timer and update last evaluated checkpoint
+ last_warning_time = 0
+ previous_path = latest_path
duration = time.time() - start
if duration < throttle_delay_secs:
@@ -268,11 +283,14 @@ class Experiment(object):
difference)
time.sleep(difference)
- def continuous_eval(self, delay_secs=None, throttle_delay_secs=None):
- self._continuous_eval(self._eval_input_fn,
- name="continuous",
- delay_secs=delay_secs,
- throttle_delay_secs=throttle_delay_secs)
+ def continuous_eval(self, delay_secs=None, throttle_delay_secs=None,
+ evaluate_checkpoint_only_once=True):
+ self._continuous_eval(
+ self._eval_input_fn,
+ name="continuous",
+ delay_secs=delay_secs,
+ throttle_delay_secs=throttle_delay_secs,
+ evaluate_checkpoint_only_once=evaluate_checkpoint_only_once)
def continuous_eval_on_train_data(self,
delay_secs=None,
@@ -328,7 +346,6 @@ class Experiment(object):
metrics=self._eval_metrics,
name=eval_dir_suffix)
-
def run_std_server(self):
"""Starts a TensorFlow server and joins the serving thread.
@@ -380,7 +397,14 @@ def _new_attr_context(obj, attr):
This creates a context in which an object's attribute can be changed.
Once the context is exited, the attribute reverts to its original value.
- Example usage:
+ Args:
+ obj: An object whose attribute to restore at the end of the context.
+ attr: An attribute to remember and restore at the end of the context.
+
+ Yields:
+ Context.
+
+ Example:
my_obj.x = 1
with _new_attr_context(my_obj, "x"):
my_obj.x = 2
diff --git a/tensorflow/contrib/learn/python/learn/experiment_test.py b/tensorflow/contrib/learn/python/learn/experiment_test.py
index d0f04ee842..0872f5a263 100644
--- a/tensorflow/contrib/learn/python/learn/experiment_test.py
+++ b/tensorflow/contrib/learn/python/learn/experiment_test.py
@@ -17,6 +17,9 @@ from __future__ import division
from __future__ import print_function
import json
+import os
+import tempfile
+import threading
import time
import tensorflow as tf
@@ -29,11 +32,17 @@ patch = tf.test.mock.patch
class TestEstimator(tf.contrib.learn.Evaluable, tf.contrib.learn.Trainable):
- def __init__(self, config=None):
+ def __init__(self, config=None, max_evals=5):
self.eval_count = 0
self.fit_count = 0
+ self._max_evals = max_evals
self.monitors = []
self._config = config or run_config.RunConfig()
+ self._model_dir = tempfile.mkdtemp()
+
+ @property
+ def model_dir(self):
+ return self._model_dir
@property
def config(self):
@@ -42,12 +51,21 @@ class TestEstimator(tf.contrib.learn.Evaluable, tf.contrib.learn.Trainable):
def evaluate(self, **kwargs):
tf.logging.info('evaluate called with args: %s' % kwargs)
self.eval_count += 1
- if self.eval_count > 5:
- tf.logging.info('Ran 6 evals. Done.')
+ if self.eval_count > self._max_evals:
+ tf.logging.info('Ran %d evals. Done.' % self.eval_count)
raise StopIteration()
return [(key, kwargs[key]) for key in sorted(kwargs.keys())]
+ def fake_checkpoint(self):
+ save_path = os.path.join(self.model_dir, 'model.ckpt')
+ with tf.Session() as sess:
+ var = tf.Variable(1.0, name='var0')
+ save = tf.train.Saver({var.op.name: var})
+ var.initializer.run()
+ save.save(sess, save_path, global_step=0)
+
def fit(self, **kwargs):
+ self.fake_checkpoint()
tf.logging.info('fit called with args: %s' % kwargs)
self.fit_count += 1
if 'monitors' in kwargs:
@@ -195,6 +213,7 @@ class ExperimentTest(tf.test.TestCase):
def test_evaluate(self):
est = TestEstimator()
+ est.fake_checkpoint()
ex = tf.contrib.learn.Experiment(
est,
train_input_fn='train_input',
@@ -208,6 +227,7 @@ class ExperimentTest(tf.test.TestCase):
def test_evaluate_delay(self):
est = TestEstimator()
+ est.fake_checkpoint()
ex = tf.contrib.learn.Experiment(
est, train_input_fn='train_input', eval_input_fn='eval_input')
@@ -220,6 +240,7 @@ class ExperimentTest(tf.test.TestCase):
def test_continuous_eval(self):
est = TestEstimator()
+ est.fake_checkpoint()
ex = tf.contrib.learn.Experiment(
est,
train_input_fn='train_input',
@@ -227,13 +248,15 @@ class ExperimentTest(tf.test.TestCase):
eval_metrics='eval_metrics',
eval_delay_secs=0,
continuous_eval_throttle_secs=0)
- self.assertRaises(StopIteration, ex.continuous_eval)
+ self.assertRaises(StopIteration, ex.continuous_eval,
+ evaluate_checkpoint_only_once=False)
self.assertEquals(6, est.eval_count)
self.assertEquals(0, est.fit_count)
def test_continuous_eval_throttle_delay(self):
for delay in [0, 1, 2]:
est = TestEstimator()
+ est.fake_checkpoint()
ex = tf.contrib.learn.Experiment(
est,
train_input_fn='train_input',
@@ -242,7 +265,8 @@ class ExperimentTest(tf.test.TestCase):
continuous_eval_throttle_secs=delay,
eval_delay_secs=0)
start = time.time()
- self.assertRaises(StopIteration, ex.continuous_eval)
+ self.assertRaises(StopIteration, ex.continuous_eval,
+ evaluate_checkpoint_only_once=False)
duration = time.time() - start
expected = 5 * delay
tf.logging.info('eval duration (expected %f): %f', expected, duration)
@@ -327,6 +351,33 @@ class ExperimentTest(tf.test.TestCase):
self.assertEquals(1, est.fit_count)
self.assertEquals(1, est.eval_count)
+ def test_continuous_eval_evaluates_checkpoint_once(self):
+ # The TestEstimator will raise StopIteration the second time evaluate is
+ # called.
+ ex = tf.contrib.learn.Experiment(
+ TestEstimator(max_evals=1),
+ train_input_fn='train_input',
+ eval_input_fn='eval_input')
+
+ # This should not happen if the logic restricting evaluation of the same
+ # checkpoint works. We do need some checkpoint though, otherwise Experiment
+ # will never evaluate.
+ ex.estimator.fake_checkpoint()
+
+ # Start a separate thread with continuous eval
+ thread = threading.Thread(
+ target=lambda: ex.continuous_eval(delay_secs=0, throttle_delay_secs=0))
+ thread.start()
+
+ # The thread will die if it evaluates twice, and we should never evaluate
+ # twice since we don't write another checkpoint. Since we did not enable
+ # throttling, if it hasn't died after two seconds, we're good.
+ thread.join(2)
+ self.assertTrue(thread.is_alive())
+
+ # But we should have evaluated once.
+ count = ex.estimator.eval_count
+ self.assertEquals(1, count)
if __name__ == '__main__':
tf.test.main()