aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Jianwei Xie <xiejw@google.com>2018-08-20 10:16:53 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-20 10:22:07 -0700
commite39ee345009c74b47ec33de64af6955e4677768b (patch)
tree3255a4b55f7b08399bc35874564ab2e8b3ea861b
parentc09a0176795a321b13caaa3126541290c235c95f (diff)
Improves the error essage for throttle_secs=0 case.
User should be aware that the job load might be increased a lot due to this settings. PiperOrigin-RevId: 209439297
-rw-r--r--tensorflow/python/estimator/training.py7
-rw-r--r--tensorflow/python/estimator/training_test.py33
2 files changed, 39 insertions, 1 deletions
diff --git a/tensorflow/python/estimator/training.py b/tensorflow/python/estimator/training.py
index 5c04387b65..e6bd263c80 100644
--- a/tensorflow/python/estimator/training.py
+++ b/tensorflow/python/estimator/training.py
@@ -837,6 +837,13 @@ class _TrainingExecutor(object):
if difference > 0:
logging.info('Waiting %f secs before starting next eval run.', difference)
time.sleep(difference)
+ elif (throttle_secs == 0 and
+ eval_result.status != _EvalStatus.EVALUATED):
+ # Prints a user-actionable warning to avoid unnecessary load on evaluator.
+ logging.warning(
+ 'EvalSpec.throttle_secs is set as 0. This might overload the job '
+ 'before finding (next) new checkpoint. Please consider to increase '
+ 'it.')
return (eval_result, should_early_stop)
diff --git a/tensorflow/python/estimator/training_test.py b/tensorflow/python/estimator/training_test.py
index dc106c7d3b..7d46917a6f 100644
--- a/tensorflow/python/estimator/training_test.py
+++ b/tensorflow/python/estimator/training_test.py
@@ -83,6 +83,9 @@ _INVALID_EVAL_LISTENER_MSG = 'must have type `_ContinuousEvalListener`'
_INVALID_CONFIG_FOR_STD_SERVER_MSG = 'Could not start server; .*TF_CONFIG'
_INVALID_LOCAL_TASK_WITH_CLUSTER = '`task.type` in TF_CONFIG cannot be `local`'
_INVALID_TASK_TYPE = '`estimator.config` must have task_type set.'
+_INPROPER_THROTTL_SECS = (
+ 'EvalSpec.throttle_secs is set as 0.*Please consider to increase')
+
# The message should NOT have 'local' word as part of it. As (?!word) is looking
# ahead, so, the $ (ending) check is required; otherwise, it will match
# partially and return successuful.
@@ -1281,7 +1284,7 @@ class TrainingExecutorRunEvaluatorTest(test.TestCase):
]
eval_spec = training.EvalSpec(
- input_fn=lambda: 1, start_delay_secs=0, throttle_secs=0)
+ input_fn=lambda: 1, start_delay_secs=0, throttle_secs=2)
executor = training._TrainingExecutor(mock_est, mock_train_spec, eval_spec)
with test.mock.patch.object(logging, 'warning') as mock_log:
@@ -1295,6 +1298,34 @@ class TrainingExecutorRunEvaluatorTest(test.TestCase):
# successuful evaluation)
self.assertEqual(2, mock_log.call_count)
+ def test_warning_if_throttle_secs_is_zero(self):
+ training_max_step = 200
+ mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
+ mock_est.evaluate.side_effect = [
+ {_GLOBAL_STEP_KEY: training_max_step}
+ ]
+ mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
+ mock_train_spec.max_steps = training_max_step
+
+ self._set_up_mock_est_to_train_and_evaluate_once(mock_est, mock_train_spec)
+
+ # We need to make the first one invalid, so it will check the
+ # throttle_secs=0.
+ mock_est.latest_checkpoint.side_effect = [None, 'path']
+
+ eval_spec = training.EvalSpec(
+ input_fn=lambda: 1, start_delay_secs=0, throttle_secs=0)
+
+ executor = training._TrainingExecutor(mock_est, mock_train_spec, eval_spec)
+ with test.mock.patch.object(logging, 'warning') as mock_log:
+ executor.run_evaluator()
+
+ # First ckpt is invalid.
+ self.assertEqual(2, mock_est.latest_checkpoint.call_count)
+ self.assertEqual(1, mock_est.evaluate.call_count)
+
+ self.assertRegexpMatches(str(mock_log.call_args), _INPROPER_THROTTL_SECS)
+
def test_continuous_eval_listener_eval_result(self):
training_max_step = 200
mock_est = test.mock.Mock(spec=estimator_lib.Estimator)