diff options
author | Jianwei Xie <xiejw@google.com> | 2017-10-05 12:58:51 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-10-05 13:07:40 -0700 |
commit | 631d3434ff33debfd0bf46d9d8602172f549c82d (patch) | |
tree | e03196bd1b8e35d5fc4e85bacde43dc3b215f7c0 /tensorflow/python/estimator/training.py | |
parent | a429d07bf545b5fd25c44f95fd50e012440bf99b (diff) |
Adds throlle_secs into run_master
PiperOrigin-RevId: 171194766
Diffstat (limited to 'tensorflow/python/estimator/training.py')
-rw-r--r-- | tensorflow/python/estimator/training.py | 74 |
1 files changed, 59 insertions, 15 deletions
diff --git a/tensorflow/python/estimator/training.py b/tensorflow/python/estimator/training.py index 5c0ebbea35..64b014a6b5 100644 --- a/tensorflow/python/estimator/training.py +++ b/tensorflow/python/estimator/training.py @@ -519,23 +519,51 @@ class _TrainingExecutor(object): class NewCheckpointListener( basic_session_run_hooks.CheckpointSaverListener): - def __init__(self, estimator, eval_spec, max_training_steps): - # pylint: disable=protected-access - self._evaluator = _TrainingExecutor._Evaluator(estimator, eval_spec, - max_training_steps) - # pylint: enable=protected-access + def __init__(self, evaluator, eval_throttle_secs): + self._evaluator = evaluator + self._eval_throttle_secs = eval_throttle_secs + + def begin(self): + self._timer = basic_session_run_hooks.SecondOrStepTimer( + every_secs=self._eval_throttle_secs) def after_save(self, session, global_step_value): - del session, global_step_value - self._evaluator.evaluate_and_export() + del session # unused; required by signature. + + if self._timer.should_trigger_for_step(global_step_value): + self._timer.update_last_triggered_step(global_step_value) + self._evaluator.evaluate_and_export() + else: + logging.info( + 'Skip the current checkpoint eval due to throttle secs ' + '({} secs).'.format(self._eval_throttle_secs)) + + # Final export signal: For any eval result with global_step >= train + # max_steps, the evaluator will send the final export signal. There is a + # small chance that the Estimator.train stopping logic sees a different + # global_step value (due to global step race condition and the fact the + # saver sees a larger value for checkpoing saving), which does not end + # the training. When the training ends, a new checkpoint is generated, which + # triggers the listener again. So, it could be the case the final export is + # triggered twice. + # + # But here, throttle_secs will skip the next intermediate checkpoint and, + # so, the double final export chance is very small. + evaluator = _TrainingExecutor._Evaluator( + self._estimator, self._eval_spec, self._train_spec.max_steps) # When the underlying `Estimator` object saves a new checkpoint, we would # like this callback to be called so that evaluation and export can trigger. saving_listeners = [ - NewCheckpointListener(self._estimator, self._eval_spec, - self._train_spec.max_steps) + NewCheckpointListener(evaluator, self._eval_spec.throttle_secs) ] - return self._start_distributed_training(saving_listeners=saving_listeners) + self._start_distributed_training(saving_listeners=saving_listeners) + + if not evaluator.is_final_export_triggered: + logging.info('Training has already ended. But the last eval is skipped ' + 'due to eval throttle_secs. Now evaluating the final ' + 'checkpoint.') + evaluator.evaluate_and_export() def run_evaluator(self): """Runs task evaluator.""" @@ -580,6 +608,11 @@ class _TrainingExecutor(object): max_steps=self._train_spec.max_steps, hooks=train_hooks) + # Final export signal: For any eval result with global_step >= train + # max_steps, the evaluator will send the final export signal. The + # _should_stop_local_train will then end the while True as the stopping + # condition is satisfied (both checks use the same global_step value, + # i.e., no race condition) metrics = evaluator.evaluate_and_export() if not metrics: @@ -656,6 +689,11 @@ class _TrainingExecutor(object): self._train_spec.max_steps) return + # Final export signal: For any eval result with global_step >= train + # max_steps, the evaluator will send the final export signal. The next + # iteration of while loop will end the continuous eval as the stopping + # condition is satisfied (both checks use the same global_step value, + # i.e., no race condition) start = time.time() latest_eval_result = evaluator.evaluate_and_export() @@ -673,10 +711,15 @@ class _TrainingExecutor(object): def __init__(self, estimator, eval_spec, max_training_steps): self._estimator = estimator self._eval_spec = eval_spec + self._is_final_export_triggered = False self._previous_ckpt_path = None self._last_warning_time = 0 self._max_training_steps = max_training_steps + @property + def is_final_export_triggered(self): + return self._is_final_export_triggered + def evaluate_and_export(self): """Evaluate and (maybe) export the current model. @@ -720,15 +763,16 @@ class _TrainingExecutor(object): 'Internal error: `Estimator.evaluate` result should have ' '`global_step` in result. Given {}'.format(eval_result)) - # TODO(isaprykin): There is a potential race condition here in the - # distributed setting. The worker job that performs training - # might stop at a later global step value than the evalutor job. is_the_final_export = (eval_result[ops.GraphKeys.GLOBAL_STEP] >= self._max_training_steps if self._max_training_steps else False) self._export_eval_result(eval_result, latest_ckpt_path, is_the_final_export) + if is_the_final_export: + logging.debug('Calling exporter with the `is_the_final_export=True`.') + self._is_final_export_triggered = True + self._last_warning_time = 0 self._previous_ckpt_path = latest_ckpt_path return eval_result @@ -749,8 +793,8 @@ class _TrainingExecutor(object): for exporter in self._eval_spec.exporters: exporter.export( - self._estimator, - os.path.join( + estimator=self._estimator, + export_path=os.path.join( compat.as_str_any(export_dir_base), compat.as_str_any(exporter.name)), checkpoint_path=checkpoint_path, |