aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/estimator/training.py
diff options
context:
space:
mode:
authorGravatar Jianwei Xie <xiejw@google.com>2017-10-05 12:58:51 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-05 13:07:40 -0700
commit631d3434ff33debfd0bf46d9d8602172f549c82d (patch)
treee03196bd1b8e35d5fc4e85bacde43dc3b215f7c0 /tensorflow/python/estimator/training.py
parenta429d07bf545b5fd25c44f95fd50e012440bf99b (diff)
Adds throlle_secs into run_master
PiperOrigin-RevId: 171194766
Diffstat (limited to 'tensorflow/python/estimator/training.py')
-rw-r--r--tensorflow/python/estimator/training.py74
1 files changed, 59 insertions, 15 deletions
diff --git a/tensorflow/python/estimator/training.py b/tensorflow/python/estimator/training.py
index 5c0ebbea35..64b014a6b5 100644
--- a/tensorflow/python/estimator/training.py
+++ b/tensorflow/python/estimator/training.py
@@ -519,23 +519,51 @@ class _TrainingExecutor(object):
class NewCheckpointListener(
basic_session_run_hooks.CheckpointSaverListener):
- def __init__(self, estimator, eval_spec, max_training_steps):
- # pylint: disable=protected-access
- self._evaluator = _TrainingExecutor._Evaluator(estimator, eval_spec,
- max_training_steps)
- # pylint: enable=protected-access
+ def __init__(self, evaluator, eval_throttle_secs):
+ self._evaluator = evaluator
+ self._eval_throttle_secs = eval_throttle_secs
+
+ def begin(self):
+ self._timer = basic_session_run_hooks.SecondOrStepTimer(
+ every_secs=self._eval_throttle_secs)
def after_save(self, session, global_step_value):
- del session, global_step_value
- self._evaluator.evaluate_and_export()
+ del session # unused; required by signature.
+
+ if self._timer.should_trigger_for_step(global_step_value):
+ self._timer.update_last_triggered_step(global_step_value)
+ self._evaluator.evaluate_and_export()
+ else:
+ logging.info(
+ 'Skip the current checkpoint eval due to throttle secs '
+ '({} secs).'.format(self._eval_throttle_secs))
+
+ # Final export signal: For any eval result with global_step >= train
+ # max_steps, the evaluator will send the final export signal. There is a
+ # small chance that the Estimator.train stopping logic sees a different
+ # global_step value (due to global step race condition and the fact the
+ # saver sees a larger value for checkpoing saving), which does not end
+ # the training. When the training ends, a new checkpoint is generated, which
+ # triggers the listener again. So, it could be the case the final export is
+ # triggered twice.
+ #
+ # But here, throttle_secs will skip the next intermediate checkpoint and,
+ # so, the double final export chance is very small.
+ evaluator = _TrainingExecutor._Evaluator(
+ self._estimator, self._eval_spec, self._train_spec.max_steps)
# When the underlying `Estimator` object saves a new checkpoint, we would
# like this callback to be called so that evaluation and export can trigger.
saving_listeners = [
- NewCheckpointListener(self._estimator, self._eval_spec,
- self._train_spec.max_steps)
+ NewCheckpointListener(evaluator, self._eval_spec.throttle_secs)
]
- return self._start_distributed_training(saving_listeners=saving_listeners)
+ self._start_distributed_training(saving_listeners=saving_listeners)
+
+ if not evaluator.is_final_export_triggered:
+ logging.info('Training has already ended. But the last eval is skipped '
+ 'due to eval throttle_secs. Now evaluating the final '
+ 'checkpoint.')
+ evaluator.evaluate_and_export()
def run_evaluator(self):
"""Runs task evaluator."""
@@ -580,6 +608,11 @@ class _TrainingExecutor(object):
max_steps=self._train_spec.max_steps,
hooks=train_hooks)
+ # Final export signal: For any eval result with global_step >= train
+ # max_steps, the evaluator will send the final export signal. The
+ # _should_stop_local_train will then end the while True as the stopping
+ # condition is satisfied (both checks use the same global_step value,
+ # i.e., no race condition)
metrics = evaluator.evaluate_and_export()
if not metrics:
@@ -656,6 +689,11 @@ class _TrainingExecutor(object):
self._train_spec.max_steps)
return
+ # Final export signal: For any eval result with global_step >= train
+ # max_steps, the evaluator will send the final export signal. The next
+ # iteration of while loop will end the continuous eval as the stopping
+ # condition is satisfied (both checks use the same global_step value,
+ # i.e., no race condition)
start = time.time()
latest_eval_result = evaluator.evaluate_and_export()
@@ -673,10 +711,15 @@ class _TrainingExecutor(object):
def __init__(self, estimator, eval_spec, max_training_steps):
self._estimator = estimator
self._eval_spec = eval_spec
+ self._is_final_export_triggered = False
self._previous_ckpt_path = None
self._last_warning_time = 0
self._max_training_steps = max_training_steps
+ @property
+ def is_final_export_triggered(self):
+ return self._is_final_export_triggered
+
def evaluate_and_export(self):
"""Evaluate and (maybe) export the current model.
@@ -720,15 +763,16 @@ class _TrainingExecutor(object):
'Internal error: `Estimator.evaluate` result should have '
'`global_step` in result. Given {}'.format(eval_result))
- # TODO(isaprykin): There is a potential race condition here in the
- # distributed setting. The worker job that performs training
- # might stop at a later global step value than the evalutor job.
is_the_final_export = (eval_result[ops.GraphKeys.GLOBAL_STEP] >=
self._max_training_steps
if self._max_training_steps else False)
self._export_eval_result(eval_result, latest_ckpt_path,
is_the_final_export)
+ if is_the_final_export:
+ logging.debug('Calling exporter with the `is_the_final_export=True`.')
+ self._is_final_export_triggered = True
+
self._last_warning_time = 0
self._previous_ckpt_path = latest_ckpt_path
return eval_result
@@ -749,8 +793,8 @@ class _TrainingExecutor(object):
for exporter in self._eval_spec.exporters:
exporter.export(
- self._estimator,
- os.path.join(
+ estimator=self._estimator,
+ export_path=os.path.join(
compat.as_str_any(export_dir_base),
compat.as_str_any(exporter.name)),
checkpoint_path=checkpoint_path,