aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python
diff options
context:
space:
mode:
authorGravatar Jianwei Xie <xiejw@google.com>2017-10-05 12:58:51 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-05 13:07:40 -0700
commit631d3434ff33debfd0bf46d9d8602172f549c82d (patch)
treee03196bd1b8e35d5fc4e85bacde43dc3b215f7c0 /tensorflow/python
parenta429d07bf545b5fd25c44f95fd50e012440bf99b (diff)
Adds throlle_secs into run_master
PiperOrigin-RevId: 171194766
Diffstat (limited to 'tensorflow/python')
-rw-r--r--tensorflow/python/estimator/training.py74
-rw-r--r--tensorflow/python/estimator/training_test.py268
2 files changed, 307 insertions, 35 deletions
diff --git a/tensorflow/python/estimator/training.py b/tensorflow/python/estimator/training.py
index 5c0ebbea35..64b014a6b5 100644
--- a/tensorflow/python/estimator/training.py
+++ b/tensorflow/python/estimator/training.py
@@ -519,23 +519,51 @@ class _TrainingExecutor(object):
class NewCheckpointListener(
basic_session_run_hooks.CheckpointSaverListener):
- def __init__(self, estimator, eval_spec, max_training_steps):
- # pylint: disable=protected-access
- self._evaluator = _TrainingExecutor._Evaluator(estimator, eval_spec,
- max_training_steps)
- # pylint: enable=protected-access
+ def __init__(self, evaluator, eval_throttle_secs):
+ self._evaluator = evaluator
+ self._eval_throttle_secs = eval_throttle_secs
+
+ def begin(self):
+ self._timer = basic_session_run_hooks.SecondOrStepTimer(
+ every_secs=self._eval_throttle_secs)
def after_save(self, session, global_step_value):
- del session, global_step_value
- self._evaluator.evaluate_and_export()
+ del session # unused; required by signature.
+
+ if self._timer.should_trigger_for_step(global_step_value):
+ self._timer.update_last_triggered_step(global_step_value)
+ self._evaluator.evaluate_and_export()
+ else:
+ logging.info(
+ 'Skip the current checkpoint eval due to throttle secs '
+ '({} secs).'.format(self._eval_throttle_secs))
+
+ # Final export signal: For any eval result with global_step >= train
+ # max_steps, the evaluator will send the final export signal. There is a
+ # small chance that the Estimator.train stopping logic sees a different
+ # global_step value (due to global step race condition and the fact the
+ # saver sees a larger value for checkpoing saving), which does not end
+ # the training. When the training ends, a new checkpoint is generated, which
+ # triggers the listener again. So, it could be the case the final export is
+ # triggered twice.
+ #
+ # But here, throttle_secs will skip the next intermediate checkpoint and,
+ # so, the double final export chance is very small.
+ evaluator = _TrainingExecutor._Evaluator(
+ self._estimator, self._eval_spec, self._train_spec.max_steps)
# When the underlying `Estimator` object saves a new checkpoint, we would
# like this callback to be called so that evaluation and export can trigger.
saving_listeners = [
- NewCheckpointListener(self._estimator, self._eval_spec,
- self._train_spec.max_steps)
+ NewCheckpointListener(evaluator, self._eval_spec.throttle_secs)
]
- return self._start_distributed_training(saving_listeners=saving_listeners)
+ self._start_distributed_training(saving_listeners=saving_listeners)
+
+ if not evaluator.is_final_export_triggered:
+ logging.info('Training has already ended. But the last eval is skipped '
+ 'due to eval throttle_secs. Now evaluating the final '
+ 'checkpoint.')
+ evaluator.evaluate_and_export()
def run_evaluator(self):
"""Runs task evaluator."""
@@ -580,6 +608,11 @@ class _TrainingExecutor(object):
max_steps=self._train_spec.max_steps,
hooks=train_hooks)
+ # Final export signal: For any eval result with global_step >= train
+ # max_steps, the evaluator will send the final export signal. The
+ # _should_stop_local_train will then end the while True as the stopping
+ # condition is satisfied (both checks use the same global_step value,
+ # i.e., no race condition)
metrics = evaluator.evaluate_and_export()
if not metrics:
@@ -656,6 +689,11 @@ class _TrainingExecutor(object):
self._train_spec.max_steps)
return
+ # Final export signal: For any eval result with global_step >= train
+ # max_steps, the evaluator will send the final export signal. The next
+ # iteration of while loop will end the continuous eval as the stopping
+ # condition is satisfied (both checks use the same global_step value,
+ # i.e., no race condition)
start = time.time()
latest_eval_result = evaluator.evaluate_and_export()
@@ -673,10 +711,15 @@ class _TrainingExecutor(object):
def __init__(self, estimator, eval_spec, max_training_steps):
self._estimator = estimator
self._eval_spec = eval_spec
+ self._is_final_export_triggered = False
self._previous_ckpt_path = None
self._last_warning_time = 0
self._max_training_steps = max_training_steps
+ @property
+ def is_final_export_triggered(self):
+ return self._is_final_export_triggered
+
def evaluate_and_export(self):
"""Evaluate and (maybe) export the current model.
@@ -720,15 +763,16 @@ class _TrainingExecutor(object):
'Internal error: `Estimator.evaluate` result should have '
'`global_step` in result. Given {}'.format(eval_result))
- # TODO(isaprykin): There is a potential race condition here in the
- # distributed setting. The worker job that performs training
- # might stop at a later global step value than the evalutor job.
is_the_final_export = (eval_result[ops.GraphKeys.GLOBAL_STEP] >=
self._max_training_steps
if self._max_training_steps else False)
self._export_eval_result(eval_result, latest_ckpt_path,
is_the_final_export)
+ if is_the_final_export:
+ logging.debug('Calling exporter with the `is_the_final_export=True`.')
+ self._is_final_export_triggered = True
+
self._last_warning_time = 0
self._previous_ckpt_path = latest_ckpt_path
return eval_result
@@ -749,8 +793,8 @@ class _TrainingExecutor(object):
for exporter in self._eval_spec.exporters:
exporter.export(
- self._estimator,
- os.path.join(
+ estimator=self._estimator,
+ export_path=os.path.join(
compat.as_str_any(export_dir_base),
compat.as_str_any(exporter.name)),
checkpoint_path=checkpoint_path,
diff --git a/tensorflow/python/estimator/training_test.py b/tensorflow/python/estimator/training_test.py
index 40972ab5a0..8c00ebddf3 100644
--- a/tensorflow/python/estimator/training_test.py
+++ b/tensorflow/python/estimator/training_test.py
@@ -45,6 +45,7 @@ from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.summary import summary_iterator
from tensorflow.python.summary.writer import writer_cache
+from tensorflow.python.training import basic_session_run_hooks
from tensorflow.python.training import monitored_session
from tensorflow.python.training import server_lib
from tensorflow.python.training import session_run_hook
@@ -692,37 +693,145 @@ class TrainingExecutorRunChiefTest(_TrainingExecutorTrainingTest,
mock_sleep.assert_not_called()
-class TrainingExecutorRunMasterTest(_TrainingExecutorTrainingTest,
- test.TestCase):
+class TrainingExecutorRunMasterTest(test.TestCase):
"""Tests run_chief of _TrainingExecutor."""
- def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
- test.TestCase.__init__(self, methodName)
- _TrainingExecutorTrainingTest.__init__(
- self,
- run_config=_create_run_config_with_cluster_spec(_TF_CONFIG_FOR_MASTER))
+ def setUp(self):
+ self._run_config = _create_run_config_with_cluster_spec(
+ _TF_CONFIG_FOR_MASTER)
@test.mock.patch.object(server_lib, 'Server')
def test_no_delay_for_master(self, _):
mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
+ mock_est.evaluate = lambda *args, **kw: {ops.GraphKeys.GLOBAL_STEP: 123}
mock_est.config = self._run_config
mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
- mock_eval_spec = test.mock.Mock(spec=training.EvalSpec)
+ mock_eval_spec = test.mock.Mock(spec=training.EvalSpec, exporters=[])
executor = training._TrainingExecutor(mock_est, mock_train_spec,
mock_eval_spec)
with test.mock.patch.object(time, 'sleep') as mock_sleep:
- self._run_task(executor)
+ executor.run_master()
mock_sleep.assert_not_called()
+ @test.mock.patch.object(time, 'sleep')
+ @test.mock.patch.object(server_lib, 'Server')
+ def test_train_with_train_spec(self, mock_server, unused_mock_sleep):
+ mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
+ mock_est.evaluate = lambda *args, **kw: {ops.GraphKeys.GLOBAL_STEP: 123}
+ mock_est.config = self._run_config
+ train_spec = training.TrainSpec(
+ input_fn=lambda: 1, max_steps=2, hooks=[_FakeHook()])
+ mock_eval_spec = test.mock.Mock(spec=training.EvalSpec, exporters=[])
+ mock_server_instance = mock_server.return_value
+
+ executor = training._TrainingExecutor(mock_est, train_spec, mock_eval_spec)
+ executor.run_master()
+
+ mock_server.assert_called_with(
+ mock_est.config.cluster_spec,
+ job_name=mock_est.config.task_type,
+ task_index=mock_est.config.task_id,
+ config=test.mock.ANY,
+ start=False)
+
+ self.assertTrue(mock_server_instance.start.called)
+
+ mock_est.train.assert_called_with(input_fn=train_spec.input_fn,
+ max_steps=train_spec.max_steps,
+ hooks=train_spec.hooks,
+ saving_listeners=test.mock.ANY)
+ mock_est.export_savedmodel.assert_not_called()
+
+ @test.mock.patch.object(time, 'sleep')
+ @test.mock.patch.object(server_lib, 'Server')
+ def test_no_server_startup_in_google(self, mock_server, unused_mock_sleep):
+ mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
+ mock_est.evaluate = lambda *args, **kw: {ops.GraphKeys.GLOBAL_STEP: 123}
+ mock_est.config = self._run_config
+ mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
+ mock_eval_spec = test.mock.Mock(spec=training.EvalSpec, exporters=[])
+
+ executor = training._TrainingExecutor(mock_est, mock_train_spec,
+ mock_eval_spec)
+ tf_config = {'TF_CONFIG': json.dumps(_TF_CONFIG_FOR_GOOGLE)}
+ with test.mock.patch.dict('os.environ', tf_config):
+ executor.run_master()
+ mock_server.assert_not_called()
+
+ def test_fail_with_empty_cluster_spec(self):
+ mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
+ mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
+ mock_eval_spec = test.mock.Mock(spec=training.EvalSpec)
+
+ mock_est.config = test.mock.PropertyMock(spec=run_config_lib.RunConfig)
+ mock_est.config.cluster_spec = None
+ mock_est.config.master = 'grpc://...'
+ mock_est.config.task_type = 'worker'
+ mock_est.config.task_id = 2
+
+ with self.assertRaisesRegexp(RuntimeError,
+ _INVALID_CONFIG_FOR_STD_SERVER_MSG):
+ training._TrainingExecutor(
+ mock_est, mock_train_spec, mock_eval_spec).run_master()
+
+ def test_fail_with_empty_master(self):
+ mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
+ mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
+ mock_eval_spec = test.mock.Mock(spec=training.EvalSpec)
+
+ mock_est.config = test.mock.PropertyMock(spec=run_config_lib.RunConfig)
+ mock_est.config.cluster_spec = {'worker': 'dummy'}
+ mock_est.config.master = ''
+ mock_est.config.task_type = 'worker'
+ mock_est.config.task_id = 2
+
+ with self.assertRaisesRegexp(RuntimeError,
+ _INVALID_CONFIG_FOR_STD_SERVER_MSG):
+ training._TrainingExecutor(
+ mock_est, mock_train_spec, mock_eval_spec).run_master()
+
+ def test_fail_with_empty_task_type(self):
+ mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
+ mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
+ mock_eval_spec = test.mock.Mock(spec=training.EvalSpec)
+
+ mock_est.config = test.mock.PropertyMock(spec=run_config_lib.RunConfig)
+ mock_est.config.cluster_spec = {'worker': 'dummy'}
+ mock_est.config.master = 'grpc://...'
+ mock_est.config.task_type = ''
+ mock_est.config.task_id = 2
+
+ with self.assertRaisesRegexp(RuntimeError,
+ _INVALID_CONFIG_FOR_STD_SERVER_MSG):
+ training._TrainingExecutor(
+ mock_est, mock_train_spec, mock_eval_spec).run_master()
+
+ def test_fail_with_none_task_id(self):
+ mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
+ mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
+ mock_eval_spec = test.mock.Mock(spec=training.EvalSpec)
+
+ mock_est.config = test.mock.PropertyMock(spec=run_config_lib.RunConfig)
+ mock_est.config.cluster_spec = {'worker': 'dummy'}
+ mock_est.config.master = 'grpc://...'
+ mock_est.config.task_type = 'worker'
+ mock_est.config.task_id = None
+
+ with self.assertRaisesRegexp(RuntimeError,
+ _INVALID_CONFIG_FOR_STD_SERVER_MSG):
+ training._TrainingExecutor(
+ mock_est, mock_train_spec, mock_eval_spec).run_master()
+
@test.mock.patch.object(server_lib, 'Server')
- def test_run_master_triggers_evaluate(self, _):
+ def test_run_master_triggers_evaluate_and_export(self, _):
def estimator_train(saving_listeners, *args, **kwargs):
# There shalt be a saving_listener. Estimator is going to call
# `after_save`.
del args, kwargs
+ saving_listeners[0].begin()
saving_listeners[0].after_save(session=None, global_step_value=None)
mock_est = test.mock.Mock(
@@ -730,18 +839,14 @@ class TrainingExecutorRunMasterTest(_TrainingExecutorTrainingTest,
mock_est.latest_checkpoint.return_value = 'checkpoint_path/'
mock_est.config = self._run_config
- def export(estimator, *args, **kwargs):
- del args, kwargs
- estimator.export_was_called = True
-
exporter = test.mock.PropertyMock(spec=exporter_lib.Exporter)
exporter.name = 'see_whether_export_is_called'
- exporter.export = export
train_spec = training.TrainSpec(input_fn=lambda: 1, max_steps=300)
eval_spec = training.EvalSpec(
input_fn=lambda: 1, steps=2, exporters=exporter)
- mock_est.evaluate.return_value = {_GLOBAL_STEP_KEY: train_spec.max_steps}
+ eval_result = {_GLOBAL_STEP_KEY: train_spec.max_steps}
+ mock_est.evaluate.return_value = eval_result
executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)
executor.run_master()
@@ -752,7 +857,109 @@ class TrainingExecutorRunMasterTest(_TrainingExecutorTrainingTest,
steps=eval_spec.steps,
checkpoint_path='checkpoint_path/',
hooks=eval_spec.hooks)
- self.assertTrue(mock_est.export_was_called)
+ self.assertEqual(1, exporter.export.call_count)
+ exporter.export.assert_called_with(
+ estimator=mock_est,
+ export_path=os.path.join('path/', 'export', exporter.name),
+ checkpoint_path='checkpoint_path/',
+ eval_result=eval_result,
+ is_the_final_export=True)
+
+ @test.mock.patch.object(basic_session_run_hooks, 'SecondOrStepTimer')
+ @test.mock.patch.object(server_lib, 'Server')
+ def test_run_master_throttle_eval(self, _, mock_timer_class):
+ mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/')
+
+ mock_timer = test.mock.Mock()
+ mock_timer_class.return_value = mock_timer
+
+ def estimator_train(saving_listeners, *args, **kwargs):
+ del args, kwargs
+ saving_listeners[0].begin()
+
+ # Call three times.
+ mock_timer.should_trigger_for_step.return_value = True
+ saving_listeners[0].after_save(session=None, global_step_value=None)
+
+ mock_timer.should_trigger_for_step.return_value = False
+ saving_listeners[0].after_save(session=None, global_step_value=None)
+
+ mock_timer.should_trigger_for_step.return_value = True
+ saving_listeners[0].after_save(session=None, global_step_value=None)
+
+ mock_est.train = estimator_train
+ mock_est.latest_checkpoint.side_effect = ['ckpt1', 'ckpt2']
+ mock_est.config = self._run_config
+
+ exporter = test.mock.PropertyMock(spec=exporter_lib.Exporter)
+ exporter.name = 'see_whether_export_is_called'
+
+ train_spec = training.TrainSpec(input_fn=lambda: 1, max_steps=300)
+ eval_spec = training.EvalSpec(
+ input_fn=lambda: 1, steps=2, exporters=exporter, throttle_secs=10)
+
+ mock_est.evaluate.side_effect = [
+ {_GLOBAL_STEP_KEY: train_spec.max_steps //2},
+ {_GLOBAL_STEP_KEY: train_spec.max_steps}
+ ]
+
+ executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)
+ executor.run_master()
+
+ self.assertEqual(2, mock_est.evaluate.call_count)
+ self.assertEqual(2, exporter.export.call_count)
+
+ is_final_export_list = [call[1]['is_the_final_export']
+ for call in exporter.export.call_args_list]
+ self.assertEqual([False, True], is_final_export_list)
+
+ @test.mock.patch.object(basic_session_run_hooks, 'SecondOrStepTimer')
+ @test.mock.patch.object(server_lib, 'Server')
+ def test_run_master_throttle_eval_which_skips_final_ckpt(
+ self, _, mock_timer_class):
+ mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/')
+
+ mock_timer = test.mock.Mock()
+ mock_timer_class.return_value = mock_timer
+
+ def estimator_train(saving_listeners, *args, **kwargs):
+ del args, kwargs
+ saving_listeners[0].begin()
+
+ # Call two times.
+ mock_timer.should_trigger_for_step.return_value = True
+ saving_listeners[0].after_save(session=None, global_step_value=None)
+
+ # The final ckpt is skipped by the timer. It will be picked up the final
+ # export check in the code.
+ mock_timer.should_trigger_for_step.return_value = False
+ saving_listeners[0].after_save(session=None, global_step_value=None)
+
+ mock_est.train = estimator_train
+ mock_est.latest_checkpoint.side_effect = ['ckpt1', 'ckpt2']
+ mock_est.config = self._run_config
+
+ exporter = test.mock.PropertyMock(spec=exporter_lib.Exporter)
+ exporter.name = 'see_whether_export_is_called'
+
+ train_spec = training.TrainSpec(input_fn=lambda: 1, max_steps=300)
+ eval_spec = training.EvalSpec(
+ input_fn=lambda: 1, steps=2, exporters=exporter, throttle_secs=10)
+
+ mock_est.evaluate.side_effect = [
+ {_GLOBAL_STEP_KEY: train_spec.max_steps //2},
+ {_GLOBAL_STEP_KEY: train_spec.max_steps}
+ ]
+
+ executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)
+ executor.run_master()
+
+ self.assertEqual(2, mock_est.evaluate.call_count)
+ self.assertEqual(2, exporter.export.call_count)
+
+ is_final_export_list = [call[1]['is_the_final_export']
+ for call in exporter.export.call_args_list]
+ self.assertEqual([False, True], is_final_export_list)
class TrainingExecutorRunEvaluatorTest(test.TestCase):
@@ -803,6 +1010,19 @@ class TrainingExecutorRunEvaluatorTest(test.TestCase):
exporter = test.mock.PropertyMock(spec=exporter_lib.Exporter)
exporter.name = 'see_how_many_times_export_is_called'
+ mock_est.times_export_was_called = 0
+ mock_est.times_final_export_was_called = 0
+ def export(estimator, export_path, checkpoint_path, eval_result,
+ is_the_final_export):
+ del export_path, checkpoint_path, eval_result
+ estimator.times_export_was_called += 1
+ # final_export is happend at the end.
+ self.assertEqual(0, estimator.times_final_export_was_called)
+ if is_the_final_export:
+ estimator.times_final_export_was_called += 1
+
+ exporter.export = export
+
eval_spec = training.EvalSpec(
input_fn=lambda: 1,
start_delay_secs=0,
@@ -813,7 +1033,8 @@ class TrainingExecutorRunEvaluatorTest(test.TestCase):
executor.run_evaluator()
self.assertEqual(2, mock_est.evaluate.call_count)
- self.assertEqual(2, exporter.export.call_count)
+ self.assertEqual(2, mock_est.times_export_was_called)
+ self.assertEqual(1, mock_est.times_final_export_was_called)
def test_final_export_is_true_in_the_end(self):
training_max_step = 200
@@ -1135,9 +1356,15 @@ class TrainingExecutorRunLocalTest(test.TestCase):
mock_est.latest_checkpoint = self.unique_checkpoint_every_time_fn
mock_est.times_export_was_called = 0
- def export(estimator, *args, **kwargs):
- del args, kwargs
+ mock_est.times_final_export_was_called = 0
+ def export(estimator, export_path, checkpoint_path, eval_result,
+ is_the_final_export):
+ del export_path, checkpoint_path, eval_result
estimator.times_export_was_called += 1
+ # final_export is happend at the end.
+ self.assertEqual(0, estimator.times_final_export_was_called)
+ if is_the_final_export:
+ estimator.times_final_export_was_called += 1
exporter = test.mock.PropertyMock(spec=exporter_lib.Exporter)
exporter.name = 'see_how_many_times_export_is_called'
@@ -1165,6 +1392,7 @@ class TrainingExecutorRunLocalTest(test.TestCase):
self.assertEqual(3, mock_est.train.call_count)
self.assertEqual(3, mock_est.evaluate.call_count)
self.assertEqual(3, mock_est.times_export_was_called)
+ self.assertEqual(1, mock_est.times_final_export_was_called)
def test_handles_no_new_checkpoint_found(self):
mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/')