diff options
author | Jianwei Xie <xiejw@google.com> | 2017-12-16 18:57:31 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-12-18 11:00:34 -0800 |
commit | c11e07925a2c40ee220b9a3d76f82dc6ef17b87a (patch) | |
tree | fd15ccc5eea1f60690a6cfddbbaab062e31d2e53 /tensorflow/python/estimator/training.py | |
parent | d234325e2b82174e203cbdb8f19dfb86bbad7bec (diff) |
Introduce the ContinuousEvalListener
PiperOrigin-RevId: 179319836
Diffstat (limited to 'tensorflow/python/estimator/training.py')
-rw-r--r-- | tensorflow/python/estimator/training.py | 238 |
1 files changed, 189 insertions, 49 deletions
diff --git a/tensorflow/python/estimator/training.py b/tensorflow/python/estimator/training.py index 58fccc3a29..569ea04f01 100644 --- a/tensorflow/python/estimator/training.py +++ b/tensorflow/python/estimator/training.py @@ -488,7 +488,11 @@ class _TrainingExecutor(object): training and evaluation based on the setting in `tf.estimator.RunConfig`. """ - def __init__(self, estimator, train_spec, eval_spec): + def __init__(self, + estimator, + train_spec, + eval_spec, + continuous_eval_listener=None): if not isinstance(estimator, estimator_lib.Estimator): raise TypeError('`estimator` must have type `tf.estimator.Estimator`.') self._estimator = estimator @@ -501,6 +505,13 @@ class _TrainingExecutor(object): raise TypeError('`eval_spec` must have type `tf.estimator.EvalSpec`.') self._eval_spec = eval_spec + if (continuous_eval_listener and + not isinstance(continuous_eval_listener, _ContinuousEvalListener)): + raise TypeError('`continuous_eval_listener` must have type ' + '`_ContinuousEvalListener`.') + self._continuous_eval_listener = ( + continuous_eval_listener or _ContinuousEvalListener()) + @property def estimator(self): return self._estimator @@ -615,13 +626,16 @@ class _TrainingExecutor(object): # _should_stop_local_train will then end the while True as the stopping # condition is satisfied (both checks use the same global_step value, # i.e., no race condition) - metrics = evaluator.evaluate_and_export() + eval_result = evaluator.evaluate_and_export() - if not metrics: - # This is unexpected. Training should always end with a new checkpoint. - raise RuntimeError('There was no new checkpoint after the training.') + if eval_result.status != _EvalStatus.EVALUATED: + # This is unexpected; should never happen. + # Training should always end with a new checkpoint. + raise RuntimeError('There was no new checkpoint after the training. ' + 'Eval status: {}'.format(eval_result.status)) - if _should_stop_local_train(metrics[ops.GraphKeys.GLOBAL_STEP]): + if _should_stop_local_train( + eval_result.metrics[ops.GraphKeys.GLOBAL_STEP]): break def _start_std_server(self, config): @@ -697,9 +711,11 @@ class _TrainingExecutor(object): evaluator = _TrainingExecutor._Evaluator(self._estimator, self._eval_spec, self._train_spec.max_steps) - while True: - if latest_eval_result: - global_step = latest_eval_result.get(ops.GraphKeys.GLOBAL_STEP) + should_early_stop = False + while not should_early_stop: + if (latest_eval_result and + latest_eval_result.status == _EvalStatus.EVALUATED): + global_step = latest_eval_result.metrics.get(ops.GraphKeys.GLOBAL_STEP) if (global_step and self._train_spec.max_steps and global_step >= self._train_spec.max_steps): logging.info( @@ -708,21 +724,46 @@ class _TrainingExecutor(object): self._train_spec.max_steps) return - # Final export signal: For any eval result with global_step >= train - # max_steps, the evaluator will send the final export signal. The next - # iteration of while loop will end the continuous eval as the stopping - # condition is satisfied (both checks use the same global_step value, - # i.e., no race condition) - start = time.time() - latest_eval_result = evaluator.evaluate_and_export() + latest_eval_result, should_early_stop = self._execute_evaluator_once( + evaluator, self._continuous_eval_listener, + self._eval_spec.throttle_secs) + + def _execute_evaluator_once(self, evaluator, continuous_eval_listener, + throttle_secs): + """Executes the `evaluator`.""" + start = time.time() + + eval_result = None + should_early_stop = False - # Throttle if necessary. - elapsed_time = time.time() - start - difference = self._eval_spec.throttle_secs - elapsed_time - if difference > 0: - logging.info('Waiting %f secs before starting next eval run.', - difference) - time.sleep(difference) + if not continuous_eval_listener.before_eval(): + logging.info('Exiting evaluation, as requested by ' + '_ContinuousEvalListener.before_eval.') + should_early_stop = True + return (eval_result, should_early_stop) + + # Final export signal: For any eval result with global_step >= train + # max_steps, the evaluator will send the final export signal. The next + # iteration of while loop will end the continuous eval as the stopping + # condition is satisfied (both checks use the same global_step value, + # i.e., no race condition) + eval_result = evaluator.evaluate_and_export() + + if not self._continuous_eval_listener.after_eval(eval_result): + logging.info('Exiting evaluation, as requested by ' + '_ContinuousEvalListener.after_eval.') + should_early_stop = True + return (eval_result, should_early_stop) + + # Throttle if necessary. + elapsed_time = time.time() - start + difference = throttle_secs - elapsed_time + if difference > 0: + logging.info('Waiting %f secs before starting next eval run.', + difference) + time.sleep(difference) + + return (eval_result, should_early_stop) class _Evaluator(object): """A helper class to call `Estimator.evaluate` and export model.""" @@ -743,8 +784,7 @@ class _TrainingExecutor(object): """Evaluate and (maybe) export the current model. Returns: - Evaluation results. Returns `None` if current round of evaluation is - skipped. + An `EvalResult` instance. Raises: RuntimeError: for any unexpected internal error. @@ -754,39 +794,32 @@ class _TrainingExecutor(object): if not latest_ckpt_path: self._log_err_msg('Estimator is not trained yet. Will start an ' 'evaluation when a checkpoint is ready.') - return None + return _EvalResult(status=_EvalStatus.MISSING_CHECKPOINT) if latest_ckpt_path == self._previous_ckpt_path: self._log_err_msg( 'No new checkpoint ready for evaluation. Skip the current ' 'evaluation pass as evaluation results are expected to be same ' 'for the same checkpoint.') - return None - eval_result = self._estimator.evaluate( + return _EvalResult(status=_EvalStatus.NO_NEW_CHECKPOINT) + + metrics = self._estimator.evaluate( input_fn=self._eval_spec.input_fn, steps=self._eval_spec.steps, name=self._eval_spec.name, checkpoint_path=latest_ckpt_path, hooks=self._eval_spec.hooks) - if not eval_result: - raise RuntimeError( - 'Internal error: `Estimator.evaluate` should never return empty ' - 'result.') - if not isinstance(eval_result, dict): - raise TypeError( - '`Estimator.evaluate` should return dict. Given {}.'.format( - type(eval_result))) - if ops.GraphKeys.GLOBAL_STEP not in eval_result: - raise RuntimeError( - 'Internal error: `Estimator.evaluate` result should have ' - '`global_step` in result. Given {}'.format(eval_result)) + # _EvalResult validates the metrics. + eval_result = _EvalResult( + status=_EvalStatus.EVALUATED, + metrics=metrics, + checkpoint_path=latest_ckpt_path) - is_the_final_export = (eval_result[ops.GraphKeys.GLOBAL_STEP] >= - self._max_training_steps - if self._max_training_steps else False) - self._export_eval_result(eval_result, latest_ckpt_path, - is_the_final_export) + is_the_final_export = ( + eval_result.metrics[ops.GraphKeys.GLOBAL_STEP] >= + self._max_training_steps if self._max_training_steps else False) + self._export_eval_result(eval_result, is_the_final_export) if is_the_final_export: logging.debug('Calling exporter with the `is_the_final_export=True`.') @@ -803,8 +836,7 @@ class _TrainingExecutor(object): logging.warning(message) self._last_warning_time = current_time - def _export_eval_result(self, eval_result, checkpoint_path, - is_the_final_export): + def _export_eval_result(self, eval_result, is_the_final_export): """Export `eval_result` according to exporters in `EvalSpec`.""" export_dir_base = os.path.join( compat.as_str_any(self._estimator.model_dir), @@ -816,6 +848,114 @@ class _TrainingExecutor(object): export_path=os.path.join( compat.as_str_any(export_dir_base), compat.as_str_any(exporter.name)), - checkpoint_path=checkpoint_path, - eval_result=eval_result, + checkpoint_path=eval_result.checkpoint_path, + eval_result=eval_result.metrics, is_the_final_export=is_the_final_export) + + +class _EvalStatus(object): + """The status of an evaluation event. + + For local training and evaluation, the status can only be `EVALUATED` as + `Estimator.train` always generates a new checkpoint. + + For distributed training and evaluation, a separated evaluator keeps looking + for new checkpoint. So, multiple situations might occur: + + - EVALUATED: A new checkpoint is found since last evaluation. + `Estimator.evaluate` will be invoked. + - MISSING_CHECKPOINT: No checkpoint can be found. Typically, this means + the trainer has not yet produced any checkpoint. + - NO_NEW_CHECKPOINT: No new checkpoint can be found since last evaluation. + Typically, this means the trainer has not yet produced any new checkpoint. + """ + + EVALUATED = 'evaluated' + MISSING_CHECKPOINT = 'missing checkpoint' + NO_NEW_CHECKPOINT = 'no new checkpoint' + + +class _EvalResult( + collections.namedtuple('EvalResult', + ['status', 'metrics', 'checkpoint_path'])): + """_EvalResult holds the result of an evaluation event.""" + + def __new__(cls, status, metrics=None, checkpoint_path=None): + """Creates a validated `_EvalResult`. + + Args: + status: See `_EvalStatus`. + metrics: The evaluation results returned by `Estimator.evaluate`. Only set + if status is `EVALUATED`. + checkpoint_path: The corresponding checkpoint path for the `metrics`. Only + set if status is `EVALUATED`. + Returns: + A validated `_EvalResult` object. + + Raises: + ValueError: If validation fails. + TypeError: If any of the arguments is not the expected type. + """ + + if status != _EvalStatus.EVALUATED: + if metrics: + raise ValueError( + 'metrics must be `None` if status is not {}; got status {},' + ' metrics {}'.format(_EvalStatus.EVALUATED, status, metrics)) + if checkpoint_path: + raise ValueError( + 'checkpoint must be `None` if status is not {}; got status {}, ' + 'checkpoint_path {}'.format( + _EvalStatus.EVALUATED, status, checkpoint_path)) + return super(_EvalResult, cls).__new__(cls, status, metrics, + checkpoint_path) + + # Now, evaluated case. + assert status == _EvalStatus.EVALUATED + + # Validates metrics. + if not metrics: + raise ValueError( + 'Internal error: `Estimator.evaluate` should never return empty ' + 'metrics.') + if not isinstance(metrics, dict): + raise TypeError( + '`Estimator.evaluate` should return dict. Given {}.'.format( + type(metrics))) + if ops.GraphKeys.GLOBAL_STEP not in metrics: + raise ValueError( + 'Internal error: `Estimator.evaluate` result should have ' + '`global_step` in result. Given {}'.format(metrics)) + + # Validates checkpoint_path. + if not checkpoint_path: + raise ValueError( + 'Internal error: `checkpoint_path` should never be empty.') + + return super(_EvalResult, cls).__new__(cls, status, metrics, + checkpoint_path) + + +class _ContinuousEvalListener(object): + """Interface for listeners that take action before or after evaluation.""" + + def before_eval(self): + """Called before evaluation. + + Returns: + `False` if you want to skip the current evaluation and early stop the + continuous evaluation; `True` otherwise. + """ + return True + + def after_eval(self, eval_result): + """Called after the evaluation is executed. + + Args: + eval_result: An `_EvalResult` instance. + + Returns: + False if you want to early stop continuous evaluation; `True` otherwise. + """ + del eval_result + return True |