diff options
author | Gunhan Gulsoy <gunan@google.com> | 2017-10-03 14:59:25 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-10-03 15:07:25 -0700 |
commit | 435b31b9fcbb9aeeebf80ee7ca0a154a0e99b826 (patch) | |
tree | d17af7fc3581cd34d681502568bf0dee1f0e0bfe /tensorflow/python | |
parent | 66df43d09c99207a06f4f697b9baa6a77857e565 (diff) |
Automated g4 rollback of changelist 170892257
PiperOrigin-RevId: 170919783
Diffstat (limited to 'tensorflow/python')
-rw-r--r-- | tensorflow/python/estimator/training.py | 4 | ||||
-rw-r--r-- | tensorflow/python/estimator/training_test.py | 22 |
2 files changed, 0 insertions, 26 deletions
diff --git a/tensorflow/python/estimator/training.py b/tensorflow/python/estimator/training.py index f3d1aca717..f4ccea6806 100644 --- a/tensorflow/python/estimator/training.py +++ b/tensorflow/python/estimator/training.py @@ -392,10 +392,6 @@ class _TrainingExecutor(object): metrics = evaluator.evaluate_and_export() - if not metrics: - # This is unexpected. Training should always end with a new checkpoint. - raise RuntimeError('There was no new checkpoint after the training.') - if _should_stop_local_train(metrics[ops.GraphKeys.GLOBAL_STEP]): break diff --git a/tensorflow/python/estimator/training_test.py b/tensorflow/python/estimator/training_test.py index 39c8bffb04..f5b4f88479 100644 --- a/tensorflow/python/estimator/training_test.py +++ b/tensorflow/python/estimator/training_test.py @@ -50,7 +50,6 @@ _INVALID_NAME_MSG = '`name` must be string' _INVALID_EVAL_DELAY_SECS_MSG = 'Must specify delay_secs >= 0' _INVALID_EVAL_THROTTLE_SECS_MSG = 'Must specify throttle_secs >= 0' _INVALID_ESTIMATOR_MSG = '`estimator` must have type `tf.estimator.Estimator`' -_STALE_CHECKPOINT_MSG = 'There was no new checkpoint after the training.' _INVALID_EXPORT_STRATEGY_MSG = '`export_strategies` must be an ExportStrategy' _DUPLICATE_STRATEGY_NAMES_MSG = '`export_strategies` must have unique names.' _INVALID_TRAIN_SPEC_MSG = '`train_spec` must have type `tf.estimator.TrainSpec`' @@ -1025,27 +1024,6 @@ class TrainingExecutorRunLocalTest(test.TestCase): self.assertEqual(3, mock_est.evaluate.call_count) self.assertEqual(3, mock_est.times_export_fn_was_called) - def test_handles_no_new_checkpoint_found(self): - mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/') - mock_est.latest_checkpoint.return_value = ( - 'no_new_checkpoints_after_the_first_train_step') - train_spec = training.TrainSpec( - input_fn=lambda: 1, max_steps=300, hooks=[_FakeHook()]) - eval_spec = training.EvalSpec( - input_fn=lambda: 1, hooks=[_FakeHook()], throttle_secs=100) - # It was going to be called 3 times. - mock_est.evaluate.side_effect = [{ - _GLOBAL_STEP_KEY: train_spec.max_steps - 100 - }, { - _GLOBAL_STEP_KEY: train_spec.max_steps - 50 - }, { - _GLOBAL_STEP_KEY: train_spec.max_steps - }] - - executor = training._TrainingExecutor(mock_est, train_spec, eval_spec) - with self.assertRaisesRegexp(RuntimeError, _STALE_CHECKPOINT_MSG): - executor.run_local() - def test_train_and_evaluate_args(self): mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/') mock_est.latest_checkpoint.return_value = 'checkpoint_path/' |