diff options
-rw-r--r-- | tensorflow/python/estimator/training.py | 7 | ||||
-rw-r--r-- | tensorflow/python/estimator/training_test.py | 33 |
2 files changed, 39 insertions, 1 deletions
diff --git a/tensorflow/python/estimator/training.py b/tensorflow/python/estimator/training.py index 5c04387b65..e6bd263c80 100644 --- a/tensorflow/python/estimator/training.py +++ b/tensorflow/python/estimator/training.py @@ -837,6 +837,13 @@ class _TrainingExecutor(object): if difference > 0: logging.info('Waiting %f secs before starting next eval run.', difference) time.sleep(difference) + elif (throttle_secs == 0 and + eval_result.status != _EvalStatus.EVALUATED): + # Prints a user-actionable warning to avoid unnecessary load on evaluator. + logging.warning( + 'EvalSpec.throttle_secs is set as 0. This might overload the job ' + 'before finding (next) new checkpoint. Please consider to increase ' + 'it.') return (eval_result, should_early_stop) diff --git a/tensorflow/python/estimator/training_test.py b/tensorflow/python/estimator/training_test.py index dc106c7d3b..7d46917a6f 100644 --- a/tensorflow/python/estimator/training_test.py +++ b/tensorflow/python/estimator/training_test.py @@ -83,6 +83,9 @@ _INVALID_EVAL_LISTENER_MSG = 'must have type `_ContinuousEvalListener`' _INVALID_CONFIG_FOR_STD_SERVER_MSG = 'Could not start server; .*TF_CONFIG' _INVALID_LOCAL_TASK_WITH_CLUSTER = '`task.type` in TF_CONFIG cannot be `local`' _INVALID_TASK_TYPE = '`estimator.config` must have task_type set.' +_INPROPER_THROTTL_SECS = ( + 'EvalSpec.throttle_secs is set as 0.*Please consider to increase') + # The message should NOT have 'local' word as part of it. As (?!word) is looking # ahead, so, the $ (ending) check is required; otherwise, it will match # partially and return successuful. @@ -1281,7 +1284,7 @@ class TrainingExecutorRunEvaluatorTest(test.TestCase): ] eval_spec = training.EvalSpec( - input_fn=lambda: 1, start_delay_secs=0, throttle_secs=0) + input_fn=lambda: 1, start_delay_secs=0, throttle_secs=2) executor = training._TrainingExecutor(mock_est, mock_train_spec, eval_spec) with test.mock.patch.object(logging, 'warning') as mock_log: @@ -1295,6 +1298,34 @@ class TrainingExecutorRunEvaluatorTest(test.TestCase): # successuful evaluation) self.assertEqual(2, mock_log.call_count) + def test_warning_if_throttle_secs_is_zero(self): + training_max_step = 200 + mock_est = test.mock.Mock(spec=estimator_lib.Estimator) + mock_est.evaluate.side_effect = [ + {_GLOBAL_STEP_KEY: training_max_step} + ] + mock_train_spec = test.mock.Mock(spec=training.TrainSpec) + mock_train_spec.max_steps = training_max_step + + self._set_up_mock_est_to_train_and_evaluate_once(mock_est, mock_train_spec) + + # We need to make the first one invalid, so it will check the + # throttle_secs=0. + mock_est.latest_checkpoint.side_effect = [None, 'path'] + + eval_spec = training.EvalSpec( + input_fn=lambda: 1, start_delay_secs=0, throttle_secs=0) + + executor = training._TrainingExecutor(mock_est, mock_train_spec, eval_spec) + with test.mock.patch.object(logging, 'warning') as mock_log: + executor.run_evaluator() + + # First ckpt is invalid. + self.assertEqual(2, mock_est.latest_checkpoint.call_count) + self.assertEqual(1, mock_est.evaluate.call_count) + + self.assertRegexpMatches(str(mock_log.call_args), _INPROPER_THROTTL_SECS) + def test_continuous_eval_listener_eval_result(self): training_max_step = 200 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) |