aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/estimator/training.py
diff options
context:
space:
mode:
authorGravatar Jianwei Xie <xiejw@google.com>2017-12-16 18:57:31 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-18 11:00:34 -0800
commitc11e07925a2c40ee220b9a3d76f82dc6ef17b87a (patch)
treefd15ccc5eea1f60690a6cfddbbaab062e31d2e53 /tensorflow/python/estimator/training.py
parentd234325e2b82174e203cbdb8f19dfb86bbad7bec (diff)
Introduce the ContinuousEvalListener
PiperOrigin-RevId: 179319836
Diffstat (limited to 'tensorflow/python/estimator/training.py')
-rw-r--r--tensorflow/python/estimator/training.py238
1 files changed, 189 insertions, 49 deletions
diff --git a/tensorflow/python/estimator/training.py b/tensorflow/python/estimator/training.py
index 58fccc3a29..569ea04f01 100644
--- a/tensorflow/python/estimator/training.py
+++ b/tensorflow/python/estimator/training.py
@@ -488,7 +488,11 @@ class _TrainingExecutor(object):
training and evaluation based on the setting in `tf.estimator.RunConfig`.
"""
- def __init__(self, estimator, train_spec, eval_spec):
+ def __init__(self,
+ estimator,
+ train_spec,
+ eval_spec,
+ continuous_eval_listener=None):
if not isinstance(estimator, estimator_lib.Estimator):
raise TypeError('`estimator` must have type `tf.estimator.Estimator`.')
self._estimator = estimator
@@ -501,6 +505,13 @@ class _TrainingExecutor(object):
raise TypeError('`eval_spec` must have type `tf.estimator.EvalSpec`.')
self._eval_spec = eval_spec
+ if (continuous_eval_listener and
+ not isinstance(continuous_eval_listener, _ContinuousEvalListener)):
+ raise TypeError('`continuous_eval_listener` must have type '
+ '`_ContinuousEvalListener`.')
+ self._continuous_eval_listener = (
+ continuous_eval_listener or _ContinuousEvalListener())
+
@property
def estimator(self):
return self._estimator
@@ -615,13 +626,16 @@ class _TrainingExecutor(object):
# _should_stop_local_train will then end the while True as the stopping
# condition is satisfied (both checks use the same global_step value,
# i.e., no race condition)
- metrics = evaluator.evaluate_and_export()
+ eval_result = evaluator.evaluate_and_export()
- if not metrics:
- # This is unexpected. Training should always end with a new checkpoint.
- raise RuntimeError('There was no new checkpoint after the training.')
+ if eval_result.status != _EvalStatus.EVALUATED:
+ # This is unexpected; should never happen.
+ # Training should always end with a new checkpoint.
+ raise RuntimeError('There was no new checkpoint after the training. '
+ 'Eval status: {}'.format(eval_result.status))
- if _should_stop_local_train(metrics[ops.GraphKeys.GLOBAL_STEP]):
+ if _should_stop_local_train(
+ eval_result.metrics[ops.GraphKeys.GLOBAL_STEP]):
break
def _start_std_server(self, config):
@@ -697,9 +711,11 @@ class _TrainingExecutor(object):
evaluator = _TrainingExecutor._Evaluator(self._estimator, self._eval_spec,
self._train_spec.max_steps)
- while True:
- if latest_eval_result:
- global_step = latest_eval_result.get(ops.GraphKeys.GLOBAL_STEP)
+ should_early_stop = False
+ while not should_early_stop:
+ if (latest_eval_result and
+ latest_eval_result.status == _EvalStatus.EVALUATED):
+ global_step = latest_eval_result.metrics.get(ops.GraphKeys.GLOBAL_STEP)
if (global_step and self._train_spec.max_steps and
global_step >= self._train_spec.max_steps):
logging.info(
@@ -708,21 +724,46 @@ class _TrainingExecutor(object):
self._train_spec.max_steps)
return
- # Final export signal: For any eval result with global_step >= train
- # max_steps, the evaluator will send the final export signal. The next
- # iteration of while loop will end the continuous eval as the stopping
- # condition is satisfied (both checks use the same global_step value,
- # i.e., no race condition)
- start = time.time()
- latest_eval_result = evaluator.evaluate_and_export()
+ latest_eval_result, should_early_stop = self._execute_evaluator_once(
+ evaluator, self._continuous_eval_listener,
+ self._eval_spec.throttle_secs)
+
+ def _execute_evaluator_once(self, evaluator, continuous_eval_listener,
+ throttle_secs):
+ """Executes the `evaluator`."""
+ start = time.time()
+
+ eval_result = None
+ should_early_stop = False
- # Throttle if necessary.
- elapsed_time = time.time() - start
- difference = self._eval_spec.throttle_secs - elapsed_time
- if difference > 0:
- logging.info('Waiting %f secs before starting next eval run.',
- difference)
- time.sleep(difference)
+ if not continuous_eval_listener.before_eval():
+ logging.info('Exiting evaluation, as requested by '
+ '_ContinuousEvalListener.before_eval.')
+ should_early_stop = True
+ return (eval_result, should_early_stop)
+
+ # Final export signal: For any eval result with global_step >= train
+ # max_steps, the evaluator will send the final export signal. The next
+ # iteration of while loop will end the continuous eval as the stopping
+ # condition is satisfied (both checks use the same global_step value,
+ # i.e., no race condition)
+ eval_result = evaluator.evaluate_and_export()
+
+ if not self._continuous_eval_listener.after_eval(eval_result):
+ logging.info('Exiting evaluation, as requested by '
+ '_ContinuousEvalListener.after_eval.')
+ should_early_stop = True
+ return (eval_result, should_early_stop)
+
+ # Throttle if necessary.
+ elapsed_time = time.time() - start
+ difference = throttle_secs - elapsed_time
+ if difference > 0:
+ logging.info('Waiting %f secs before starting next eval run.',
+ difference)
+ time.sleep(difference)
+
+ return (eval_result, should_early_stop)
class _Evaluator(object):
"""A helper class to call `Estimator.evaluate` and export model."""
@@ -743,8 +784,7 @@ class _TrainingExecutor(object):
"""Evaluate and (maybe) export the current model.
Returns:
- Evaluation results. Returns `None` if current round of evaluation is
- skipped.
+ An `EvalResult` instance.
Raises:
RuntimeError: for any unexpected internal error.
@@ -754,39 +794,32 @@ class _TrainingExecutor(object):
if not latest_ckpt_path:
self._log_err_msg('Estimator is not trained yet. Will start an '
'evaluation when a checkpoint is ready.')
- return None
+ return _EvalResult(status=_EvalStatus.MISSING_CHECKPOINT)
if latest_ckpt_path == self._previous_ckpt_path:
self._log_err_msg(
'No new checkpoint ready for evaluation. Skip the current '
'evaluation pass as evaluation results are expected to be same '
'for the same checkpoint.')
- return None
- eval_result = self._estimator.evaluate(
+ return _EvalResult(status=_EvalStatus.NO_NEW_CHECKPOINT)
+
+ metrics = self._estimator.evaluate(
input_fn=self._eval_spec.input_fn,
steps=self._eval_spec.steps,
name=self._eval_spec.name,
checkpoint_path=latest_ckpt_path,
hooks=self._eval_spec.hooks)
- if not eval_result:
- raise RuntimeError(
- 'Internal error: `Estimator.evaluate` should never return empty '
- 'result.')
- if not isinstance(eval_result, dict):
- raise TypeError(
- '`Estimator.evaluate` should return dict. Given {}.'.format(
- type(eval_result)))
- if ops.GraphKeys.GLOBAL_STEP not in eval_result:
- raise RuntimeError(
- 'Internal error: `Estimator.evaluate` result should have '
- '`global_step` in result. Given {}'.format(eval_result))
+ # _EvalResult validates the metrics.
+ eval_result = _EvalResult(
+ status=_EvalStatus.EVALUATED,
+ metrics=metrics,
+ checkpoint_path=latest_ckpt_path)
- is_the_final_export = (eval_result[ops.GraphKeys.GLOBAL_STEP] >=
- self._max_training_steps
- if self._max_training_steps else False)
- self._export_eval_result(eval_result, latest_ckpt_path,
- is_the_final_export)
+ is_the_final_export = (
+ eval_result.metrics[ops.GraphKeys.GLOBAL_STEP] >=
+ self._max_training_steps if self._max_training_steps else False)
+ self._export_eval_result(eval_result, is_the_final_export)
if is_the_final_export:
logging.debug('Calling exporter with the `is_the_final_export=True`.')
@@ -803,8 +836,7 @@ class _TrainingExecutor(object):
logging.warning(message)
self._last_warning_time = current_time
- def _export_eval_result(self, eval_result, checkpoint_path,
- is_the_final_export):
+ def _export_eval_result(self, eval_result, is_the_final_export):
"""Export `eval_result` according to exporters in `EvalSpec`."""
export_dir_base = os.path.join(
compat.as_str_any(self._estimator.model_dir),
@@ -816,6 +848,114 @@ class _TrainingExecutor(object):
export_path=os.path.join(
compat.as_str_any(export_dir_base),
compat.as_str_any(exporter.name)),
- checkpoint_path=checkpoint_path,
- eval_result=eval_result,
+ checkpoint_path=eval_result.checkpoint_path,
+ eval_result=eval_result.metrics,
is_the_final_export=is_the_final_export)
+
+
+class _EvalStatus(object):
+ """The status of an evaluation event.
+
+ For local training and evaluation, the status can only be `EVALUATED` as
+ `Estimator.train` always generates a new checkpoint.
+
+ For distributed training and evaluation, a separated evaluator keeps looking
+ for new checkpoint. So, multiple situations might occur:
+
+ - EVALUATED: A new checkpoint is found since last evaluation.
+ `Estimator.evaluate` will be invoked.
+ - MISSING_CHECKPOINT: No checkpoint can be found. Typically, this means
+ the trainer has not yet produced any checkpoint.
+ - NO_NEW_CHECKPOINT: No new checkpoint can be found since last evaluation.
+ Typically, this means the trainer has not yet produced any new checkpoint.
+ """
+
+ EVALUATED = 'evaluated'
+ MISSING_CHECKPOINT = 'missing checkpoint'
+ NO_NEW_CHECKPOINT = 'no new checkpoint'
+
+
+class _EvalResult(
+ collections.namedtuple('EvalResult',
+ ['status', 'metrics', 'checkpoint_path'])):
+ """_EvalResult holds the result of an evaluation event."""
+
+ def __new__(cls, status, metrics=None, checkpoint_path=None):
+ """Creates a validated `_EvalResult`.
+
+ Args:
+ status: See `_EvalStatus`.
+ metrics: The evaluation results returned by `Estimator.evaluate`. Only set
+ if status is `EVALUATED`.
+ checkpoint_path: The corresponding checkpoint path for the `metrics`. Only
+ set if status is `EVALUATED`.
+ Returns:
+ A validated `_EvalResult` object.
+
+ Raises:
+ ValueError: If validation fails.
+ TypeError: If any of the arguments is not the expected type.
+ """
+
+ if status != _EvalStatus.EVALUATED:
+ if metrics:
+ raise ValueError(
+ 'metrics must be `None` if status is not {}; got status {},'
+ ' metrics {}'.format(_EvalStatus.EVALUATED, status, metrics))
+ if checkpoint_path:
+ raise ValueError(
+ 'checkpoint must be `None` if status is not {}; got status {}, '
+ 'checkpoint_path {}'.format(
+ _EvalStatus.EVALUATED, status, checkpoint_path))
+ return super(_EvalResult, cls).__new__(cls, status, metrics,
+ checkpoint_path)
+
+ # Now, evaluated case.
+ assert status == _EvalStatus.EVALUATED
+
+ # Validates metrics.
+ if not metrics:
+ raise ValueError(
+ 'Internal error: `Estimator.evaluate` should never return empty '
+ 'metrics.')
+ if not isinstance(metrics, dict):
+ raise TypeError(
+ '`Estimator.evaluate` should return dict. Given {}.'.format(
+ type(metrics)))
+ if ops.GraphKeys.GLOBAL_STEP not in metrics:
+ raise ValueError(
+ 'Internal error: `Estimator.evaluate` result should have '
+ '`global_step` in result. Given {}'.format(metrics))
+
+ # Validates checkpoint_path.
+ if not checkpoint_path:
+ raise ValueError(
+ 'Internal error: `checkpoint_path` should never be empty.')
+
+ return super(_EvalResult, cls).__new__(cls, status, metrics,
+ checkpoint_path)
+
+
+class _ContinuousEvalListener(object):
+ """Interface for listeners that take action before or after evaluation."""
+
+ def before_eval(self):
+ """Called before evaluation.
+
+ Returns:
+ `False` if you want to skip the current evaluation and early stop the
+ continuous evaluation; `True` otherwise.
+ """
+ return True
+
+ def after_eval(self, eval_result):
+ """Called after the evaluation is executed.
+
+ Args:
+ eval_result: An `_EvalResult` instance.
+
+ Returns:
+ False if you want to early stop continuous evaluation; `True` otherwise.
+ """
+ del eval_result
+ return True