aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/estimator/training.py
diff options
context:
space:
mode:
authorGravatar Mustafa Ispir <ispir@google.com>2018-06-18 17:04:18 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-18 17:07:19 -0700
commit3edb609926f2521c726737fc1efeae1572dc6581 (patch)
treea4e72f1068d518a7823e9f11efede52c297d7fbf /tensorflow/python/estimator/training.py
parentf91b5b0896e3ed2b57a32b5a21068b9b5c55899e (diff)
Improving local run behavior in estimator.train_and_evaluate.
Current behavior is unintuitive (depends on throttle_secs) and leads to frequent checkpoint than desired. This CL makes evaluation synchronized with checkpointing. It also makes the behavior more closer to distributed setting in following ways: * in distributed setting we do create input_pipeline only once, in current behavior of local run we do recreate input pipeline in a loop. This cl creates training input pipeline only once. * in distributed setting evaluator job waits for checkpoints which are dumped by training job. In current behavior of local run evaluator controls the checkpoint schedule. In this cl, we give back the control to trainer. PiperOrigin-RevId: 201085814
Diffstat (limited to 'tensorflow/python/estimator/training.py')
-rw-r--r--tensorflow/python/estimator/training.py160
1 files changed, 77 insertions, 83 deletions
diff --git a/tensorflow/python/estimator/training.py b/tensorflow/python/estimator/training.py
index 1572af579b..37b123217a 100644
--- a/tensorflow/python/estimator/training.py
+++ b/tensorflow/python/estimator/training.py
@@ -470,6 +470,61 @@ class _StopAtSecsHook(session_run_hook.SessionRunHook):
run_context.request_stop()
+class _NewCheckpointListenerForEvaluate(
+ basic_session_run_hooks.CheckpointSaverListener):
+ """A saver listener to run evaluate with every checkpoint."""
+
+ def __init__(self, evaluator, eval_throttle_secs, continuous_eval_listener):
+ self._evaluator = evaluator
+ self._eval_throttle_secs = eval_throttle_secs
+ self._continuous_eval_listener = continuous_eval_listener
+ self.eval_result, self.export_results = None, None
+
+ def begin(self):
+ self._timer = basic_session_run_hooks.SecondOrStepTimer(
+ every_secs=self._eval_throttle_secs)
+ self._is_first_run = True
+
+ def after_save(self, session, global_step_value):
+ del session # unused; required by signature.
+ # skip first run model is not trained yet.
+ if self._is_first_run:
+ self._is_first_run = False
+ return
+
+ if not self._continuous_eval_listener.before_eval():
+ logging.info('Exiting training and evaluation loop, as requested by '
+ '_ContinuousEvalListener.before_eval.')
+ return True
+ if self._timer.should_trigger_for_step(global_step_value):
+ self._evaluate(global_step_value) # updates self.eval_result
+ if not self._continuous_eval_listener.after_eval(self.eval_result):
+ logging.info('Exiting evaluation, as requested by '
+ '_ContinuousEvalListener.after_eval.')
+ return True
+ else:
+ # TODO(ispir): add remaining time in the log.
+ logging.info('Skip the current checkpoint eval due to throttle secs '
+ '({} secs).'.format(self._eval_throttle_secs))
+
+ def end(self, session, global_step_value):
+ # Evaluate if the last step has not been evaluated, yet.
+ if global_step_value != self._timer.last_triggered_step():
+ if self._continuous_eval_listener.before_eval():
+ self._evaluate(global_step_value)
+ self._continuous_eval_listener.after_eval(self.eval_result)
+
+ def _evaluate(self, global_step_value):
+ self._timer.update_last_triggered_step(global_step_value)
+ self.eval_result, self.export_results = (
+ self._evaluator.evaluate_and_export())
+ if self.eval_result.status != _EvalStatus.EVALUATED:
+ # This is unexpected; should never happen.
+ # Training should always end with a new checkpoint.
+ raise RuntimeError('There was no new checkpoint after the training. '
+ 'Eval status: {}'.format(self.eval_result.status))
+
+
class _TrainingExecutor(object):
"""The executor to run `Estimator` training and evaluation.
@@ -576,28 +631,6 @@ class _TrainingExecutor(object):
def run_master(self):
"""Runs task master."""
-
- class NewCheckpointListener(
- basic_session_run_hooks.CheckpointSaverListener):
-
- def __init__(self, evaluator, eval_throttle_secs):
- self._evaluator = evaluator
- self._eval_throttle_secs = eval_throttle_secs
-
- def begin(self):
- self._timer = basic_session_run_hooks.SecondOrStepTimer(
- every_secs=self._eval_throttle_secs)
-
- def after_save(self, session, global_step_value):
- del session # unused; required by signature.
-
- if self._timer.should_trigger_for_step(global_step_value):
- self._timer.update_last_triggered_step(global_step_value)
- self._evaluator.evaluate_and_export()
- else:
- logging.info('Skip the current checkpoint eval due to throttle secs '
- '({} secs).'.format(self._eval_throttle_secs))
-
_assert_eval_spec(self._eval_spec)
# Final export signal: For any eval result with global_step >= train
@@ -617,16 +650,12 @@ class _TrainingExecutor(object):
# When the underlying `Estimator` object saves a new checkpoint, we would
# like this callback to be called so that evaluation and export can trigger.
saving_listeners = [
- NewCheckpointListener(evaluator, self._eval_spec.throttle_secs)
+ _NewCheckpointListenerForEvaluate(evaluator,
+ self._eval_spec.throttle_secs,
+ _ContinuousEvalListener())
]
self._start_distributed_training(saving_listeners=saving_listeners)
- if not evaluator.is_final_export_triggered:
- logging.info('Training has already ended. But the last eval is skipped '
- 'due to eval throttle_secs. Now evaluating the final '
- 'checkpoint.')
- evaluator.evaluate_and_export()
-
def run_evaluator(self):
"""Runs task evaluator."""
# TODO(xiejw): To allow execution framework to add continuous eval listener.
@@ -640,68 +669,33 @@ class _TrainingExecutor(object):
def run_local(self):
"""Runs training and evaluation locally (non-distributed)."""
-
- def _should_stop_local_train(global_step):
- if self._train_spec.max_steps is None:
- return False
- if global_step >= self._train_spec.max_steps:
- return True
- return False
-
_assert_eval_spec(self._eval_spec)
- if self._eval_spec.throttle_secs <= 0:
- raise ValueError('eval_spec.throttle_secs should be positive, given: {}.'
- 'It is used do determine how long each training '
- 'iteration should go when train and evaluate '
- 'locally.'.format(self._eval_spec.throttle_secs))
-
- stop_hook = _StopAtSecsHook(self._eval_spec.throttle_secs)
- train_hooks = (
- list(self._train_spec.hooks) + [stop_hook] + list(self._train_hooks))
+ train_hooks = list(self._train_spec.hooks) + list(self._train_hooks)
logging.info('Start train and evaluate loop. The evaluate will happen '
- 'after {} secs (eval_spec.throttle_secs) or training is '
- 'finished.'.format(self._eval_spec.throttle_secs))
+ 'after every checkpoint. Checkpoint frequency is determined '
+ 'based on RunConfig arguments: save_checkpoints_steps {} or '
+ 'save_checkpoints_secs {}.'.format(
+ self._estimator.config.save_checkpoints_steps,
+ self._estimator.config.save_checkpoints_secs))
evaluator = _TrainingExecutor._Evaluator(self._estimator, self._eval_spec,
self._train_spec.max_steps)
- eval_result = _EvalResult(status=_EvalStatus.MISSING_CHECKPOINT)
- export_results = []
-
- while True:
- self._estimator.train(
- input_fn=self._train_spec.input_fn,
- max_steps=self._train_spec.max_steps,
- hooks=train_hooks)
-
- if not self._continuous_eval_listener.before_eval():
- logging.info('Exiting training and evaluation loop, as requested by '
- '_ContinuousEvalListener.before_eval.')
- break
-
- # Final export signal: For any eval result with global_step >= train
- # max_steps, the evaluator will send the final export signal. The
- # _should_stop_local_train will then end the while True as the stopping
- # condition is satisfied (both checks use the same global_step value,
- # i.e., no race condition)
- eval_result, export_results = evaluator.evaluate_and_export()
-
- if eval_result.status != _EvalStatus.EVALUATED:
- # This is unexpected; should never happen.
- # Training should always end with a new checkpoint.
- raise RuntimeError('There was no new checkpoint after the training. '
- 'Eval status: {}'.format(eval_result.status))
-
- if not self._continuous_eval_listener.after_eval(eval_result):
- logging.info('Exiting evaluation, as requested by '
- '_ContinuousEvalListener.after_eval.')
- break
+ listener_for_eval = _NewCheckpointListenerForEvaluate(
+ evaluator, self._eval_spec.throttle_secs,
+ self._continuous_eval_listener)
+ saving_listeners = [listener_for_eval]
+
+ self._estimator.train(
+ input_fn=self._train_spec.input_fn,
+ max_steps=self._train_spec.max_steps,
+ hooks=train_hooks,
+ saving_listeners=saving_listeners)
- if _should_stop_local_train(
- eval_result.metrics[ops.GraphKeys.GLOBAL_STEP]):
- break
- return eval_result.metrics, export_results
+ eval_result = listener_for_eval.eval_result or _EvalResult(
+ status=_EvalStatus.MISSING_CHECKPOINT)
+ return eval_result.metrics, listener_for_eval.export_results
def _start_std_server(self, config):
"""Creates, starts, and returns a server_lib.Server."""