aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/examples/learn
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-06-30 14:11:23 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-30 14:16:12 -0700
commite766804ba7c711ff785b7e14311b56d0d3c9b487 (patch)
treefe76078158067a6db3fe1a3a2c2e58e9d7c094ea /tensorflow/examples/learn
parentb8237a4583f08f3164b5213f27aaf7c1add0c9a5 (diff)
Updates wide and wide_n_deep tutorials.
PiperOrigin-RevId: 160686911
Diffstat (limited to 'tensorflow/examples/learn')
-rw-r--r--tensorflow/examples/learn/wide_n_deep_tutorial.py12
1 files changed, 5 insertions, 7 deletions
diff --git a/tensorflow/examples/learn/wide_n_deep_tutorial.py b/tensorflow/examples/learn/wide_n_deep_tutorial.py
index 6a3ae50f0b..48c207bed1 100644
--- a/tensorflow/examples/learn/wide_n_deep_tutorial.py
+++ b/tensorflow/examples/learn/wide_n_deep_tutorial.py
@@ -34,7 +34,7 @@ CSV_COLUMNS = [
]
gender = tf.feature_column.categorical_column_with_vocabulary_list(
- "gender", [" Female", " Male"])
+ "gender", ["Female", "Male"])
education = tf.feature_column.categorical_column_with_vocabulary_list(
"education", [
"Bachelors", "HS-grad", "11th", "Masters", "9th",
@@ -42,7 +42,7 @@ education = tf.feature_column.categorical_column_with_vocabulary_list(
"Doctorate", "Prof-school", "5th-6th", "10th", "1st-4th",
"Preschool", "12th"
])
-tf.feature_column.categorical_column_with_vocabulary_list(
+marital_status = tf.feature_column.categorical_column_with_vocabulary_list(
"marital_status", [
"Married-civ-spouse", "Divorced", "Married-spouse-absent",
"Never-married", "Separated", "Married-AF-spouse", "Widowed"
@@ -55,7 +55,7 @@ relationship = tf.feature_column.categorical_column_with_vocabulary_list(
workclass = tf.feature_column.categorical_column_with_vocabulary_list(
"workclass", [
"Self-emp-not-inc", "Private", "State-gov", "Federal-gov",
- "Local-gov", "?", "Self-emp-inc", "Without-pay", " Never-worked"
+ "Local-gov", "?", "Self-emp-inc", "Without-pay", "Never-worked"
])
# To show an example of hashing:
@@ -77,8 +77,8 @@ age_buckets = tf.feature_column.bucketized_column(
# Wide columns and deep columns.
base_columns = [
- gender, native_country, education, occupation, workclass, relationship,
- age_buckets,
+ gender, education, marital_status, relationship, workclass, occupation,
+ native_country, age_buckets,
]
crossed_columns = [
@@ -135,8 +135,6 @@ def maybe_download(train_data, test_data):
def build_estimator(model_dir, model_type):
"""Build an estimator."""
- # Categorical base columns.
-
if model_type == "wide":
m = tf.estimator.LinearClassifier(
model_dir=model_dir, feature_columns=base_columns + crossed_columns)