diff options
author | Jianwei Xie <xiejw@google.com> | 2017-10-04 12:41:45 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-10-04 12:46:07 -0700 |
commit | 943c6d7af7a8ccd4f824a2c0f90b251587c63fea (patch) | |
tree | 5a3ad83df0155e06708a7e141823423f218e0206 /tensorflow/python/estimator | |
parent | 8c9ef44668c767dd30de14f49fb96be6e2648243 (diff) |
errors out if the evaluator has task id > 0.
PiperOrigin-RevId: 171047652
Diffstat (limited to 'tensorflow/python/estimator')
-rw-r--r-- | tensorflow/python/estimator/training.py | 8 | ||||
-rw-r--r-- | tensorflow/python/estimator/training_test.py | 18 |
2 files changed, 23 insertions, 3 deletions
diff --git a/tensorflow/python/estimator/training.py b/tensorflow/python/estimator/training.py index 166b7b20ed..953e970eea 100644 --- a/tensorflow/python/estimator/training.py +++ b/tensorflow/python/estimator/training.py @@ -438,14 +438,18 @@ def train_and_evaluate(estimator, train_spec, eval_spec): '`estimator.config` must have task_type set. This usually means ' 'TF_CONFIG environment is not set correctly.') - # TODO(xiejw): error out if evaluator index is more than 0. - if config.task_type == 'local': raise ValueError( '`task.type` in TF_CONFIG cannot be `local`. Leaving `cluster` and ' '`task` properties in TF_CONFIG absent triggers train and evaluate ' '`Estimator` locally (non-distributed).') + if (config.task_type == run_config_lib.TaskType.EVALUATOR and + config.task_id > 0): + raise ValueError( + 'For distributed training, there can only be one `evaluator` task ' + '(with task id 0). Given task id {}'.format(config.task_id)) + # For task type foo, call executor.run_foo. available_tasks = [x for x in dir(executor) if x.startswith('run_') and x != 'run_local' diff --git a/tensorflow/python/estimator/training_test.py b/tensorflow/python/estimator/training_test.py index c474004dab..e4c400ca7f 100644 --- a/tensorflow/python/estimator/training_test.py +++ b/tensorflow/python/estimator/training_test.py @@ -71,6 +71,8 @@ _INVALID_EMPTY_EVAL_RESULT_ERR = ( _INVALID_EVAL_RESULT_TYPE_ERR = '`Estimator.evaluate` should return dict.' _MISSING_GLOBAL_STEP_IN_EVAL_RESULT_ERR = ( 'Internal error: `Estimator.evaluate` result should have `global_step`') +_INVALID_EVAL_TASK_ID_ERR = ( + 'there can only be one `evaluator` task .*with task id 0') _TF_CONFIG_FOR_CHIEF = { 'cluster': { @@ -128,7 +130,7 @@ _TF_CONFIG_FOR_EVALUATOR = { }, 'task': { 'type': run_config_lib.TaskType.EVALUATOR, - 'index': 1 + 'index': 0 } } @@ -351,6 +353,20 @@ class TrainAndEvaluteTest(test.TestCase): _TF_CONFIG_FOR_EVALUATOR)) self.assertEqual(1, mock_executor.call_task['evaluator']) + def test_error_out_if_evaluator_task_id_is_non_zero(self): + tf_config = { + 'cluster': { + run_config_lib.TaskType.CHIEF: ['host0:0'], + }, + 'task': { + 'type': run_config_lib.TaskType.EVALUATOR, + 'index': 1 + } + } + with self.assertRaisesRegexp(ValueError, _INVALID_EVAL_TASK_ID_ERR): + self._test_run_task_in_distributed_training( + run_config=_create_run_config_with_cluster_spec(tf_config)) + def test_run_local(self): mock_est = test.mock.Mock(spec=estimator_lib.Estimator) mock_est.config = run_config_lib.RunConfig() |