diff options
author | 2016-12-01 23:10:01 -0800 | |
---|---|---|
committer | 2016-12-01 23:23:10 -0800 | |
commit | 40d4b8efa317411113784135612adc960f2721aa (patch) | |
tree | f962d65ccf26850b524cdb9168592fffbfc765f4 | |
parent | c78b1ec74d1c4aab88bca6e9010c7582c1f89706 (diff) |
Don't evaluate checkpoints twice in continuous eval.
Change: 140819261
6 files changed, 139 insertions, 40 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py index ea43fba6b3..bf1af5390e 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py @@ -414,7 +414,7 @@ class BaseEstimator( ) def evaluate( self, x=None, y=None, input_fn=None, feed_fn=None, batch_size=None, - steps=None, metrics=None, name=None): + steps=None, metrics=None, name=None, checkpoint_path=None): # pylint: disable=g-doc-args,g-doc-return-or-yield """See `Evaluable`. @@ -429,11 +429,13 @@ class BaseEstimator( if metrics is not None and not isinstance(metrics, dict): raise ValueError('Metrics argument should be None or dict. ' 'Got %s.' % metrics) - eval_results, global_step = self._evaluate_model(input_fn=input_fn, - feed_fn=feed_fn, - steps=steps, - metrics=metrics, - name=name) + eval_results, global_step = self._evaluate_model( + input_fn=input_fn, + feed_fn=feed_fn, + steps=steps, + metrics=metrics, + name=name, + checkpoint_path=checkpoint_path) if eval_results is not None: eval_results.update({'global_step': global_step}) return eval_results @@ -769,18 +771,21 @@ class BaseEstimator( steps, feed_fn=None, metrics=None, - name=''): + name='', + checkpoint_path=None): # TODO(wicke): Remove this once Model and associated code are gone. if (hasattr(self._config, 'execution_mode') and self._config.execution_mode not in ('all', 'evaluate', 'eval_evalset')): return None, None - # Check that model has been trained. - checkpoint_path = self._model_dir - latest_path = saver.latest_checkpoint(checkpoint_path) - if not latest_path: - raise NotFittedError("Couldn't find trained model at %s." - % checkpoint_path) + # Check that model has been trained (if nothing has been set explicitly). + if not checkpoint_path: + latest_path = saver.latest_checkpoint(self._model_dir) + if not latest_path: + raise NotFittedError("Couldn't find trained model at %s." + % self._model_dir) + checkpoint_path = self._model_dir + # Setup output directory. eval_dir = os.path.join(self._model_dir, 'eval' if not name else 'eval_' + name) diff --git a/tensorflow/contrib/learn/python/learn/estimators/random_forest.py b/tensorflow/contrib/learn/python/learn/estimators/random_forest.py index deb55efc9f..cf9b508470 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/random_forest.py +++ b/tensorflow/contrib/learn/python/learn/estimators/random_forest.py @@ -219,6 +219,11 @@ class TensorForestEstimator(evaluable.Evaluable, trainable.Trainable): config=config, feature_engineering_fn=feature_engineering_fn) + @property + def model_dir(self): + """See evaluable.Evaluable.""" + return self._estimator.model_dir + def evaluate( self, x=None, y=None, input_fn=None, feed_fn=None, batch_size=None, steps=None, metrics=None, name=None): diff --git a/tensorflow/contrib/learn/python/learn/estimators/svm.py b/tensorflow/contrib/learn/python/learn/estimators/svm.py index a6e4e7b6a3..561a898e78 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/svm.py +++ b/tensorflow/contrib/learn/python/learn/estimators/svm.py @@ -168,6 +168,11 @@ class SVM(trainable.Trainable, evaluable.Evaluable): if not self._estimator.config.is_chief: self._chief_hook = None + @property + def model_dir(self): + """See trainable.Evaluable.""" + return self._estimator.model_dir + def fit(self, x=None, y=None, input_fn=None, steps=None, batch_size=None, monitors=None, max_steps=None): """See trainable.Trainable.""" @@ -181,11 +186,13 @@ class SVM(trainable.Trainable, evaluable.Evaluable): # pylint: disable=protected-access def evaluate(self, x=None, y=None, input_fn=None, feed_fn=None, - batch_size=None, steps=None, metrics=None, name=None): + batch_size=None, steps=None, metrics=None, name=None, + checkpoint_path=None): """See evaluable.Evaluable.""" return self._estimator.evaluate(x=x, y=y, input_fn=input_fn, feed_fn=feed_fn, batch_size=batch_size, - steps=steps, metrics=metrics, name=name) + steps=steps, metrics=metrics, name=name, + checkpoint_path=checkpoint_path) @deprecated_arg_values( estimator.AS_ITERABLE_DATE, estimator.AS_ITERABLE_INSTRUCTIONS, diff --git a/tensorflow/contrib/learn/python/learn/evaluable.py b/tensorflow/contrib/learn/python/learn/evaluable.py index 051d132416..14cf5f01b8 100644 --- a/tensorflow/contrib/learn/python/learn/evaluable.py +++ b/tensorflow/contrib/learn/python/learn/evaluable.py @@ -27,10 +27,15 @@ class Evaluable(object): """ __metaclass__ = abc.ABCMeta + @abc.abstractproperty + def model_dir(self): + """Returns a path in which the eval process will look for checkpoints.""" + raise NotImplementedError + @abc.abstractmethod def evaluate( self, x=None, y=None, input_fn=None, feed_fn=None, batch_size=None, - steps=None, metrics=None, name=None): + steps=None, metrics=None, name=None, checkpoint_path=None): """Evaluates given model with provided evaluation data. Stop conditions - we evaluate on the given input data until one of the @@ -81,6 +86,8 @@ class Evaluable(object): `../../../metrics/python/ops/metrics_ops.py`. name: Name of the evaluation if user needs to run multiple evaluations on different data sets, such as on training data vs test data. + checkpoint_path: Path of a specific checkpoint to evaluate. If `None`, the + latest checkpoint in `model_dir` is used. Returns: Returns `dict` with evaluation results. diff --git a/tensorflow/contrib/learn/python/learn/experiment.py b/tensorflow/contrib/learn/python/learn/experiment.py index 2615fa8a00..bf0791af51 100644 --- a/tensorflow/contrib/learn/python/learn/experiment.py +++ b/tensorflow/contrib/learn/python/learn/experiment.py @@ -29,9 +29,9 @@ from tensorflow.contrib.learn.python.learn import evaluable from tensorflow.contrib.learn.python.learn import monitors from tensorflow.contrib.learn.python.learn import trainable from tensorflow.contrib.learn.python.learn.estimators import run_config -from tensorflow.contrib.learn.python.learn.estimators._sklearn import NotFittedError from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import basic_session_run_hooks +from tensorflow.python.training import saver from tensorflow.python.training import server_lib @@ -219,7 +219,8 @@ class Experiment(object): input_fn, name, delay_secs, - throttle_delay_secs): + throttle_delay_secs, + evaluate_checkpoint_only_once=True): """Run continuous eval. Runs infinite eval on the evaluation data set. This function starts @@ -235,6 +236,8 @@ class Experiment(object): throttle_delay_secs: Do not re-evaluate unless the last evaluation was started at least this many seconds ago. If None, defaults to self._continuous_eval_throttle_secs. + evaluate_checkpoint_only_once: Whether to skip evaluation of checkpoints + that have already been evaluated. Default is `True`. """ if delay_secs is None: delay_secs = self._eval_delay_secs @@ -245,21 +248,33 @@ class Experiment(object): logging.info("Waiting %f secs before starting eval.", delay_secs) time.sleep(delay_secs) - last_fitted_error_time = 0 + previous_path = None + last_warning_time = 0 while True: start = time.time() - try: + + error_msg = None + latest_path = saver.latest_checkpoint(self._estimator.model_dir) + if not latest_path: + error_msg = ("Estimator is not fitted yet. " + "Will start an evaluation when a checkpoint is ready.") + elif evaluate_checkpoint_only_once and latest_path == previous_path: + error_msg = "No new checkpoint ready for evaluation." + + if error_msg: + # Print warning message every 10 mins. + if time.time() - last_warning_time > 600: + logging.warning(error_msg) + last_warning_time = time.time() + else: self._estimator.evaluate(input_fn=input_fn, steps=self._eval_steps, metrics=self._eval_metrics, - name=name) - except NotFittedError: - # Print warning message every 10 mins. - if time.time() - last_fitted_error_time > 600: - logging.warning( - "Estimator is not fitted yet. " - "Will start an evaluation when a checkpoint will be ready.") - last_fitted_error_time = time.time() + name=name, + use_checkpoint=latest_path) + # Clear warning timer and update last evaluated checkpoint + last_warning_time = 0 + previous_path = latest_path duration = time.time() - start if duration < throttle_delay_secs: @@ -268,11 +283,14 @@ class Experiment(object): difference) time.sleep(difference) - def continuous_eval(self, delay_secs=None, throttle_delay_secs=None): - self._continuous_eval(self._eval_input_fn, - name="continuous", - delay_secs=delay_secs, - throttle_delay_secs=throttle_delay_secs) + def continuous_eval(self, delay_secs=None, throttle_delay_secs=None, + evaluate_checkpoint_only_once=True): + self._continuous_eval( + self._eval_input_fn, + name="continuous", + delay_secs=delay_secs, + throttle_delay_secs=throttle_delay_secs, + evaluate_checkpoint_only_once=evaluate_checkpoint_only_once) def continuous_eval_on_train_data(self, delay_secs=None, @@ -328,7 +346,6 @@ class Experiment(object): metrics=self._eval_metrics, name=eval_dir_suffix) - def run_std_server(self): """Starts a TensorFlow server and joins the serving thread. @@ -380,7 +397,14 @@ def _new_attr_context(obj, attr): This creates a context in which an object's attribute can be changed. Once the context is exited, the attribute reverts to its original value. - Example usage: + Args: + obj: An object whose attribute to restore at the end of the context. + attr: An attribute to remember and restore at the end of the context. + + Yields: + Context. + + Example: my_obj.x = 1 with _new_attr_context(my_obj, "x"): my_obj.x = 2 diff --git a/tensorflow/contrib/learn/python/learn/experiment_test.py b/tensorflow/contrib/learn/python/learn/experiment_test.py index d0f04ee842..0872f5a263 100644 --- a/tensorflow/contrib/learn/python/learn/experiment_test.py +++ b/tensorflow/contrib/learn/python/learn/experiment_test.py @@ -17,6 +17,9 @@ from __future__ import division from __future__ import print_function import json +import os +import tempfile +import threading import time import tensorflow as tf @@ -29,11 +32,17 @@ patch = tf.test.mock.patch class TestEstimator(tf.contrib.learn.Evaluable, tf.contrib.learn.Trainable): - def __init__(self, config=None): + def __init__(self, config=None, max_evals=5): self.eval_count = 0 self.fit_count = 0 + self._max_evals = max_evals self.monitors = [] self._config = config or run_config.RunConfig() + self._model_dir = tempfile.mkdtemp() + + @property + def model_dir(self): + return self._model_dir @property def config(self): @@ -42,12 +51,21 @@ class TestEstimator(tf.contrib.learn.Evaluable, tf.contrib.learn.Trainable): def evaluate(self, **kwargs): tf.logging.info('evaluate called with args: %s' % kwargs) self.eval_count += 1 - if self.eval_count > 5: - tf.logging.info('Ran 6 evals. Done.') + if self.eval_count > self._max_evals: + tf.logging.info('Ran %d evals. Done.' % self.eval_count) raise StopIteration() return [(key, kwargs[key]) for key in sorted(kwargs.keys())] + def fake_checkpoint(self): + save_path = os.path.join(self.model_dir, 'model.ckpt') + with tf.Session() as sess: + var = tf.Variable(1.0, name='var0') + save = tf.train.Saver({var.op.name: var}) + var.initializer.run() + save.save(sess, save_path, global_step=0) + def fit(self, **kwargs): + self.fake_checkpoint() tf.logging.info('fit called with args: %s' % kwargs) self.fit_count += 1 if 'monitors' in kwargs: @@ -195,6 +213,7 @@ class ExperimentTest(tf.test.TestCase): def test_evaluate(self): est = TestEstimator() + est.fake_checkpoint() ex = tf.contrib.learn.Experiment( est, train_input_fn='train_input', @@ -208,6 +227,7 @@ class ExperimentTest(tf.test.TestCase): def test_evaluate_delay(self): est = TestEstimator() + est.fake_checkpoint() ex = tf.contrib.learn.Experiment( est, train_input_fn='train_input', eval_input_fn='eval_input') @@ -220,6 +240,7 @@ class ExperimentTest(tf.test.TestCase): def test_continuous_eval(self): est = TestEstimator() + est.fake_checkpoint() ex = tf.contrib.learn.Experiment( est, train_input_fn='train_input', @@ -227,13 +248,15 @@ class ExperimentTest(tf.test.TestCase): eval_metrics='eval_metrics', eval_delay_secs=0, continuous_eval_throttle_secs=0) - self.assertRaises(StopIteration, ex.continuous_eval) + self.assertRaises(StopIteration, ex.continuous_eval, + evaluate_checkpoint_only_once=False) self.assertEquals(6, est.eval_count) self.assertEquals(0, est.fit_count) def test_continuous_eval_throttle_delay(self): for delay in [0, 1, 2]: est = TestEstimator() + est.fake_checkpoint() ex = tf.contrib.learn.Experiment( est, train_input_fn='train_input', @@ -242,7 +265,8 @@ class ExperimentTest(tf.test.TestCase): continuous_eval_throttle_secs=delay, eval_delay_secs=0) start = time.time() - self.assertRaises(StopIteration, ex.continuous_eval) + self.assertRaises(StopIteration, ex.continuous_eval, + evaluate_checkpoint_only_once=False) duration = time.time() - start expected = 5 * delay tf.logging.info('eval duration (expected %f): %f', expected, duration) @@ -327,6 +351,33 @@ class ExperimentTest(tf.test.TestCase): self.assertEquals(1, est.fit_count) self.assertEquals(1, est.eval_count) + def test_continuous_eval_evaluates_checkpoint_once(self): + # The TestEstimator will raise StopIteration the second time evaluate is + # called. + ex = tf.contrib.learn.Experiment( + TestEstimator(max_evals=1), + train_input_fn='train_input', + eval_input_fn='eval_input') + + # This should not happen if the logic restricting evaluation of the same + # checkpoint works. We do need some checkpoint though, otherwise Experiment + # will never evaluate. + ex.estimator.fake_checkpoint() + + # Start a separate thread with continuous eval + thread = threading.Thread( + target=lambda: ex.continuous_eval(delay_secs=0, throttle_delay_secs=0)) + thread.start() + + # The thread will die if it evaluates twice, and we should never evaluate + # twice since we don't write another checkpoint. Since we did not enable + # throttling, if it hasn't died after two seconds, we're good. + thread.join(2) + self.assertTrue(thread.is_alive()) + + # But we should have evaluated once. + count = ex.estimator.eval_count + self.assertEquals(1, count) if __name__ == '__main__': tf.test.main() |