diff options
author | Jianwei Xie <xiejw@google.com> | 2017-10-05 12:58:51 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-10-05 13:07:40 -0700 |
commit | 631d3434ff33debfd0bf46d9d8602172f549c82d (patch) | |
tree | e03196bd1b8e35d5fc4e85bacde43dc3b215f7c0 /tensorflow/python | |
parent | a429d07bf545b5fd25c44f95fd50e012440bf99b (diff) |
Adds throlle_secs into run_master
PiperOrigin-RevId: 171194766
Diffstat (limited to 'tensorflow/python')
-rw-r--r-- | tensorflow/python/estimator/training.py | 74 | ||||
-rw-r--r-- | tensorflow/python/estimator/training_test.py | 268 |
2 files changed, 307 insertions, 35 deletions
diff --git a/tensorflow/python/estimator/training.py b/tensorflow/python/estimator/training.py index 5c0ebbea35..64b014a6b5 100644 --- a/tensorflow/python/estimator/training.py +++ b/tensorflow/python/estimator/training.py @@ -519,23 +519,51 @@ class _TrainingExecutor(object): class NewCheckpointListener( basic_session_run_hooks.CheckpointSaverListener): - def __init__(self, estimator, eval_spec, max_training_steps): - # pylint: disable=protected-access - self._evaluator = _TrainingExecutor._Evaluator(estimator, eval_spec, - max_training_steps) - # pylint: enable=protected-access + def __init__(self, evaluator, eval_throttle_secs): + self._evaluator = evaluator + self._eval_throttle_secs = eval_throttle_secs + + def begin(self): + self._timer = basic_session_run_hooks.SecondOrStepTimer( + every_secs=self._eval_throttle_secs) def after_save(self, session, global_step_value): - del session, global_step_value - self._evaluator.evaluate_and_export() + del session # unused; required by signature. + + if self._timer.should_trigger_for_step(global_step_value): + self._timer.update_last_triggered_step(global_step_value) + self._evaluator.evaluate_and_export() + else: + logging.info( + 'Skip the current checkpoint eval due to throttle secs ' + '({} secs).'.format(self._eval_throttle_secs)) + + # Final export signal: For any eval result with global_step >= train + # max_steps, the evaluator will send the final export signal. There is a + # small chance that the Estimator.train stopping logic sees a different + # global_step value (due to global step race condition and the fact the + # saver sees a larger value for checkpoing saving), which does not end + # the training. When the training ends, a new checkpoint is generated, which + # triggers the listener again. So, it could be the case the final export is + # triggered twice. + # + # But here, throttle_secs will skip the next intermediate checkpoint and, + # so, the double final export chance is very small. + evaluator = _TrainingExecutor._Evaluator( + self._estimator, self._eval_spec, self._train_spec.max_steps) # When the underlying `Estimator` object saves a new checkpoint, we would # like this callback to be called so that evaluation and export can trigger. saving_listeners = [ - NewCheckpointListener(self._estimator, self._eval_spec, - self._train_spec.max_steps) + NewCheckpointListener(evaluator, self._eval_spec.throttle_secs) ] - return self._start_distributed_training(saving_listeners=saving_listeners) + self._start_distributed_training(saving_listeners=saving_listeners) + + if not evaluator.is_final_export_triggered: + logging.info('Training has already ended. But the last eval is skipped ' + 'due to eval throttle_secs. Now evaluating the final ' + 'checkpoint.') + evaluator.evaluate_and_export() def run_evaluator(self): """Runs task evaluator.""" @@ -580,6 +608,11 @@ class _TrainingExecutor(object): max_steps=self._train_spec.max_steps, hooks=train_hooks) + # Final export signal: For any eval result with global_step >= train + # max_steps, the evaluator will send the final export signal. The + # _should_stop_local_train will then end the while True as the stopping + # condition is satisfied (both checks use the same global_step value, + # i.e., no race condition) metrics = evaluator.evaluate_and_export() if not metrics: @@ -656,6 +689,11 @@ class _TrainingExecutor(object): self._train_spec.max_steps) return + # Final export signal: For any eval result with global_step >= train + # max_steps, the evaluator will send the final export signal. The next + # iteration of while loop will end the continuous eval as the stopping + # condition is satisfied (both checks use the same global_step value, + # i.e., no race condition) start = time.time() latest_eval_result = evaluator.evaluate_and_export() @@ -673,10 +711,15 @@ class _TrainingExecutor(object): def __init__(self, estimator, eval_spec, max_training_steps): self._estimator = estimator self._eval_spec = eval_spec + self._is_final_export_triggered = False self._previous_ckpt_path = None self._last_warning_time = 0 self._max_training_steps = max_training_steps + @property + def is_final_export_triggered(self): + return self._is_final_export_triggered + def evaluate_and_export(self): """Evaluate and (maybe) export the current model. @@ -720,15 +763,16 @@ class _TrainingExecutor(object): 'Internal error: `Estimator.evaluate` result should have ' '`global_step` in result. Given {}'.format(eval_result)) - # TODO(isaprykin): There is a potential race condition here in the - # distributed setting. The worker job that performs training - # might stop at a later global step value than the evalutor job. is_the_final_export = (eval_result[ops.GraphKeys.GLOBAL_STEP] >= self._max_training_steps if self._max_training_steps else False) self._export_eval_result(eval_result, latest_ckpt_path, is_the_final_export) + if is_the_final_export: + logging.debug('Calling exporter with the `is_the_final_export=True`.') + self._is_final_export_triggered = True + self._last_warning_time = 0 self._previous_ckpt_path = latest_ckpt_path return eval_result @@ -749,8 +793,8 @@ class _TrainingExecutor(object): for exporter in self._eval_spec.exporters: exporter.export( - self._estimator, - os.path.join( + estimator=self._estimator, + export_path=os.path.join( compat.as_str_any(export_dir_base), compat.as_str_any(exporter.name)), checkpoint_path=checkpoint_path, diff --git a/tensorflow/python/estimator/training_test.py b/tensorflow/python/estimator/training_test.py index 40972ab5a0..8c00ebddf3 100644 --- a/tensorflow/python/estimator/training_test.py +++ b/tensorflow/python/estimator/training_test.py @@ -45,6 +45,7 @@ from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging as logging from tensorflow.python.summary import summary_iterator from tensorflow.python.summary.writer import writer_cache +from tensorflow.python.training import basic_session_run_hooks from tensorflow.python.training import monitored_session from tensorflow.python.training import server_lib from tensorflow.python.training import session_run_hook @@ -692,37 +693,145 @@ class TrainingExecutorRunChiefTest(_TrainingExecutorTrainingTest, mock_sleep.assert_not_called() -class TrainingExecutorRunMasterTest(_TrainingExecutorTrainingTest, - test.TestCase): +class TrainingExecutorRunMasterTest(test.TestCase): """Tests run_chief of _TrainingExecutor.""" - def __init__(self, methodName='runTest'): # pylint: disable=invalid-name - test.TestCase.__init__(self, methodName) - _TrainingExecutorTrainingTest.__init__( - self, - run_config=_create_run_config_with_cluster_spec(_TF_CONFIG_FOR_MASTER)) + def setUp(self): + self._run_config = _create_run_config_with_cluster_spec( + _TF_CONFIG_FOR_MASTER) @test.mock.patch.object(server_lib, 'Server') def test_no_delay_for_master(self, _): mock_est = test.mock.Mock(spec=estimator_lib.Estimator) + mock_est.evaluate = lambda *args, **kw: {ops.GraphKeys.GLOBAL_STEP: 123} mock_est.config = self._run_config mock_train_spec = test.mock.Mock(spec=training.TrainSpec) - mock_eval_spec = test.mock.Mock(spec=training.EvalSpec) + mock_eval_spec = test.mock.Mock(spec=training.EvalSpec, exporters=[]) executor = training._TrainingExecutor(mock_est, mock_train_spec, mock_eval_spec) with test.mock.patch.object(time, 'sleep') as mock_sleep: - self._run_task(executor) + executor.run_master() mock_sleep.assert_not_called() + @test.mock.patch.object(time, 'sleep') + @test.mock.patch.object(server_lib, 'Server') + def test_train_with_train_spec(self, mock_server, unused_mock_sleep): + mock_est = test.mock.Mock(spec=estimator_lib.Estimator) + mock_est.evaluate = lambda *args, **kw: {ops.GraphKeys.GLOBAL_STEP: 123} + mock_est.config = self._run_config + train_spec = training.TrainSpec( + input_fn=lambda: 1, max_steps=2, hooks=[_FakeHook()]) + mock_eval_spec = test.mock.Mock(spec=training.EvalSpec, exporters=[]) + mock_server_instance = mock_server.return_value + + executor = training._TrainingExecutor(mock_est, train_spec, mock_eval_spec) + executor.run_master() + + mock_server.assert_called_with( + mock_est.config.cluster_spec, + job_name=mock_est.config.task_type, + task_index=mock_est.config.task_id, + config=test.mock.ANY, + start=False) + + self.assertTrue(mock_server_instance.start.called) + + mock_est.train.assert_called_with(input_fn=train_spec.input_fn, + max_steps=train_spec.max_steps, + hooks=train_spec.hooks, + saving_listeners=test.mock.ANY) + mock_est.export_savedmodel.assert_not_called() + + @test.mock.patch.object(time, 'sleep') + @test.mock.patch.object(server_lib, 'Server') + def test_no_server_startup_in_google(self, mock_server, unused_mock_sleep): + mock_est = test.mock.Mock(spec=estimator_lib.Estimator) + mock_est.evaluate = lambda *args, **kw: {ops.GraphKeys.GLOBAL_STEP: 123} + mock_est.config = self._run_config + mock_train_spec = test.mock.Mock(spec=training.TrainSpec) + mock_eval_spec = test.mock.Mock(spec=training.EvalSpec, exporters=[]) + + executor = training._TrainingExecutor(mock_est, mock_train_spec, + mock_eval_spec) + tf_config = {'TF_CONFIG': json.dumps(_TF_CONFIG_FOR_GOOGLE)} + with test.mock.patch.dict('os.environ', tf_config): + executor.run_master() + mock_server.assert_not_called() + + def test_fail_with_empty_cluster_spec(self): + mock_est = test.mock.Mock(spec=estimator_lib.Estimator) + mock_train_spec = test.mock.Mock(spec=training.TrainSpec) + mock_eval_spec = test.mock.Mock(spec=training.EvalSpec) + + mock_est.config = test.mock.PropertyMock(spec=run_config_lib.RunConfig) + mock_est.config.cluster_spec = None + mock_est.config.master = 'grpc://...' + mock_est.config.task_type = 'worker' + mock_est.config.task_id = 2 + + with self.assertRaisesRegexp(RuntimeError, + _INVALID_CONFIG_FOR_STD_SERVER_MSG): + training._TrainingExecutor( + mock_est, mock_train_spec, mock_eval_spec).run_master() + + def test_fail_with_empty_master(self): + mock_est = test.mock.Mock(spec=estimator_lib.Estimator) + mock_train_spec = test.mock.Mock(spec=training.TrainSpec) + mock_eval_spec = test.mock.Mock(spec=training.EvalSpec) + + mock_est.config = test.mock.PropertyMock(spec=run_config_lib.RunConfig) + mock_est.config.cluster_spec = {'worker': 'dummy'} + mock_est.config.master = '' + mock_est.config.task_type = 'worker' + mock_est.config.task_id = 2 + + with self.assertRaisesRegexp(RuntimeError, + _INVALID_CONFIG_FOR_STD_SERVER_MSG): + training._TrainingExecutor( + mock_est, mock_train_spec, mock_eval_spec).run_master() + + def test_fail_with_empty_task_type(self): + mock_est = test.mock.Mock(spec=estimator_lib.Estimator) + mock_train_spec = test.mock.Mock(spec=training.TrainSpec) + mock_eval_spec = test.mock.Mock(spec=training.EvalSpec) + + mock_est.config = test.mock.PropertyMock(spec=run_config_lib.RunConfig) + mock_est.config.cluster_spec = {'worker': 'dummy'} + mock_est.config.master = 'grpc://...' + mock_est.config.task_type = '' + mock_est.config.task_id = 2 + + with self.assertRaisesRegexp(RuntimeError, + _INVALID_CONFIG_FOR_STD_SERVER_MSG): + training._TrainingExecutor( + mock_est, mock_train_spec, mock_eval_spec).run_master() + + def test_fail_with_none_task_id(self): + mock_est = test.mock.Mock(spec=estimator_lib.Estimator) + mock_train_spec = test.mock.Mock(spec=training.TrainSpec) + mock_eval_spec = test.mock.Mock(spec=training.EvalSpec) + + mock_est.config = test.mock.PropertyMock(spec=run_config_lib.RunConfig) + mock_est.config.cluster_spec = {'worker': 'dummy'} + mock_est.config.master = 'grpc://...' + mock_est.config.task_type = 'worker' + mock_est.config.task_id = None + + with self.assertRaisesRegexp(RuntimeError, + _INVALID_CONFIG_FOR_STD_SERVER_MSG): + training._TrainingExecutor( + mock_est, mock_train_spec, mock_eval_spec).run_master() + @test.mock.patch.object(server_lib, 'Server') - def test_run_master_triggers_evaluate(self, _): + def test_run_master_triggers_evaluate_and_export(self, _): def estimator_train(saving_listeners, *args, **kwargs): # There shalt be a saving_listener. Estimator is going to call # `after_save`. del args, kwargs + saving_listeners[0].begin() saving_listeners[0].after_save(session=None, global_step_value=None) mock_est = test.mock.Mock( @@ -730,18 +839,14 @@ class TrainingExecutorRunMasterTest(_TrainingExecutorTrainingTest, mock_est.latest_checkpoint.return_value = 'checkpoint_path/' mock_est.config = self._run_config - def export(estimator, *args, **kwargs): - del args, kwargs - estimator.export_was_called = True - exporter = test.mock.PropertyMock(spec=exporter_lib.Exporter) exporter.name = 'see_whether_export_is_called' - exporter.export = export train_spec = training.TrainSpec(input_fn=lambda: 1, max_steps=300) eval_spec = training.EvalSpec( input_fn=lambda: 1, steps=2, exporters=exporter) - mock_est.evaluate.return_value = {_GLOBAL_STEP_KEY: train_spec.max_steps} + eval_result = {_GLOBAL_STEP_KEY: train_spec.max_steps} + mock_est.evaluate.return_value = eval_result executor = training._TrainingExecutor(mock_est, train_spec, eval_spec) executor.run_master() @@ -752,7 +857,109 @@ class TrainingExecutorRunMasterTest(_TrainingExecutorTrainingTest, steps=eval_spec.steps, checkpoint_path='checkpoint_path/', hooks=eval_spec.hooks) - self.assertTrue(mock_est.export_was_called) + self.assertEqual(1, exporter.export.call_count) + exporter.export.assert_called_with( + estimator=mock_est, + export_path=os.path.join('path/', 'export', exporter.name), + checkpoint_path='checkpoint_path/', + eval_result=eval_result, + is_the_final_export=True) + + @test.mock.patch.object(basic_session_run_hooks, 'SecondOrStepTimer') + @test.mock.patch.object(server_lib, 'Server') + def test_run_master_throttle_eval(self, _, mock_timer_class): + mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/') + + mock_timer = test.mock.Mock() + mock_timer_class.return_value = mock_timer + + def estimator_train(saving_listeners, *args, **kwargs): + del args, kwargs + saving_listeners[0].begin() + + # Call three times. + mock_timer.should_trigger_for_step.return_value = True + saving_listeners[0].after_save(session=None, global_step_value=None) + + mock_timer.should_trigger_for_step.return_value = False + saving_listeners[0].after_save(session=None, global_step_value=None) + + mock_timer.should_trigger_for_step.return_value = True + saving_listeners[0].after_save(session=None, global_step_value=None) + + mock_est.train = estimator_train + mock_est.latest_checkpoint.side_effect = ['ckpt1', 'ckpt2'] + mock_est.config = self._run_config + + exporter = test.mock.PropertyMock(spec=exporter_lib.Exporter) + exporter.name = 'see_whether_export_is_called' + + train_spec = training.TrainSpec(input_fn=lambda: 1, max_steps=300) + eval_spec = training.EvalSpec( + input_fn=lambda: 1, steps=2, exporters=exporter, throttle_secs=10) + + mock_est.evaluate.side_effect = [ + {_GLOBAL_STEP_KEY: train_spec.max_steps //2}, + {_GLOBAL_STEP_KEY: train_spec.max_steps} + ] + + executor = training._TrainingExecutor(mock_est, train_spec, eval_spec) + executor.run_master() + + self.assertEqual(2, mock_est.evaluate.call_count) + self.assertEqual(2, exporter.export.call_count) + + is_final_export_list = [call[1]['is_the_final_export'] + for call in exporter.export.call_args_list] + self.assertEqual([False, True], is_final_export_list) + + @test.mock.patch.object(basic_session_run_hooks, 'SecondOrStepTimer') + @test.mock.patch.object(server_lib, 'Server') + def test_run_master_throttle_eval_which_skips_final_ckpt( + self, _, mock_timer_class): + mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/') + + mock_timer = test.mock.Mock() + mock_timer_class.return_value = mock_timer + + def estimator_train(saving_listeners, *args, **kwargs): + del args, kwargs + saving_listeners[0].begin() + + # Call two times. + mock_timer.should_trigger_for_step.return_value = True + saving_listeners[0].after_save(session=None, global_step_value=None) + + # The final ckpt is skipped by the timer. It will be picked up the final + # export check in the code. + mock_timer.should_trigger_for_step.return_value = False + saving_listeners[0].after_save(session=None, global_step_value=None) + + mock_est.train = estimator_train + mock_est.latest_checkpoint.side_effect = ['ckpt1', 'ckpt2'] + mock_est.config = self._run_config + + exporter = test.mock.PropertyMock(spec=exporter_lib.Exporter) + exporter.name = 'see_whether_export_is_called' + + train_spec = training.TrainSpec(input_fn=lambda: 1, max_steps=300) + eval_spec = training.EvalSpec( + input_fn=lambda: 1, steps=2, exporters=exporter, throttle_secs=10) + + mock_est.evaluate.side_effect = [ + {_GLOBAL_STEP_KEY: train_spec.max_steps //2}, + {_GLOBAL_STEP_KEY: train_spec.max_steps} + ] + + executor = training._TrainingExecutor(mock_est, train_spec, eval_spec) + executor.run_master() + + self.assertEqual(2, mock_est.evaluate.call_count) + self.assertEqual(2, exporter.export.call_count) + + is_final_export_list = [call[1]['is_the_final_export'] + for call in exporter.export.call_args_list] + self.assertEqual([False, True], is_final_export_list) class TrainingExecutorRunEvaluatorTest(test.TestCase): @@ -803,6 +1010,19 @@ class TrainingExecutorRunEvaluatorTest(test.TestCase): exporter = test.mock.PropertyMock(spec=exporter_lib.Exporter) exporter.name = 'see_how_many_times_export_is_called' + mock_est.times_export_was_called = 0 + mock_est.times_final_export_was_called = 0 + def export(estimator, export_path, checkpoint_path, eval_result, + is_the_final_export): + del export_path, checkpoint_path, eval_result + estimator.times_export_was_called += 1 + # final_export is happend at the end. + self.assertEqual(0, estimator.times_final_export_was_called) + if is_the_final_export: + estimator.times_final_export_was_called += 1 + + exporter.export = export + eval_spec = training.EvalSpec( input_fn=lambda: 1, start_delay_secs=0, @@ -813,7 +1033,8 @@ class TrainingExecutorRunEvaluatorTest(test.TestCase): executor.run_evaluator() self.assertEqual(2, mock_est.evaluate.call_count) - self.assertEqual(2, exporter.export.call_count) + self.assertEqual(2, mock_est.times_export_was_called) + self.assertEqual(1, mock_est.times_final_export_was_called) def test_final_export_is_true_in_the_end(self): training_max_step = 200 @@ -1135,9 +1356,15 @@ class TrainingExecutorRunLocalTest(test.TestCase): mock_est.latest_checkpoint = self.unique_checkpoint_every_time_fn mock_est.times_export_was_called = 0 - def export(estimator, *args, **kwargs): - del args, kwargs + mock_est.times_final_export_was_called = 0 + def export(estimator, export_path, checkpoint_path, eval_result, + is_the_final_export): + del export_path, checkpoint_path, eval_result estimator.times_export_was_called += 1 + # final_export is happend at the end. + self.assertEqual(0, estimator.times_final_export_was_called) + if is_the_final_export: + estimator.times_final_export_was_called += 1 exporter = test.mock.PropertyMock(spec=exporter_lib.Exporter) exporter.name = 'see_how_many_times_export_is_called' @@ -1165,6 +1392,7 @@ class TrainingExecutorRunLocalTest(test.TestCase): self.assertEqual(3, mock_est.train.call_count) self.assertEqual(3, mock_est.evaluate.call_count) self.assertEqual(3, mock_est.times_export_was_called) + self.assertEqual(1, mock_est.times_final_export_was_called) def test_handles_no_new_checkpoint_found(self): mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/') |