aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/estimator
diff options
context:
space:
mode:
authorGravatar Jianwei Xie <xiejw@google.com>2017-10-04 12:41:45 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-04 12:46:07 -0700
commit943c6d7af7a8ccd4f824a2c0f90b251587c63fea (patch)
tree5a3ad83df0155e06708a7e141823423f218e0206 /tensorflow/python/estimator
parent8c9ef44668c767dd30de14f49fb96be6e2648243 (diff)
errors out if the evaluator has task id > 0.
PiperOrigin-RevId: 171047652
Diffstat (limited to 'tensorflow/python/estimator')
-rw-r--r--tensorflow/python/estimator/training.py8
-rw-r--r--tensorflow/python/estimator/training_test.py18
2 files changed, 23 insertions, 3 deletions
diff --git a/tensorflow/python/estimator/training.py b/tensorflow/python/estimator/training.py
index 166b7b20ed..953e970eea 100644
--- a/tensorflow/python/estimator/training.py
+++ b/tensorflow/python/estimator/training.py
@@ -438,14 +438,18 @@ def train_and_evaluate(estimator, train_spec, eval_spec):
'`estimator.config` must have task_type set. This usually means '
'TF_CONFIG environment is not set correctly.')
- # TODO(xiejw): error out if evaluator index is more than 0.
-
if config.task_type == 'local':
raise ValueError(
'`task.type` in TF_CONFIG cannot be `local`. Leaving `cluster` and '
'`task` properties in TF_CONFIG absent triggers train and evaluate '
'`Estimator` locally (non-distributed).')
+ if (config.task_type == run_config_lib.TaskType.EVALUATOR and
+ config.task_id > 0):
+ raise ValueError(
+ 'For distributed training, there can only be one `evaluator` task '
+ '(with task id 0). Given task id {}'.format(config.task_id))
+
# For task type foo, call executor.run_foo.
available_tasks = [x for x in dir(executor) if x.startswith('run_')
and x != 'run_local'
diff --git a/tensorflow/python/estimator/training_test.py b/tensorflow/python/estimator/training_test.py
index c474004dab..e4c400ca7f 100644
--- a/tensorflow/python/estimator/training_test.py
+++ b/tensorflow/python/estimator/training_test.py
@@ -71,6 +71,8 @@ _INVALID_EMPTY_EVAL_RESULT_ERR = (
_INVALID_EVAL_RESULT_TYPE_ERR = '`Estimator.evaluate` should return dict.'
_MISSING_GLOBAL_STEP_IN_EVAL_RESULT_ERR = (
'Internal error: `Estimator.evaluate` result should have `global_step`')
+_INVALID_EVAL_TASK_ID_ERR = (
+ 'there can only be one `evaluator` task .*with task id 0')
_TF_CONFIG_FOR_CHIEF = {
'cluster': {
@@ -128,7 +130,7 @@ _TF_CONFIG_FOR_EVALUATOR = {
},
'task': {
'type': run_config_lib.TaskType.EVALUATOR,
- 'index': 1
+ 'index': 0
}
}
@@ -351,6 +353,20 @@ class TrainAndEvaluteTest(test.TestCase):
_TF_CONFIG_FOR_EVALUATOR))
self.assertEqual(1, mock_executor.call_task['evaluator'])
+ def test_error_out_if_evaluator_task_id_is_non_zero(self):
+ tf_config = {
+ 'cluster': {
+ run_config_lib.TaskType.CHIEF: ['host0:0'],
+ },
+ 'task': {
+ 'type': run_config_lib.TaskType.EVALUATOR,
+ 'index': 1
+ }
+ }
+ with self.assertRaisesRegexp(ValueError, _INVALID_EVAL_TASK_ID_ERR):
+ self._test_run_task_in_distributed_training(
+ run_config=_create_run_config_with_cluster_spec(tf_config))
+
def test_run_local(self):
mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
mock_est.config = run_config_lib.RunConfig()