aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Sanders Kleinfeld <skleinfeld@google.com>2017-05-11 13:52:58 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-05-11 13:56:45 -0700
commit20f9b5da8d7477c6cc86ee24fd1c78eb41f789f0 (patch)
tree4f4129b22bec6e441aec740d7bcc704e7762a0ba
parent20be1353372dbb6f2db4b152a37453a2d4e0af14 (diff)
Update to Custom Estimators docs to use input_fn for both .fit() and
.evaluate(). Follow-up to PR #9617 PiperOrigin-RevId: 155792276
-rw-r--r--tensorflow/docs_src/extend/estimators.md32
1 files changed, 25 insertions, 7 deletions
diff --git a/tensorflow/docs_src/extend/estimators.md b/tensorflow/docs_src/extend/estimators.md
index c5444c59ca..f972ee5f50 100644
--- a/tensorflow/docs_src/extend/estimators.md
+++ b/tensorflow/docs_src/extend/estimators.md
@@ -37,14 +37,17 @@ measurements. You'll learn how to do the following:
## Prerequisites
This tutorial assumes you already know tf.contrib.learn API basics, such as
-feature columns and `fit()` operations. If you've never used tf.contrib.learn
-before, or need a refresher, you should first review the following tutorials:
+feature columns, input functions, and `fit()`/`evaluate()`/`predict()`
+operations. If you've never used tf.contrib.learn before, or need a refresher,
+you should first review the following tutorials:
* @{$tflearn$tf.contrib.learn Quickstart}: Quick introduction to
training a neural network using tf.contrib.learn.
* @{$wide$TensorFlow Linear Model Tutorial}: Introduction to
feature columns, and an overview on building a linear classifier in
tf.contrib.learn.
+* @{$input_fn$Building Input Functions with tf.contrib.learn}: Overview of how
+ to construct an input_fn to preprocess and feed data into your models.
## An Abalone Age Predictor {#abalone-predictor}
@@ -239,7 +242,7 @@ nn = tf.contrib.learn.Estimator(
* `params`: An optional dict of hyperparameters (e.g., learning rate, dropout)
that will be passed into the `model_fn`.
-NOTE: Just like `tf.contrib.learn`'s predefined regressors and classifiers, the
+Note: Just like `tf.contrib.learn`'s predefined regressors and classifiers, the
`Estimator` initializer also accepts the general configuration arguments
`model_dir` and `config`.
@@ -252,7 +255,7 @@ code (highlighted in bold below), right after the logging configuration:
<strong># Learning rate for the model
LEARNING_RATE = 0.001</strong></code></pre>
-NOTE: Here, `LEARNING_RATE` is set to `0.001`, but you can tune this value as
+Note: Here, `LEARNING_RATE` is set to `0.001`, but you can tune this value as
needed to achieve the best results during model training.
Then, add the following code to `main()`, which creates the dict `model_params`
@@ -576,7 +579,7 @@ required arguments:
algorithm
(@{tf.train.RMSPropOptimizer})
-NOTE: The `optimize_loss` function supports additional optional arguments to
+Note: The `optimize_loss` function supports additional optional arguments to
further configure the optimizer, such as for implementing decay. See the
@{tf.contrib.layers.optimize_loss$API docs} for more info.
@@ -654,15 +657,30 @@ Add the following code to the end of `main()` to fit the neural network to the
training data and evaluate accuracy:
```python
+def get_train_inputs():
+ x = tf.constant(training_set.data)
+ y = tf.constant(training_set.target)
+ return x, y
+
# Fit
-nn.fit(x=training_set.data, y=training_set.target, steps=5000)
+nn.fit(input_fn=get_train_inputs, steps=5000)
+
+def get_test_inputs():
+ x = tf.constant(test_set.data)
+ y = tf.constant(test_set.target)
+ return x, y
# Score accuracy
-ev = nn.evaluate(x=test_set.data, y=test_set.target, steps=1)
+ev = nn.evaluate(input_fn=get_test_inputs, steps=1)
print("Loss: %s" % ev["loss"])
print("Root Mean Squared Error: %s" % ev["rmse"])
```
+Note: The above code uses input functions to feed feature (`x`) and label (`y`)
+`Tensor`s into the model for both training (`get_train_inputs()`) and evaluation
+(`get_test_inputs()`). To learn more about input functions, see the tutorial
+@{$input_fn$Building Input Functions with tf.contrib.learn}.
+
Then run the code. You should see output like the following:
```none