aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/estimator/training.py
diff options
context:
space:
mode:
authorGravatar Jianwei Xie <xiejw@google.com>2017-10-04 10:27:49 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-04 10:34:50 -0700
commit6a1b867ff939211673abe6ebe2d3989c74084403 (patch)
tree626b8af06908fc12522a2625069ec09b9d07f9fb /tensorflow/python/estimator/training.py
parent7209c1602dc71cb118ab3fa6af282b85b63bd4ad (diff)
Adds the docstring with details for tf.estimator.train_and_evaluate
PiperOrigin-RevId: 171027527
Diffstat (limited to 'tensorflow/python/estimator/training.py')
-rw-r--r--tensorflow/python/estimator/training.py212
1 files changed, 189 insertions, 23 deletions
diff --git a/tensorflow/python/estimator/training.py b/tensorflow/python/estimator/training.py
index 604c1a356c..df0b602309 100644
--- a/tensorflow/python/estimator/training.py
+++ b/tensorflow/python/estimator/training.py
@@ -112,9 +112,10 @@ def _is_google_env():
class TrainSpec(
collections.namedtuple('TrainSpec', ['input_fn', 'max_steps', 'hooks'])):
- """Objects passed to `train_and_evaluate`.
+ """Configuration for the "train" part for the `train_and_evaluate` call.
- `TrainSpec` fully defines the objects to be run by `Estimator.train`.
+ `TrainSpec` determines the input data for the training, as well as the
+ duration. Optional hooks run at various stages of training.
"""
def __new__(cls,
@@ -127,9 +128,10 @@ class TrainSpec(
input_fn: Training input function returning a tuple of:
features - `Tensor` or dictionary of string feature name to `Tensor`.
labels - `Tensor` or dictionary of `Tensor` with labels.
- max_steps: Int. Number of total steps for which to train model. If `None`,
- train forever or train until `input_fn` generates the `OutOfRange` error
- or `StopIteration` exception. See `Estimator.train` for details.
+ max_steps: Int. Positive number of total steps for which to train model.
+ If `None`, train forever. The training `input_fn` is not expected to
+ generate `OutOfRangeError` or `StopIteration` exceptions. See the
+ `train_and_evaluate` stop condition section for details.
hooks: Iterable of `tf.train.SessionRunHook` objects to run
on all workers (including chief) during training.
@@ -137,8 +139,8 @@ class TrainSpec(
A validated `TrainSpec` object.
Raises:
- ValueError: If validation fails.
- TypeError: If any of the arguments is not the expected type.
+ ValueError: If any of the input arguments is invalid.
+ TypeError: If any of the arguments is not of the expected type.
"""
# Validate input_fn.
_validate_input_fn(input_fn)
@@ -163,10 +165,12 @@ class EvalSpec(
'input_fn', 'steps', 'name', 'hooks', 'exporters',
'delay_secs', 'throttle_secs'
])):
- """Objects passed to `train_and_evaluate`.
+ """Configuration for the "eval" part for the `train_and_evaluate` call.
- `EvalSpec` fully defines the objects to be run by `Estimator.evaluate` and
- `Estimator.export_savedmodel`.
+ `EvalSpec` combines details of evaluation of the trained model as well as its
+ export. Evaluation consists of computing metrics to judge the performance of
+ the trained model. Export writes out the trained model on to external
+ storage.
"""
def __new__(cls,
@@ -180,12 +184,12 @@ class EvalSpec(
"""Creates a validated `EvalSpec` instance.
Args:
- input_fn: Training input function returning a tuple of:
+ input_fn: Evaluation input function returning a tuple of:
features - `Tensor` or dictionary of string feature name to `Tensor`.
labels - `Tensor` or dictionary of `Tensor` with labels.
- steps: Int. Number of total steps for which to train model. If `None`,
- train forever or train until `input_fn` generates the `OutOfRange` error
- or `StopIteration` exception. See `Estimator.train` for details.
+ steps: Int. Positive number of steps for which to evaluate model. If
+ `None`, evaluates until `input_fn` raises an end-of-input exception.
+ See `Estimator.evaluate` for details.
name: String. Name of the evaluation if user needs to run multiple
evaluations on different data sets. Metrics for different evaluations
are saved in separate folders, and appear separately in tensorboard.
@@ -196,14 +200,14 @@ class EvalSpec(
delay_secs: Int. Start evaluating after waiting for this many seconds.
throttle_secs: Int. Do not re-evaluate unless the last evaluation was
started at least this many seconds ago. Of course, evaluation does not
- occur if no new checkpoint is available, hence, this is the minimum.
+ occur if no new checkpoints are available, hence, this is the minimum.
Returns:
- A validated `TrainSpec` object.
+ A validated `EvalSpec` object.
Raises:
- ValueError: If validation fails.
- TypeError: If any of the arguments is not the expected type.
+ ValueError: If any of the input arguments is invalid.
+ TypeError: If any of the arguments is not of the expected type.
"""
# Validate input_fn.
_validate_input_fn(input_fn)
@@ -243,10 +247,168 @@ class EvalSpec(
throttle_secs=throttle_secs)
-# TODO(xiejw): Write detailed docstring to cover local behavior and distributed
-# behavior. Also write examples for both with TF_CONFIG.
def train_and_evaluate(estimator, train_spec, eval_spec):
- """Train and evaluate the `estimator`."""
+ """Train and evaluate the `estimator`.
+
+ This utility function trains, evaluates, and (optionally) exports the model by
+ using the given `estimator`. All training related specification is held in
+ `train_spec`, including training `input_fn` and training max steps, etc. All
+ evaluation and export related specification is held in `eval_spec`, including
+ evaluation `input_fn`, steps, etc.
+
+ This utility function provides consistent behavior for both local
+ (non-distributed) and distributed configurations. Currently, the only
+ supported distributed training configuration is between-graph replication.
+
+ Overfitting: In order to avoid overfitting, it is recommended to set up the
+ training `input_fn` to shuffle the training data properly. It is also
+ recommended to train the model a little longer, say multiple epochs, before
+ performing evaluation, as the input pipeline starts from scratch for each
+ training. It is particularly important for local training and evaluation.
+
+ Stop condition: In order to support both distributed and non-distributed
+ configuration reliably, the only supported stop condition for model
+ training is `train_spec.max_steps`. If `train_spec.max_steps` is `None`, the
+ model is trained forever. *Use with care* if model stop condition is
+ different. For example, assume that the model is expected to be trained with
+ one epoch of training data, and the training `input_fn` is configured to throw
+ `OutOfRangeError` after going through one epoch, which stops the
+ `Estimator.train`. For a three-training-worker distributed configuration, each
+ training worker is likely to go through the whole epoch independently. So, the
+ model will be trained with three epochs of training data instead of one epoch.
+
+ Example of local (non-distributed) training:
+ ```python
+ # Set up feature columns.
+ categorial_feature_a = categorial_column_with_hash_bucket(...)
+ categorial_feature_a_emb = embedding_column(
+ categorical_column=categorial_feature_a, ...)
+ ... # other feature columns
+
+ estimator = DNNClassifier(
+ feature_columns=[categorial_feature_a_emb, ...],
+ hidden_units=[1024, 512, 256])
+
+ # Or set up the model directory
+ # estimator = DNNClassifier(
+ # config=tf.estimator.RunConfig(
+ # model_dir='/my_model', save_summary_steps=100),
+ # feature_columns=[categorial_feature_a_emb, ...],
+ # hidden_units=[1024, 512, 256])
+
+ # Input pipeline for train and evaluate.
+ def train_input_fn: # returns x, y
+ # please shuffle the data.
+ pass
+ def eval_input_fn_eval: # returns x, y
+ pass
+
+ train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn, max_steps=1000)
+ eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn)
+
+ tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
+ ```
+
+ Example of distributed training:
+
+ Regarding the example of distributed training, the code above can be used
+ without a change (Please do make sure that the `RunConfig.model_dir` for all
+ workers is set to the same directory, i.e., a shared file system all workers
+ can read and write). The only extra work to do is setting the environment
+ variable `TF_CONFIG` properly for each worker correspondingly.
+
+ Also see: https://www.tensorflow.org/deploy/distributed
+
+ Setting environment variable depends on the platform. For example, on Linux,
+ it can be done as follows (`$` is the shell prompt):
+ ```
+ $ TF_CONFIG="<replace_with_real_content>" python train_model.py
+ ```
+
+ For the content in `TF_CONFIG`, assume that the training cluster spec looks
+ like:
+ ```
+ cluster = {'chief': ['host0:2222'],
+ 'worker': ['host1:2222', 'host2:2222', 'host3:2222'],
+ 'ps': ['host4:2222', 'host5:2222']}
+ ```
+
+ Example of `TF_CONFIG` for chief training worker (must have one and only one):
+ ```
+ # This should be a JSON string, which is set as environment variable. Usually
+ # the cluster manager handles that.
+ TF_CONFIG="{
+ 'cluster': {
+ 'chief': ['host0:2222'],
+ 'worker': ['host1:2222', 'host2:2222', 'host3:2222'],
+ 'ps': ['host4:2222', 'host5:2222']
+ },
+ 'task': {'type': 'chief', 'index': 0}
+ }"
+ ```
+ Note that the chief worker also does the model training job, similar to other
+ non-chief training workers (see next paragraph). In addition to the model
+ training, it manages some extra work, e.g., checkpoint saving and restoring,
+ writing summaries, etc.
+
+ Example of `TF_CONFIG` for non-chief training worker (optional, could be
+ multiple):
+ ```
+ # This should be a JSON string, which is set as environment variable. Usually
+ # the cluster manager handles that.
+ TF_CONFIG="{
+ 'cluster': {
+ 'chief': ['host0:2222'],
+ 'worker': ['host1:2222', 'host2:2222', 'host3:2222'],
+ 'ps': ['host4:2222', 'host5:2222']
+ },
+ 'task': {'type': 'worker', 'index': 0}
+ }"
+ ```
+ where the `task.index` should be set as 0, 1, 2, in this example, respectively
+ for non-chief training workers.
+
+ Example of `TF_CONFIG` for parameter server, aka ps (could be multiple):
+ ```
+ # This should be a JSON string, which is set as environment variable. Usually
+ # the cluster manager handles that.
+ TF_CONFIG="{
+ 'cluster': {
+ 'chief': ['host0:2222'],
+ 'worker': ['host1:2222', 'host2:2222', 'host3:2222'],
+ 'ps': ['host4:2222', 'host5:2222']
+ },
+ 'task': {'type': 'ps', 'index': 0}
+ }"
+ ```
+ where the `task.index` should be set as 0 and 1, in this example, respectively
+ for parameter servers.
+
+ Example of `TF_CONFIG` for evaluator task. Evaluator is a special task that is
+ not part of the training cluster. There could be only one. It is used for
+ model evaluation.
+ ```
+ # This should be a JSON string, which is set as environment variable. Usually
+ # the cluster manager handles that.
+ TF_CONFIG="{
+ 'cluster': {
+ 'chief': ['host0:2222'],
+ 'worker': ['host1:2222', 'host2:2222', 'host3:2222'],
+ 'ps': ['host4:2222', 'host5:2222']
+ },
+ 'task': {'type': 'evaluator', 'index': 0}
+ }"
+ ```
+
+ Args:
+ estimator: An `Estimator` instance to train and evaluate.
+ train_spec: A `TrainSpec instance to specify the training specification.
+ eval_spec: A `EvalSpec instance to specify the evaluation and export
+ specification.
+
+ Raises:
+ ValueError: if environment variable `TF_CONFIG` is incorrectly set.
+ """
if not isinstance(estimator, estimator_lib.Estimator):
raise TypeError('`estimator` must have type `tf.estimator.Estimator`, '
@@ -259,7 +421,8 @@ def train_and_evaluate(estimator, train_spec, eval_spec):
if (not config.cluster_spec and
config.task_type != run_config_lib.TaskType.EVALUATOR):
logging.info('Running training and evaluation locally (non-distributed).')
- return executor.run_local()
+ executor.run_local()
+ return
# Distributed case.
if not config.task_type:
@@ -269,6 +432,8 @@ def train_and_evaluate(estimator, train_spec, eval_spec):
'`estimator.config` must have task_type set. This usually means '
'TF_CONFIG environment is not set correctly.')
+ # TODO(xiejw): error out if evaluator index is more than 0.
+
if config.task_type == 'local':
raise ValueError(
'`task.type` in TF_CONFIG cannot be `local`. Leaving `cluster` and '
@@ -284,7 +449,8 @@ def train_and_evaluate(estimator, train_spec, eval_spec):
raise ValueError(
'Task type {} is not supported. Supported task types are {}'.format(
config.task_type, [x[len('run_'):] for x in available_tasks]))
- return getattr(executor, task_to_run)()
+ getattr(executor, task_to_run)()
+ return
class _StopAtSecsHook(session_run_hook.SessionRunHook):