diff options
author | 2017-06-30 14:11:23 -0700 | |
---|---|---|
committer | 2017-06-30 14:16:12 -0700 | |
commit | e766804ba7c711ff785b7e14311b56d0d3c9b487 (patch) | |
tree | fe76078158067a6db3fe1a3a2c2e58e9d7c094ea /tensorflow/examples/learn | |
parent | b8237a4583f08f3164b5213f27aaf7c1add0c9a5 (diff) |
Updates wide and wide_n_deep tutorials.
PiperOrigin-RevId: 160686911
Diffstat (limited to 'tensorflow/examples/learn')
-rw-r--r-- | tensorflow/examples/learn/wide_n_deep_tutorial.py | 12 |
1 files changed, 5 insertions, 7 deletions
diff --git a/tensorflow/examples/learn/wide_n_deep_tutorial.py b/tensorflow/examples/learn/wide_n_deep_tutorial.py index 6a3ae50f0b..48c207bed1 100644 --- a/tensorflow/examples/learn/wide_n_deep_tutorial.py +++ b/tensorflow/examples/learn/wide_n_deep_tutorial.py @@ -34,7 +34,7 @@ CSV_COLUMNS = [ ] gender = tf.feature_column.categorical_column_with_vocabulary_list( - "gender", [" Female", " Male"]) + "gender", ["Female", "Male"]) education = tf.feature_column.categorical_column_with_vocabulary_list( "education", [ "Bachelors", "HS-grad", "11th", "Masters", "9th", @@ -42,7 +42,7 @@ education = tf.feature_column.categorical_column_with_vocabulary_list( "Doctorate", "Prof-school", "5th-6th", "10th", "1st-4th", "Preschool", "12th" ]) -tf.feature_column.categorical_column_with_vocabulary_list( +marital_status = tf.feature_column.categorical_column_with_vocabulary_list( "marital_status", [ "Married-civ-spouse", "Divorced", "Married-spouse-absent", "Never-married", "Separated", "Married-AF-spouse", "Widowed" @@ -55,7 +55,7 @@ relationship = tf.feature_column.categorical_column_with_vocabulary_list( workclass = tf.feature_column.categorical_column_with_vocabulary_list( "workclass", [ "Self-emp-not-inc", "Private", "State-gov", "Federal-gov", - "Local-gov", "?", "Self-emp-inc", "Without-pay", " Never-worked" + "Local-gov", "?", "Self-emp-inc", "Without-pay", "Never-worked" ]) # To show an example of hashing: @@ -77,8 +77,8 @@ age_buckets = tf.feature_column.bucketized_column( # Wide columns and deep columns. base_columns = [ - gender, native_country, education, occupation, workclass, relationship, - age_buckets, + gender, education, marital_status, relationship, workclass, occupation, + native_country, age_buckets, ] crossed_columns = [ @@ -135,8 +135,6 @@ def maybe_download(train_data, test_data): def build_estimator(model_dir, model_type): """Build an estimator.""" - # Categorical base columns. - if model_type == "wide": m = tf.estimator.LinearClassifier( model_dir=model_dir, feature_columns=base_columns + crossed_columns) |