diff options
author | Jianwei Xie <xiejw@google.com> | 2017-09-27 09:24:52 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-09-27 09:28:44 -0700 |
commit | 8b9256106334c2c1a78765992b4f6e94e8074f4d (patch) | |
tree | c589f856c7628a56d602ef48fc94a283c21ec5d4 /tensorflow/python/estimator/training.py | |
parent | 01b75170bbc42358109101c3103454dfd86cf0ee (diff) |
Adds implementation for tf.estimator.train_and_evaluate
PiperOrigin-RevId: 170207452
Diffstat (limited to 'tensorflow/python/estimator/training.py')
-rw-r--r-- | tensorflow/python/estimator/training.py | 44 |
1 files changed, 44 insertions, 0 deletions
diff --git a/tensorflow/python/estimator/training.py b/tensorflow/python/estimator/training.py index 0dadfc4adf..565ed0b599 100644 --- a/tensorflow/python/estimator/training.py +++ b/tensorflow/python/estimator/training.py @@ -202,6 +202,50 @@ class EvalSpec( throttle_secs=throttle_secs) +# TODO(xiejw): Write detailed docstring to cover local behavior and distributed +# behavior. Also write examples for both with TF_CONFIG. +def train_and_evaluate(estimator, train_spec, eval_spec): + """Train and evaluate the `estimator`.""" + + if not isinstance(estimator, estimator_lib.Estimator): + raise TypeError('`estimator` must have type `tf.estimator.Estimator`, ' + 'given {}'.format(type(estimator))) + config = estimator.config + + executor = _TrainingExecutor(estimator=estimator, train_spec=train_spec, + eval_spec=eval_spec) + + if (not config.cluster_spec and + config.task_type != run_config_lib.TaskType.EVALUATOR): + logging.info('Running training and evaluation locally (non-distributed).') + return executor.run_local() + + # Distributed case. + if not config.task_type: + # TODO(xiejw): Improve the error message about how to set the TF_CONFIG + # correctly. + raise ValueError( + '`estimator.config` must have task_type set. This usually means ' + 'TF_CONFIG environment is not set correctly.') + + if config.task_type == 'local': + raise ValueError( + '`task.type` in TF_CONFIG cannot be `local`. Leaving `cluster` and ' + '`task` properties in TF_CONFIG absent triggers train and evaluate ' + '`Estimator` locally (non-distributed).') + + # For task type foo, call executor.run_foo. + available_tasks = [x for x in dir(executor) if x.startswith('run_') + and x != 'run_local' + and callable(getattr(executor, x))] + task_to_run = 'run_' + config.task_type + if task_to_run not in available_tasks: + raise ValueError( + 'Task type {} is not supported. Supported task types are {}'.format( + config.task_type, [x[len('run_'):] for x in available_tasks])) + return getattr(executor, task_to_run)() + + class _StopAtSecsHook(session_run_hook.SessionRunHook): """Stops given secs after begin is called.""" |