diff options
author | 2018-01-02 15:05:38 -0800 | |
---|---|---|
committer | 2018-01-02 15:09:39 -0800 | |
commit | c25694086e185e844e684ec196aa13667a7c2406 (patch) | |
tree | 717097601a04e5b8c337c948fbe54a6fb4e26d22 /tensorflow/python/estimator/training_test.py | |
parent | a966bbca81509201e7d0e1e30fb4abfdd9dcad4d (diff) |
Adds _TrainingExecutor.run method to automatically invoke correct procedure.
PiperOrigin-RevId: 180598558
Diffstat (limited to 'tensorflow/python/estimator/training_test.py')
-rw-r--r-- | tensorflow/python/estimator/training_test.py | 284 |
1 files changed, 148 insertions, 136 deletions
diff --git a/tensorflow/python/estimator/training_test.py b/tensorflow/python/estimator/training_test.py index 9536ee44d5..2d3f5d6cef 100644 --- a/tensorflow/python/estimator/training_test.py +++ b/tensorflow/python/estimator/training_test.py @@ -312,61 +312,21 @@ class EvalSpecTest(test.TestCase): training.EvalSpec(input_fn=lambda: 1, exporters=_create_exporter(None)) -class TrainAndEvaluteTest(test.TestCase): +class TrainAndEvaluateTest(test.TestCase): - def _mock_executor_instance(self): - mock_instance = test.mock.Mock() - mock_instance.call_task = {} - - def task_fn(name): - def _fn(): - mock_instance.call_task[name] = 1 - return _fn - - mock_instance.run_chief = task_fn('chief') - mock_instance.run_master = task_fn('master') - mock_instance.run_ps = task_fn('ps') - mock_instance.run_evaluator = task_fn('evaluator') - mock_instance.run_worker = task_fn('worker') - mock_instance.run_local = task_fn('local') - - return mock_instance - - def _test_run_task_in_distributed_training(self, run_config): + def test_run_task(self): mock_est = test.mock.Mock(spec=estimator_lib.Estimator) - mock_est.config = run_config mock_train_spec = test.mock.Mock(spec=training.TrainSpec) mock_eval_spec = test.mock.Mock(spec=training.EvalSpec) with test.mock.patch.object(training, '_TrainingExecutor') as mock_executor: - mock_executor_instance = self._mock_executor_instance() + mock_executor_instance = test.mock.Mock() mock_executor.return_value = mock_executor_instance training.train_and_evaluate(mock_est, mock_train_spec, mock_eval_spec) mock_executor.assert_called_with(estimator=mock_est, train_spec=mock_train_spec, eval_spec=mock_eval_spec) - return mock_executor_instance - - def test_run_chief(self): - mock_executor = self._test_run_task_in_distributed_training( - run_config=_create_run_config_with_cluster_spec(_TF_CONFIG_FOR_CHIEF)) - self.assertEqual(1, mock_executor.call_task['chief']) - - def test_run_worker(self): - mock_executor = self._test_run_task_in_distributed_training( - run_config=_create_run_config_with_cluster_spec(_TF_CONFIG_FOR_WORKER)) - self.assertEqual(1, mock_executor.call_task['worker']) - - def test_run_ps(self): - mock_executor = self._test_run_task_in_distributed_training( - run_config=_create_run_config_with_cluster_spec(_TF_CONFIG_FOR_PS)) - self.assertEqual(1, mock_executor.call_task['ps']) - - def test_run_evaluator(self): - mock_executor = self._test_run_task_in_distributed_training( - run_config=_create_run_config_with_cluster_spec( - _TF_CONFIG_FOR_EVALUATOR)) - self.assertEqual(1, mock_executor.call_task['evaluator']) + mock_executor_instance.run.assert_called() def test_error_out_if_evaluator_task_id_is_non_zero(self): tf_config = { @@ -378,93 +338,15 @@ class TrainAndEvaluteTest(test.TestCase): 'index': 1 } } - with self.assertRaisesRegexp(ValueError, _INVALID_EVAL_TASK_ID_ERR): - self._test_run_task_in_distributed_training( - run_config=_create_run_config_with_cluster_spec(tf_config)) - def test_run_local(self): - mock_est = test.mock.Mock(spec=estimator_lib.Estimator) - mock_est.config = run_config_lib.RunConfig() - mock_train_spec = test.mock.Mock(spec=training.TrainSpec) - mock_eval_spec = test.mock.Mock(spec=training.EvalSpec) - - with test.mock.patch.object(training, '_TrainingExecutor') as mock_executor: - mock_executor_instance = self._mock_executor_instance() - mock_executor.return_value = mock_executor_instance - training.train_and_evaluate(mock_est, mock_train_spec, mock_eval_spec) - self.assertEqual(1, mock_executor_instance.call_task['local']) - - mock_executor.assert_called_with(estimator=mock_est, - train_spec=mock_train_spec, - eval_spec=mock_eval_spec) - - def test_invalid_local_task(self): - tf_config = { - 'cluster': { - run_config_lib.TaskType.CHIEF: ['host0:0'], - 'local': ['hos1:1'], - }, - 'task': { - 'type': 'local', - 'index': 0 - } - } mock_est = test.mock.Mock(spec=estimator_lib.Estimator) mock_est.config = _create_run_config_with_cluster_spec(tf_config) mock_train_spec = test.mock.Mock(spec=training.TrainSpec) mock_eval_spec = test.mock.Mock(spec=training.EvalSpec) - with self.assertRaisesRegexp(ValueError, _INVALID_LOCAL_TASK_WITH_CLUSTER): + with self.assertRaisesRegexp(ValueError, _INVALID_EVAL_TASK_ID_ERR): training.train_and_evaluate(mock_est, mock_train_spec, mock_eval_spec) - def test_unsupported_task_due_to_missing_run_task(self): - unsupported_task = 'alloc' - tf_config = { - 'cluster': { - run_config_lib.TaskType.CHIEF: ['host0:0'], - unsupported_task: ['hos1:1'], - }, - 'task': { - 'type': unsupported_task, - 'index': 0 - } - } - mock_est = test.mock.Mock(spec=estimator_lib.Estimator) - mock_est.config = _create_run_config_with_cluster_spec(tf_config) - mock_train_spec = test.mock.Mock(spec=training.TrainSpec) - mock_eval_spec = test.mock.Mock(spec=training.EvalSpec) - - with test.mock.patch.object(training, '_TrainingExecutor') as mock_executor: - # mock_instance has no run_alloc method. - mock_instance = self._mock_executor_instance() - mock_executor.return_value = mock_instance - with self.assertRaisesRegexp(ValueError, _INVALID_TASK_TO_RUN): - training.train_and_evaluate(mock_est, mock_train_spec, mock_eval_spec) - - def test_unsupported_task_due_to_not_callable(self): - unsupported_task = 'alloc' - tf_config = { - 'cluster': { - run_config_lib.TaskType.CHIEF: ['host0:0'], - unsupported_task: ['hos1:1'], - }, - 'task': { - 'type': unsupported_task, - 'index': 0 - } - } - mock_est = test.mock.Mock(spec=estimator_lib.Estimator) - mock_est.config = _create_run_config_with_cluster_spec(tf_config) - mock_train_spec = test.mock.Mock(spec=training.TrainSpec) - mock_eval_spec = test.mock.Mock(spec=training.EvalSpec) - - with test.mock.patch.object(training, '_TrainingExecutor') as mock_executor: - mock_instance = self._mock_executor_instance() - mock_instance.run_alloc = 123 # not callable - mock_executor.return_value = mock_instance - with self.assertRaisesRegexp(ValueError, _INVALID_TASK_TO_RUN): - training.train_and_evaluate(mock_est, mock_train_spec, mock_eval_spec) - def test_invalid_estimator(self): invalid_estimator = object() mock_train_spec = test.mock.Mock(spec=training.TrainSpec) @@ -474,19 +356,6 @@ class TrainAndEvaluteTest(test.TestCase): training.train_and_evaluate(invalid_estimator, mock_train_spec, mock_eval_spec) - def test_invalid_task_type(self): - mock_est = test.mock.Mock(spec=estimator_lib.Estimator) - mock_est.config = test.mock.Mock() - mock_train_spec = test.mock.Mock(spec=training.TrainSpec) - mock_eval_spec = test.mock.Mock(spec=training.EvalSpec) - - mock_est.config = test.mock.Mock() - mock_est.config.cluster_spec = server_lib.ClusterSpec({'1': ['dummy']}) - mock_est.config.task_type = '' - - with self.assertRaisesRegexp(ValueError, _INVALID_TASK_TYPE): - training.train_and_evaluate(mock_est, mock_train_spec, mock_eval_spec) - class TrainingExecutorConstructorTest(test.TestCase): """Tests constructor of _TrainingExecutor.""" @@ -554,6 +423,8 @@ class _TrainingExecutorTrainingTest(object): self._run_config = run_config def _run_task(self, executor): + # We should not call executor.run as the test here is intended to test + # run_foo explicitly (foo is the task type). return getattr(executor, 'run_' + self._run_config.task_type)() @test.mock.patch.object(time, 'sleep') @@ -1856,6 +1727,147 @@ class TrainingExecutorRunLocalTest(test.TestCase): executor.run_local() +class TrainAndEvaluateRunTest(test.TestCase): + + def _test_run_task_and_executor(self, run_config): + mock_est = test.mock.Mock(spec=estimator_lib.Estimator) + mock_est.config = run_config + mock_train_spec = test.mock.Mock(spec=training.TrainSpec) + mock_eval_spec = test.mock.Mock(spec=training.EvalSpec) + + executor = training._TrainingExecutor(mock_est, mock_train_spec, + mock_eval_spec) + + executor.call_task = {} + + def task_fn(name): + + def _fn(): + executor.call_task[name] = 1 + + return _fn + + executor.run_chief = task_fn('chief') + executor.run_master = task_fn('master') + executor.run_ps = task_fn('ps') + executor.run_evaluator = task_fn('evaluator') + executor.run_worker = task_fn('worker') + executor.run_local = task_fn('local') + return executor + + def test_run_chief(self): + executor = self._test_run_task_and_executor( + run_config=_create_run_config_with_cluster_spec(_TF_CONFIG_FOR_CHIEF)) + executor.run() + self.assertEqual(1, executor.call_task['chief']) + + def test_run_worker(self): + executor = self._test_run_task_and_executor( + run_config=_create_run_config_with_cluster_spec(_TF_CONFIG_FOR_WORKER)) + executor.run() + self.assertEqual(1, executor.call_task['worker']) + + def test_run_ps(self): + executor = self._test_run_task_and_executor( + run_config=_create_run_config_with_cluster_spec(_TF_CONFIG_FOR_PS)) + executor.run() + self.assertEqual(1, executor.call_task['ps']) + + def test_run_evaluator(self): + executor = self._test_run_task_and_executor( + run_config=_create_run_config_with_cluster_spec( + _TF_CONFIG_FOR_EVALUATOR)) + executor.run() + self.assertEqual(1, executor.call_task['evaluator']) + + def test_run_local(self): + executor = self._test_run_task_and_executor( + run_config=run_config_lib.RunConfig()) + executor.run() + self.assertEqual(1, executor.call_task['local']) + + def test_invalid_local_task(self): + tf_config = { + 'cluster': { + run_config_lib.TaskType.CHIEF: ['host0:0'], + 'local': ['hos1:1'], + }, + 'task': { + 'type': 'local', # invalid task type. + 'index': 0 + } + } + mock_est = test.mock.Mock(spec=estimator_lib.Estimator) + mock_est.config = _create_run_config_with_cluster_spec(tf_config) + mock_train_spec = test.mock.Mock(spec=training.TrainSpec) + mock_eval_spec = test.mock.Mock(spec=training.EvalSpec) + + executor = training._TrainingExecutor(mock_est, mock_train_spec, + mock_eval_spec) + with self.assertRaisesRegexp(ValueError, _INVALID_LOCAL_TASK_WITH_CLUSTER): + executor.run() + + def test_unsupported_task_due_to_missing_run_task(self): + unsupported_task = 'alloc' + tf_config = { + 'cluster': { + run_config_lib.TaskType.CHIEF: ['host0:0'], + unsupported_task: ['hos1:1'], + }, + 'task': { + 'type': unsupported_task, + 'index': 0 + } + } + mock_est = test.mock.Mock(spec=estimator_lib.Estimator) + mock_est.config = _create_run_config_with_cluster_spec(tf_config) + mock_train_spec = test.mock.Mock(spec=training.TrainSpec) + mock_eval_spec = test.mock.Mock(spec=training.EvalSpec) + + executor = training._TrainingExecutor(mock_est, mock_train_spec, + mock_eval_spec) + with self.assertRaisesRegexp(ValueError, _INVALID_TASK_TO_RUN): + executor.run() + + def test_unsupported_task_due_to_not_callable(self): + unsupported_task = 'alloc' + tf_config = { + 'cluster': { + run_config_lib.TaskType.CHIEF: ['host0:0'], + unsupported_task: ['hos1:1'], + }, + 'task': { + 'type': unsupported_task, + 'index': 0 + } + } + mock_est = test.mock.Mock(spec=estimator_lib.Estimator) + mock_est.config = _create_run_config_with_cluster_spec(tf_config) + mock_train_spec = test.mock.Mock(spec=training.TrainSpec) + mock_eval_spec = test.mock.Mock(spec=training.EvalSpec) + + executor = training._TrainingExecutor(mock_est, mock_train_spec, + mock_eval_spec) + executor.run_alloc = 123 # not callable + with self.assertRaisesRegexp(ValueError, _INVALID_TASK_TO_RUN): + executor.run() + + def test_invalid_task_type(self): + mock_est = test.mock.Mock(spec=estimator_lib.Estimator) + mock_est.config = test.mock.Mock() + mock_train_spec = test.mock.Mock(spec=training.TrainSpec) + mock_eval_spec = test.mock.Mock(spec=training.EvalSpec) + + mock_est.config = test.mock.Mock() + mock_est.config.cluster_spec = server_lib.ClusterSpec({'1': ['dummy']}) + mock_est.config.task_type = '' + + executor = training._TrainingExecutor(mock_est, mock_train_spec, + mock_eval_spec) + with self.assertRaisesRegexp(ValueError, _INVALID_TASK_TYPE): + executor.run() + + class TrainAndEvaluateIntegrationTest(test.TestCase): def setUp(self): |