aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/estimator/training.py
diff options
context:
space:
mode:
authorGravatar Jianwei Xie <xiejw@google.com>2017-09-27 09:24:52 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-27 09:28:44 -0700
commit8b9256106334c2c1a78765992b4f6e94e8074f4d (patch)
treec589f856c7628a56d602ef48fc94a283c21ec5d4 /tensorflow/python/estimator/training.py
parent01b75170bbc42358109101c3103454dfd86cf0ee (diff)
Adds implementation for tf.estimator.train_and_evaluate
PiperOrigin-RevId: 170207452
Diffstat (limited to 'tensorflow/python/estimator/training.py')
-rw-r--r--tensorflow/python/estimator/training.py44
1 files changed, 44 insertions, 0 deletions
diff --git a/tensorflow/python/estimator/training.py b/tensorflow/python/estimator/training.py
index 0dadfc4adf..565ed0b599 100644
--- a/tensorflow/python/estimator/training.py
+++ b/tensorflow/python/estimator/training.py
@@ -202,6 +202,50 @@ class EvalSpec(
throttle_secs=throttle_secs)
+# TODO(xiejw): Write detailed docstring to cover local behavior and distributed
+# behavior. Also write examples for both with TF_CONFIG.
+def train_and_evaluate(estimator, train_spec, eval_spec):
+ """Train and evaluate the `estimator`."""
+
+ if not isinstance(estimator, estimator_lib.Estimator):
+ raise TypeError('`estimator` must have type `tf.estimator.Estimator`, '
+ 'given {}'.format(type(estimator)))
+ config = estimator.config
+
+ executor = _TrainingExecutor(estimator=estimator, train_spec=train_spec,
+ eval_spec=eval_spec)
+
+ if (not config.cluster_spec and
+ config.task_type != run_config_lib.TaskType.EVALUATOR):
+ logging.info('Running training and evaluation locally (non-distributed).')
+ return executor.run_local()
+
+ # Distributed case.
+ if not config.task_type:
+ # TODO(xiejw): Improve the error message about how to set the TF_CONFIG
+ # correctly.
+ raise ValueError(
+ '`estimator.config` must have task_type set. This usually means '
+ 'TF_CONFIG environment is not set correctly.')
+
+ if config.task_type == 'local':
+ raise ValueError(
+ '`task.type` in TF_CONFIG cannot be `local`. Leaving `cluster` and '
+ '`task` properties in TF_CONFIG absent triggers train and evaluate '
+ '`Estimator` locally (non-distributed).')
+
+ # For task type foo, call executor.run_foo.
+ available_tasks = [x for x in dir(executor) if x.startswith('run_')
+ and x != 'run_local'
+ and callable(getattr(executor, x))]
+ task_to_run = 'run_' + config.task_type
+ if task_to_run not in available_tasks:
+ raise ValueError(
+ 'Task type {} is not supported. Supported task types are {}'.format(
+ config.task_type, [x[len('run_'):] for x in available_tasks]))
+ return getattr(executor, task_to_run)()
+
+
class _StopAtSecsHook(session_run_hook.SessionRunHook):
"""Stops given secs after begin is called."""