aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Jianwei Xie <xiejw@google.com>2017-05-11 13:15:04 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-05-11 13:18:26 -0700
commite1e820efb0e46dafd70d8a776b26962927e64454 (patch)
treeccbb64bbf1e726cea84e04c69dc28f25344b5bfb
parent6479ba550226263e1da8a58ad6c81095693e2751 (diff)
Improve the docstring for learn_runner with new experiment_fn.
PiperOrigin-RevId: 155787311
-rw-r--r--tensorflow/contrib/learn/python/learn/learn_runner.py25
-rw-r--r--tensorflow/contrib/learn/python/learn/learn_runner_test.py3
2 files changed, 22 insertions, 6 deletions
diff --git a/tensorflow/contrib/learn/python/learn/learn_runner.py b/tensorflow/contrib/learn/python/learn/learn_runner.py
index 983ac7462a..a3398a87e1 100644
--- a/tensorflow/contrib/learn/python/learn/learn_runner.py
+++ b/tensorflow/contrib/learn/python/learn/learn_runner.py
@@ -68,7 +68,8 @@ def _wrapped_experiment_fn_with_uid_check(experiment_fn, require_hparams=False):
if not isinstance(run_config, run_config_lib.RunConfig):
raise ValueError('`run_config` must be `RunConfig` instance')
if not run_config.model_dir:
- raise ValueError('Must specify a model directory in `run_config`.')
+ raise ValueError(
+ 'Must specify a model directory `model_dir` in `run_config`.')
if hparams is not None and not isinstance(hparams, hparam_lib.HParams):
raise ValueError('`hparams` must be `HParams` instance')
if require_hparams and hparams is None:
@@ -110,6 +111,10 @@ def run(experiment_fn, output_dir=None, schedule=None, run_config=None,
Example with `run_config` (Recommended):
```
def _create_my_experiment(run_config, hparams):
+
+ # You can change a subset of the run_config properties as
+ # run_config = run_config.replace(save_checkpoints_steps=500)
+
return tf.contrib.learn.Experiment(
estimator=my_estimator(config=run_config, hparams=hparams),
train_input_fn=my_train_input,
@@ -118,8 +123,17 @@ def run(experiment_fn, output_dir=None, schedule=None, run_config=None,
learn_runner.run(
experiment_fn=_create_my_experiment,
run_config=run_config_lib.RunConfig(model_dir="some/output/dir"),
- schedule="train",
+ schedule="train_and_evaluate",
hparams=_create_default_hparams())
+ ```
+ or simply as
+ ```
+ learn_runner.run(
+ experiment_fn=_create_my_experiment,
+ run_config=run_config_lib.RunConfig(model_dir="some/output/dir"))
+ ```
+ if `hparams` is not used by the `Estimator`. On a single machine, `schedule`
+ defaults to `train_and_evaluate`.
Example with `output_dir` (deprecated):
```
@@ -147,7 +161,8 @@ def run(experiment_fn, output_dir=None, schedule=None, run_config=None,
It must return an `Experiment`. For this case, `output_dir` must be None.
output_dir: Base output directory [Deprecated].
schedule: The name of the method in the `Experiment` to run.
- run_config: `RunConfig` instance. If set, `output_dir` must be None.
+ run_config: `RunConfig` instance. The `run_config.model_dir` must be
+ non-empty. If `run_config` is set, `output_dir` must be None.
hparams: `HParams` instance. The default hyper-parameters, which will be
passed to the `experiment_fn` if `run_config` is not None.
@@ -157,8 +172,8 @@ def run(experiment_fn, output_dir=None, schedule=None, run_config=None,
Raises:
ValueError: If both `output_dir` and `run_config` are empty or set,
`schedule` is None but no task type is set in the built experiment's
- config, the task type has no default, or `schedule` doesn't reference a
- member of `Experiment`.
+ config, the task type has no default, `run_config.model_dir` is empty or
+ `schedule` doesn't reference a member of `Experiment`.
TypeError: `schedule` references non-callable member.
"""
diff --git a/tensorflow/contrib/learn/python/learn/learn_runner_test.py b/tensorflow/contrib/learn/python/learn/learn_runner_test.py
index 77bdcaeb7e..b61a42a1c7 100644
--- a/tensorflow/contrib/learn/python/learn/learn_runner_test.py
+++ b/tensorflow/contrib/learn/python/learn/learn_runner_test.py
@@ -36,7 +36,8 @@ patch = test.mock.patch
_MODIR_DIR = "/tmp"
_HPARAMS = hparam_lib.HParams(learning_rate=0.01)
_MUST_SPECIFY_OUTPUT_DIR_MSG = "Must specify an output directory"
-_MISSING_MODEL_DIR_ERR_MSG = "Must specify a model directory in `run_config`."
+_MISSING_MODEL_DIR_ERR_MSG = (
+ "Must specify a model directory `model_dir` in `run_config`.")
_EXP_NOT_CALLABLE_MSG = "Experiment builder .* is not callable"
_INVALID_HPARAMS_ERR_MSG = "`hparams` must be `HParams` instance"
_NOT_EXP_TYPE_MSG = "Experiment builder did not return an Experiment"