aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/estimator
diff options
context:
space:
mode:
authorGravatar Katherine Wu <kathywu@google.com>2018-07-27 13:07:31 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-27 13:11:10 -0700
commit6a19ae36acb6ac60f46b046efc3cc0672a7dca42 (patch)
treebaf4bbba7984de8faa7d0ab27766144be2b2cf89 /tensorflow/contrib/estimator
parentab9f0a628f61fcb19b6b09cb51bf05ff8c702a80 (diff)
Fix SavedModelEstimator docstring formatting.
PiperOrigin-RevId: 206361654
Diffstat (limited to 'tensorflow/contrib/estimator')
-rw-r--r--tensorflow/contrib/estimator/python/estimator/saved_model_estimator.py14
1 files changed, 9 insertions, 5 deletions
diff --git a/tensorflow/contrib/estimator/python/estimator/saved_model_estimator.py b/tensorflow/contrib/estimator/python/estimator/saved_model_estimator.py
index f3d0f6b047..b0082f7e55 100644
--- a/tensorflow/contrib/estimator/python/estimator/saved_model_estimator.py
+++ b/tensorflow/contrib/estimator/python/estimator/saved_model_estimator.py
@@ -46,6 +46,7 @@ class SavedModelEstimator(estimator_lib.Estimator):
Example with `tf.estimator.DNNClassifier`:
**Step 1: Create and train DNNClassifier.**
+
```python
feature1 = tf.feature_column.embedding_column(
tf.feature_column.categorical_column_with_vocabulary_list(
@@ -66,13 +67,14 @@ class SavedModelEstimator(estimator_lib.Estimator):
**Step 2: Export classifier.**
First, build functions that specify the expected inputs.
+
```python
# During train and evaluation, both the features and labels should be defined.
supervised_input_receiver_fn = (
tf.contrib.estimator.build_raw_supervised_input_receiver_fn(
- {'feature1': tf.placeholder(dtype=tf.string, shape=[None]),
- 'feature2': tf.placeholder(dtype=tf.float32, shape=[None])},
- tf.placeholder(dtype=tf.float32, shape=[None])))
+ {'feature1': tf.placeholder(dtype=tf.string, shape=[None]),
+ 'feature2': tf.placeholder(dtype=tf.float32, shape=[None])},
+ tf.placeholder(dtype=tf.float32, shape=[None])))
# During predict mode, expect to receive a `tf.Example` proto, so a parsing
# function is used.
@@ -83,6 +85,7 @@ class SavedModelEstimator(estimator_lib.Estimator):
Next, export the model as a SavedModel. A timestamped directory will be
created (for example `/tmp/export_all/1234567890`).
+
```python
# Option 1: Save all modes (train, eval, predict)
export_dir = tf.contrib.estimator.export_all_saved_models(
@@ -93,10 +96,11 @@ class SavedModelEstimator(estimator_lib.Estimator):
# Option 2: Only export predict mode
export_dir = classifier.export_savedmodel(
- '/tmp/export_predict', serving_input_receiver_fn)
+ '/tmp/export_predict', serving_input_receiver_fn)
```
**Step 3: Create a SavedModelEstimator from the exported SavedModel.**
+
```python
est = tf.contrib.estimator.SavedModelEstimator(export_dir)
@@ -108,7 +112,7 @@ class SavedModelEstimator(estimator_lib.Estimator):
est.train(input_fn=input_fn, steps=20)
def predict_input_fn():
- example = example_pb2.Example()
+ example = tf.train.Example()
example.features.feature['feature1'].bytes_list.value.extend(['yellow'])
example.features.feature['feature2'].float_list.value.extend([1.])
return {'inputs':tf.constant([example.SerializeToString()])}