aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/estimator/training_test.py
diff options
context:
space:
mode:
authorGravatar Jianwei Xie <xiejw@google.com>2018-01-02 15:05:38 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-02 15:09:39 -0800
commitc25694086e185e844e684ec196aa13667a7c2406 (patch)
tree717097601a04e5b8c337c948fbe54a6fb4e26d22 /tensorflow/python/estimator/training_test.py
parenta966bbca81509201e7d0e1e30fb4abfdd9dcad4d (diff)
Adds _TrainingExecutor.run method to automatically invoke correct procedure.
PiperOrigin-RevId: 180598558
Diffstat (limited to 'tensorflow/python/estimator/training_test.py')
-rw-r--r--tensorflow/python/estimator/training_test.py284
1 files changed, 148 insertions, 136 deletions
diff --git a/tensorflow/python/estimator/training_test.py b/tensorflow/python/estimator/training_test.py
index 9536ee44d5..2d3f5d6cef 100644
--- a/tensorflow/python/estimator/training_test.py
+++ b/tensorflow/python/estimator/training_test.py
@@ -312,61 +312,21 @@ class EvalSpecTest(test.TestCase):
training.EvalSpec(input_fn=lambda: 1, exporters=_create_exporter(None))
-class TrainAndEvaluteTest(test.TestCase):
+class TrainAndEvaluateTest(test.TestCase):
- def _mock_executor_instance(self):
- mock_instance = test.mock.Mock()
- mock_instance.call_task = {}
-
- def task_fn(name):
- def _fn():
- mock_instance.call_task[name] = 1
- return _fn
-
- mock_instance.run_chief = task_fn('chief')
- mock_instance.run_master = task_fn('master')
- mock_instance.run_ps = task_fn('ps')
- mock_instance.run_evaluator = task_fn('evaluator')
- mock_instance.run_worker = task_fn('worker')
- mock_instance.run_local = task_fn('local')
-
- return mock_instance
-
- def _test_run_task_in_distributed_training(self, run_config):
+ def test_run_task(self):
mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
- mock_est.config = run_config
mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
mock_eval_spec = test.mock.Mock(spec=training.EvalSpec)
with test.mock.patch.object(training, '_TrainingExecutor') as mock_executor:
- mock_executor_instance = self._mock_executor_instance()
+ mock_executor_instance = test.mock.Mock()
mock_executor.return_value = mock_executor_instance
training.train_and_evaluate(mock_est, mock_train_spec, mock_eval_spec)
mock_executor.assert_called_with(estimator=mock_est,
train_spec=mock_train_spec,
eval_spec=mock_eval_spec)
- return mock_executor_instance
-
- def test_run_chief(self):
- mock_executor = self._test_run_task_in_distributed_training(
- run_config=_create_run_config_with_cluster_spec(_TF_CONFIG_FOR_CHIEF))
- self.assertEqual(1, mock_executor.call_task['chief'])
-
- def test_run_worker(self):
- mock_executor = self._test_run_task_in_distributed_training(
- run_config=_create_run_config_with_cluster_spec(_TF_CONFIG_FOR_WORKER))
- self.assertEqual(1, mock_executor.call_task['worker'])
-
- def test_run_ps(self):
- mock_executor = self._test_run_task_in_distributed_training(
- run_config=_create_run_config_with_cluster_spec(_TF_CONFIG_FOR_PS))
- self.assertEqual(1, mock_executor.call_task['ps'])
-
- def test_run_evaluator(self):
- mock_executor = self._test_run_task_in_distributed_training(
- run_config=_create_run_config_with_cluster_spec(
- _TF_CONFIG_FOR_EVALUATOR))
- self.assertEqual(1, mock_executor.call_task['evaluator'])
+ mock_executor_instance.run.assert_called()
def test_error_out_if_evaluator_task_id_is_non_zero(self):
tf_config = {
@@ -378,93 +338,15 @@ class TrainAndEvaluteTest(test.TestCase):
'index': 1
}
}
- with self.assertRaisesRegexp(ValueError, _INVALID_EVAL_TASK_ID_ERR):
- self._test_run_task_in_distributed_training(
- run_config=_create_run_config_with_cluster_spec(tf_config))
- def test_run_local(self):
- mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
- mock_est.config = run_config_lib.RunConfig()
- mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
- mock_eval_spec = test.mock.Mock(spec=training.EvalSpec)
-
- with test.mock.patch.object(training, '_TrainingExecutor') as mock_executor:
- mock_executor_instance = self._mock_executor_instance()
- mock_executor.return_value = mock_executor_instance
- training.train_and_evaluate(mock_est, mock_train_spec, mock_eval_spec)
- self.assertEqual(1, mock_executor_instance.call_task['local'])
-
- mock_executor.assert_called_with(estimator=mock_est,
- train_spec=mock_train_spec,
- eval_spec=mock_eval_spec)
-
- def test_invalid_local_task(self):
- tf_config = {
- 'cluster': {
- run_config_lib.TaskType.CHIEF: ['host0:0'],
- 'local': ['hos1:1'],
- },
- 'task': {
- 'type': 'local',
- 'index': 0
- }
- }
mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
mock_est.config = _create_run_config_with_cluster_spec(tf_config)
mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
mock_eval_spec = test.mock.Mock(spec=training.EvalSpec)
- with self.assertRaisesRegexp(ValueError, _INVALID_LOCAL_TASK_WITH_CLUSTER):
+ with self.assertRaisesRegexp(ValueError, _INVALID_EVAL_TASK_ID_ERR):
training.train_and_evaluate(mock_est, mock_train_spec, mock_eval_spec)
- def test_unsupported_task_due_to_missing_run_task(self):
- unsupported_task = 'alloc'
- tf_config = {
- 'cluster': {
- run_config_lib.TaskType.CHIEF: ['host0:0'],
- unsupported_task: ['hos1:1'],
- },
- 'task': {
- 'type': unsupported_task,
- 'index': 0
- }
- }
- mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
- mock_est.config = _create_run_config_with_cluster_spec(tf_config)
- mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
- mock_eval_spec = test.mock.Mock(spec=training.EvalSpec)
-
- with test.mock.patch.object(training, '_TrainingExecutor') as mock_executor:
- # mock_instance has no run_alloc method.
- mock_instance = self._mock_executor_instance()
- mock_executor.return_value = mock_instance
- with self.assertRaisesRegexp(ValueError, _INVALID_TASK_TO_RUN):
- training.train_and_evaluate(mock_est, mock_train_spec, mock_eval_spec)
-
- def test_unsupported_task_due_to_not_callable(self):
- unsupported_task = 'alloc'
- tf_config = {
- 'cluster': {
- run_config_lib.TaskType.CHIEF: ['host0:0'],
- unsupported_task: ['hos1:1'],
- },
- 'task': {
- 'type': unsupported_task,
- 'index': 0
- }
- }
- mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
- mock_est.config = _create_run_config_with_cluster_spec(tf_config)
- mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
- mock_eval_spec = test.mock.Mock(spec=training.EvalSpec)
-
- with test.mock.patch.object(training, '_TrainingExecutor') as mock_executor:
- mock_instance = self._mock_executor_instance()
- mock_instance.run_alloc = 123 # not callable
- mock_executor.return_value = mock_instance
- with self.assertRaisesRegexp(ValueError, _INVALID_TASK_TO_RUN):
- training.train_and_evaluate(mock_est, mock_train_spec, mock_eval_spec)
-
def test_invalid_estimator(self):
invalid_estimator = object()
mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
@@ -474,19 +356,6 @@ class TrainAndEvaluteTest(test.TestCase):
training.train_and_evaluate(invalid_estimator, mock_train_spec,
mock_eval_spec)
- def test_invalid_task_type(self):
- mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
- mock_est.config = test.mock.Mock()
- mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
- mock_eval_spec = test.mock.Mock(spec=training.EvalSpec)
-
- mock_est.config = test.mock.Mock()
- mock_est.config.cluster_spec = server_lib.ClusterSpec({'1': ['dummy']})
- mock_est.config.task_type = ''
-
- with self.assertRaisesRegexp(ValueError, _INVALID_TASK_TYPE):
- training.train_and_evaluate(mock_est, mock_train_spec, mock_eval_spec)
-
class TrainingExecutorConstructorTest(test.TestCase):
"""Tests constructor of _TrainingExecutor."""
@@ -554,6 +423,8 @@ class _TrainingExecutorTrainingTest(object):
self._run_config = run_config
def _run_task(self, executor):
+ # We should not call executor.run as the test here is intended to test
+ # run_foo explicitly (foo is the task type).
return getattr(executor, 'run_' + self._run_config.task_type)()
@test.mock.patch.object(time, 'sleep')
@@ -1856,6 +1727,147 @@ class TrainingExecutorRunLocalTest(test.TestCase):
executor.run_local()
+class TrainAndEvaluateRunTest(test.TestCase):
+
+ def _test_run_task_and_executor(self, run_config):
+ mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
+ mock_est.config = run_config
+ mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
+ mock_eval_spec = test.mock.Mock(spec=training.EvalSpec)
+
+ executor = training._TrainingExecutor(mock_est, mock_train_spec,
+ mock_eval_spec)
+
+ executor.call_task = {}
+
+ def task_fn(name):
+
+ def _fn():
+ executor.call_task[name] = 1
+
+ return _fn
+
+ executor.run_chief = task_fn('chief')
+ executor.run_master = task_fn('master')
+ executor.run_ps = task_fn('ps')
+ executor.run_evaluator = task_fn('evaluator')
+ executor.run_worker = task_fn('worker')
+ executor.run_local = task_fn('local')
+ return executor
+
+ def test_run_chief(self):
+ executor = self._test_run_task_and_executor(
+ run_config=_create_run_config_with_cluster_spec(_TF_CONFIG_FOR_CHIEF))
+ executor.run()
+ self.assertEqual(1, executor.call_task['chief'])
+
+ def test_run_worker(self):
+ executor = self._test_run_task_and_executor(
+ run_config=_create_run_config_with_cluster_spec(_TF_CONFIG_FOR_WORKER))
+ executor.run()
+ self.assertEqual(1, executor.call_task['worker'])
+
+ def test_run_ps(self):
+ executor = self._test_run_task_and_executor(
+ run_config=_create_run_config_with_cluster_spec(_TF_CONFIG_FOR_PS))
+ executor.run()
+ self.assertEqual(1, executor.call_task['ps'])
+
+ def test_run_evaluator(self):
+ executor = self._test_run_task_and_executor(
+ run_config=_create_run_config_with_cluster_spec(
+ _TF_CONFIG_FOR_EVALUATOR))
+ executor.run()
+ self.assertEqual(1, executor.call_task['evaluator'])
+
+ def test_run_local(self):
+ executor = self._test_run_task_and_executor(
+ run_config=run_config_lib.RunConfig())
+ executor.run()
+ self.assertEqual(1, executor.call_task['local'])
+
+ def test_invalid_local_task(self):
+ tf_config = {
+ 'cluster': {
+ run_config_lib.TaskType.CHIEF: ['host0:0'],
+ 'local': ['hos1:1'],
+ },
+ 'task': {
+ 'type': 'local', # invalid task type.
+ 'index': 0
+ }
+ }
+ mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
+ mock_est.config = _create_run_config_with_cluster_spec(tf_config)
+ mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
+ mock_eval_spec = test.mock.Mock(spec=training.EvalSpec)
+
+ executor = training._TrainingExecutor(mock_est, mock_train_spec,
+ mock_eval_spec)
+ with self.assertRaisesRegexp(ValueError, _INVALID_LOCAL_TASK_WITH_CLUSTER):
+ executor.run()
+
+ def test_unsupported_task_due_to_missing_run_task(self):
+ unsupported_task = 'alloc'
+ tf_config = {
+ 'cluster': {
+ run_config_lib.TaskType.CHIEF: ['host0:0'],
+ unsupported_task: ['hos1:1'],
+ },
+ 'task': {
+ 'type': unsupported_task,
+ 'index': 0
+ }
+ }
+ mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
+ mock_est.config = _create_run_config_with_cluster_spec(tf_config)
+ mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
+ mock_eval_spec = test.mock.Mock(spec=training.EvalSpec)
+
+ executor = training._TrainingExecutor(mock_est, mock_train_spec,
+ mock_eval_spec)
+ with self.assertRaisesRegexp(ValueError, _INVALID_TASK_TO_RUN):
+ executor.run()
+
+ def test_unsupported_task_due_to_not_callable(self):
+ unsupported_task = 'alloc'
+ tf_config = {
+ 'cluster': {
+ run_config_lib.TaskType.CHIEF: ['host0:0'],
+ unsupported_task: ['hos1:1'],
+ },
+ 'task': {
+ 'type': unsupported_task,
+ 'index': 0
+ }
+ }
+ mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
+ mock_est.config = _create_run_config_with_cluster_spec(tf_config)
+ mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
+ mock_eval_spec = test.mock.Mock(spec=training.EvalSpec)
+
+ executor = training._TrainingExecutor(mock_est, mock_train_spec,
+ mock_eval_spec)
+ executor.run_alloc = 123 # not callable
+ with self.assertRaisesRegexp(ValueError, _INVALID_TASK_TO_RUN):
+ executor.run()
+
+ def test_invalid_task_type(self):
+ mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
+ mock_est.config = test.mock.Mock()
+ mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
+ mock_eval_spec = test.mock.Mock(spec=training.EvalSpec)
+
+ mock_est.config = test.mock.Mock()
+ mock_est.config.cluster_spec = server_lib.ClusterSpec({'1': ['dummy']})
+ mock_est.config.task_type = ''
+
+ executor = training._TrainingExecutor(mock_est, mock_train_spec,
+ mock_eval_spec)
+ with self.assertRaisesRegexp(ValueError, _INVALID_TASK_TYPE):
+ executor.run()
+
+
class TrainAndEvaluateIntegrationTest(test.TestCase):
def setUp(self):