diff options
author | 2017-10-04 10:27:49 -0700 | |
---|---|---|
committer | 2017-10-04 10:34:50 -0700 | |
commit | 6a1b867ff939211673abe6ebe2d3989c74084403 (patch) | |
tree | 626b8af06908fc12522a2625069ec09b9d07f9fb /tensorflow/python/estimator/training.py | |
parent | 7209c1602dc71cb118ab3fa6af282b85b63bd4ad (diff) |
Adds the docstring with details for tf.estimator.train_and_evaluate
PiperOrigin-RevId: 171027527
Diffstat (limited to 'tensorflow/python/estimator/training.py')
-rw-r--r-- | tensorflow/python/estimator/training.py | 212 |
1 files changed, 189 insertions, 23 deletions
diff --git a/tensorflow/python/estimator/training.py b/tensorflow/python/estimator/training.py index 604c1a356c..df0b602309 100644 --- a/tensorflow/python/estimator/training.py +++ b/tensorflow/python/estimator/training.py @@ -112,9 +112,10 @@ def _is_google_env(): class TrainSpec( collections.namedtuple('TrainSpec', ['input_fn', 'max_steps', 'hooks'])): - """Objects passed to `train_and_evaluate`. + """Configuration for the "train" part for the `train_and_evaluate` call. - `TrainSpec` fully defines the objects to be run by `Estimator.train`. + `TrainSpec` determines the input data for the training, as well as the + duration. Optional hooks run at various stages of training. """ def __new__(cls, @@ -127,9 +128,10 @@ class TrainSpec( input_fn: Training input function returning a tuple of: features - `Tensor` or dictionary of string feature name to `Tensor`. labels - `Tensor` or dictionary of `Tensor` with labels. - max_steps: Int. Number of total steps for which to train model. If `None`, - train forever or train until `input_fn` generates the `OutOfRange` error - or `StopIteration` exception. See `Estimator.train` for details. + max_steps: Int. Positive number of total steps for which to train model. + If `None`, train forever. The training `input_fn` is not expected to + generate `OutOfRangeError` or `StopIteration` exceptions. See the + `train_and_evaluate` stop condition section for details. hooks: Iterable of `tf.train.SessionRunHook` objects to run on all workers (including chief) during training. @@ -137,8 +139,8 @@ class TrainSpec( A validated `TrainSpec` object. Raises: - ValueError: If validation fails. - TypeError: If any of the arguments is not the expected type. + ValueError: If any of the input arguments is invalid. + TypeError: If any of the arguments is not of the expected type. """ # Validate input_fn. _validate_input_fn(input_fn) @@ -163,10 +165,12 @@ class EvalSpec( 'input_fn', 'steps', 'name', 'hooks', 'exporters', 'delay_secs', 'throttle_secs' ])): - """Objects passed to `train_and_evaluate`. + """Configuration for the "eval" part for the `train_and_evaluate` call. - `EvalSpec` fully defines the objects to be run by `Estimator.evaluate` and - `Estimator.export_savedmodel`. + `EvalSpec` combines details of evaluation of the trained model as well as its + export. Evaluation consists of computing metrics to judge the performance of + the trained model. Export writes out the trained model on to external + storage. """ def __new__(cls, @@ -180,12 +184,12 @@ class EvalSpec( """Creates a validated `EvalSpec` instance. Args: - input_fn: Training input function returning a tuple of: + input_fn: Evaluation input function returning a tuple of: features - `Tensor` or dictionary of string feature name to `Tensor`. labels - `Tensor` or dictionary of `Tensor` with labels. - steps: Int. Number of total steps for which to train model. If `None`, - train forever or train until `input_fn` generates the `OutOfRange` error - or `StopIteration` exception. See `Estimator.train` for details. + steps: Int. Positive number of steps for which to evaluate model. If + `None`, evaluates until `input_fn` raises an end-of-input exception. + See `Estimator.evaluate` for details. name: String. Name of the evaluation if user needs to run multiple evaluations on different data sets. Metrics for different evaluations are saved in separate folders, and appear separately in tensorboard. @@ -196,14 +200,14 @@ class EvalSpec( delay_secs: Int. Start evaluating after waiting for this many seconds. throttle_secs: Int. Do not re-evaluate unless the last evaluation was started at least this many seconds ago. Of course, evaluation does not - occur if no new checkpoint is available, hence, this is the minimum. + occur if no new checkpoints are available, hence, this is the minimum. Returns: - A validated `TrainSpec` object. + A validated `EvalSpec` object. Raises: - ValueError: If validation fails. - TypeError: If any of the arguments is not the expected type. + ValueError: If any of the input arguments is invalid. + TypeError: If any of the arguments is not of the expected type. """ # Validate input_fn. _validate_input_fn(input_fn) @@ -243,10 +247,168 @@ class EvalSpec( throttle_secs=throttle_secs) -# TODO(xiejw): Write detailed docstring to cover local behavior and distributed -# behavior. Also write examples for both with TF_CONFIG. def train_and_evaluate(estimator, train_spec, eval_spec): - """Train and evaluate the `estimator`.""" + """Train and evaluate the `estimator`. + + This utility function trains, evaluates, and (optionally) exports the model by + using the given `estimator`. All training related specification is held in + `train_spec`, including training `input_fn` and training max steps, etc. All + evaluation and export related specification is held in `eval_spec`, including + evaluation `input_fn`, steps, etc. + + This utility function provides consistent behavior for both local + (non-distributed) and distributed configurations. Currently, the only + supported distributed training configuration is between-graph replication. + + Overfitting: In order to avoid overfitting, it is recommended to set up the + training `input_fn` to shuffle the training data properly. It is also + recommended to train the model a little longer, say multiple epochs, before + performing evaluation, as the input pipeline starts from scratch for each + training. It is particularly important for local training and evaluation. + + Stop condition: In order to support both distributed and non-distributed + configuration reliably, the only supported stop condition for model + training is `train_spec.max_steps`. If `train_spec.max_steps` is `None`, the + model is trained forever. *Use with care* if model stop condition is + different. For example, assume that the model is expected to be trained with + one epoch of training data, and the training `input_fn` is configured to throw + `OutOfRangeError` after going through one epoch, which stops the + `Estimator.train`. For a three-training-worker distributed configuration, each + training worker is likely to go through the whole epoch independently. So, the + model will be trained with three epochs of training data instead of one epoch. + + Example of local (non-distributed) training: + ```python + # Set up feature columns. + categorial_feature_a = categorial_column_with_hash_bucket(...) + categorial_feature_a_emb = embedding_column( + categorical_column=categorial_feature_a, ...) + ... # other feature columns + + estimator = DNNClassifier( + feature_columns=[categorial_feature_a_emb, ...], + hidden_units=[1024, 512, 256]) + + # Or set up the model directory + # estimator = DNNClassifier( + # config=tf.estimator.RunConfig( + # model_dir='/my_model', save_summary_steps=100), + # feature_columns=[categorial_feature_a_emb, ...], + # hidden_units=[1024, 512, 256]) + + # Input pipeline for train and evaluate. + def train_input_fn: # returns x, y + # please shuffle the data. + pass + def eval_input_fn_eval: # returns x, y + pass + + train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn, max_steps=1000) + eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn) + + tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec) + ``` + + Example of distributed training: + + Regarding the example of distributed training, the code above can be used + without a change (Please do make sure that the `RunConfig.model_dir` for all + workers is set to the same directory, i.e., a shared file system all workers + can read and write). The only extra work to do is setting the environment + variable `TF_CONFIG` properly for each worker correspondingly. + + Also see: https://www.tensorflow.org/deploy/distributed + + Setting environment variable depends on the platform. For example, on Linux, + it can be done as follows (`$` is the shell prompt): + ``` + $ TF_CONFIG="<replace_with_real_content>" python train_model.py + ``` + + For the content in `TF_CONFIG`, assume that the training cluster spec looks + like: + ``` + cluster = {'chief': ['host0:2222'], + 'worker': ['host1:2222', 'host2:2222', 'host3:2222'], + 'ps': ['host4:2222', 'host5:2222']} + ``` + + Example of `TF_CONFIG` for chief training worker (must have one and only one): + ``` + # This should be a JSON string, which is set as environment variable. Usually + # the cluster manager handles that. + TF_CONFIG="{ + 'cluster': { + 'chief': ['host0:2222'], + 'worker': ['host1:2222', 'host2:2222', 'host3:2222'], + 'ps': ['host4:2222', 'host5:2222'] + }, + 'task': {'type': 'chief', 'index': 0} + }" + ``` + Note that the chief worker also does the model training job, similar to other + non-chief training workers (see next paragraph). In addition to the model + training, it manages some extra work, e.g., checkpoint saving and restoring, + writing summaries, etc. + + Example of `TF_CONFIG` for non-chief training worker (optional, could be + multiple): + ``` + # This should be a JSON string, which is set as environment variable. Usually + # the cluster manager handles that. + TF_CONFIG="{ + 'cluster': { + 'chief': ['host0:2222'], + 'worker': ['host1:2222', 'host2:2222', 'host3:2222'], + 'ps': ['host4:2222', 'host5:2222'] + }, + 'task': {'type': 'worker', 'index': 0} + }" + ``` + where the `task.index` should be set as 0, 1, 2, in this example, respectively + for non-chief training workers. + + Example of `TF_CONFIG` for parameter server, aka ps (could be multiple): + ``` + # This should be a JSON string, which is set as environment variable. Usually + # the cluster manager handles that. + TF_CONFIG="{ + 'cluster': { + 'chief': ['host0:2222'], + 'worker': ['host1:2222', 'host2:2222', 'host3:2222'], + 'ps': ['host4:2222', 'host5:2222'] + }, + 'task': {'type': 'ps', 'index': 0} + }" + ``` + where the `task.index` should be set as 0 and 1, in this example, respectively + for parameter servers. + + Example of `TF_CONFIG` for evaluator task. Evaluator is a special task that is + not part of the training cluster. There could be only one. It is used for + model evaluation. + ``` + # This should be a JSON string, which is set as environment variable. Usually + # the cluster manager handles that. + TF_CONFIG="{ + 'cluster': { + 'chief': ['host0:2222'], + 'worker': ['host1:2222', 'host2:2222', 'host3:2222'], + 'ps': ['host4:2222', 'host5:2222'] + }, + 'task': {'type': 'evaluator', 'index': 0} + }" + ``` + + Args: + estimator: An `Estimator` instance to train and evaluate. + train_spec: A `TrainSpec instance to specify the training specification. + eval_spec: A `EvalSpec instance to specify the evaluation and export + specification. + + Raises: + ValueError: if environment variable `TF_CONFIG` is incorrectly set. + """ if not isinstance(estimator, estimator_lib.Estimator): raise TypeError('`estimator` must have type `tf.estimator.Estimator`, ' @@ -259,7 +421,8 @@ def train_and_evaluate(estimator, train_spec, eval_spec): if (not config.cluster_spec and config.task_type != run_config_lib.TaskType.EVALUATOR): logging.info('Running training and evaluation locally (non-distributed).') - return executor.run_local() + executor.run_local() + return # Distributed case. if not config.task_type: @@ -269,6 +432,8 @@ def train_and_evaluate(estimator, train_spec, eval_spec): '`estimator.config` must have task_type set. This usually means ' 'TF_CONFIG environment is not set correctly.') + # TODO(xiejw): error out if evaluator index is more than 0. + if config.task_type == 'local': raise ValueError( '`task.type` in TF_CONFIG cannot be `local`. Leaving `cluster` and ' @@ -284,7 +449,8 @@ def train_and_evaluate(estimator, train_spec, eval_spec): raise ValueError( 'Task type {} is not supported. Supported task types are {}'.format( config.task_type, [x[len('run_'):] for x in available_tasks])) - return getattr(executor, task_to_run)() + getattr(executor, task_to_run)() + return class _StopAtSecsHook(session_run_hook.SessionRunHook): |