aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/examples/learn
diff options
context:
space:
mode:
authorGravatar Mustafa Ispir <ispir@google.com>2017-06-23 13:18:53 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-23 13:22:27 -0700
commit160ff06ac0cd0551044c7f650a3bb0d6f3d074f5 (patch)
treef0b57a7c1e88e8dd495d77d70a6e856c321f92b9 /tensorflow/examples/learn
parente01611369f29eb18565bc77512884b908fde70ff (diff)
Updated wide-n-deep tutorial code to use core version of estimators and feature-columns.
PiperOrigin-RevId: 159984663
Diffstat (limited to 'tensorflow/examples/learn')
-rw-r--r--tensorflow/examples/learn/wide_n_deep_tutorial.py242
1 files changed, 123 insertions, 119 deletions
diff --git a/tensorflow/examples/learn/wide_n_deep_tutorial.py b/tensorflow/examples/learn/wide_n_deep_tutorial.py
index a0c6df821a..6a3ae50f0b 100644
--- a/tensorflow/examples/learn/wide_n_deep_tutorial.py
+++ b/tensorflow/examples/learn/wide_n_deep_tutorial.py
@@ -21,21 +21,89 @@ import argparse
import sys
import tempfile
-from six.moves import urllib
-
import pandas as pd
+from six.moves import urllib
import tensorflow as tf
-COLUMNS = ["age", "workclass", "fnlwgt", "education", "education_num",
- "marital_status", "occupation", "relationship", "race", "gender",
- "capital_gain", "capital_loss", "hours_per_week", "native_country",
- "income_bracket"]
-LABEL_COLUMN = "label"
-CATEGORICAL_COLUMNS = ["workclass", "education", "marital_status", "occupation",
- "relationship", "race", "gender", "native_country"]
-CONTINUOUS_COLUMNS = ["age", "education_num", "capital_gain", "capital_loss",
- "hours_per_week"]
+CSV_COLUMNS = [
+ "age", "workclass", "fnlwgt", "education", "education_num",
+ "marital_status", "occupation", "relationship", "race", "gender",
+ "capital_gain", "capital_loss", "hours_per_week", "native_country",
+ "income_bracket"
+]
+
+gender = tf.feature_column.categorical_column_with_vocabulary_list(
+ "gender", [" Female", " Male"])
+education = tf.feature_column.categorical_column_with_vocabulary_list(
+ "education", [
+ "Bachelors", "HS-grad", "11th", "Masters", "9th",
+ "Some-college", "Assoc-acdm", "Assoc-voc", "7th-8th",
+ "Doctorate", "Prof-school", "5th-6th", "10th", "1st-4th",
+ "Preschool", "12th"
+ ])
+tf.feature_column.categorical_column_with_vocabulary_list(
+ "marital_status", [
+ "Married-civ-spouse", "Divorced", "Married-spouse-absent",
+ "Never-married", "Separated", "Married-AF-spouse", "Widowed"
+ ])
+relationship = tf.feature_column.categorical_column_with_vocabulary_list(
+ "relationship", [
+ "Husband", "Not-in-family", "Wife", "Own-child", "Unmarried",
+ "Other-relative"
+ ])
+workclass = tf.feature_column.categorical_column_with_vocabulary_list(
+ "workclass", [
+ "Self-emp-not-inc", "Private", "State-gov", "Federal-gov",
+ "Local-gov", "?", "Self-emp-inc", "Without-pay", " Never-worked"
+ ])
+
+# To show an example of hashing:
+occupation = tf.feature_column.categorical_column_with_hash_bucket(
+ "occupation", hash_bucket_size=1000)
+native_country = tf.feature_column.categorical_column_with_hash_bucket(
+ "native_country", hash_bucket_size=1000)
+
+# Continuous base columns.
+age = tf.feature_column.numeric_column("age")
+education_num = tf.feature_column.numeric_column("education_num")
+capital_gain = tf.feature_column.numeric_column("capital_gain")
+capital_loss = tf.feature_column.numeric_column("capital_loss")
+hours_per_week = tf.feature_column.numeric_column("hours_per_week")
+
+# Transformations.
+age_buckets = tf.feature_column.bucketized_column(
+ age, boundaries=[18, 25, 30, 35, 40, 45, 50, 55, 60, 65])
+
+# Wide columns and deep columns.
+base_columns = [
+ gender, native_country, education, occupation, workclass, relationship,
+ age_buckets,
+]
+
+crossed_columns = [
+ tf.feature_column.crossed_column(
+ ["education", "occupation"], hash_bucket_size=1000),
+ tf.feature_column.crossed_column(
+ [age_buckets, "education", "occupation"], hash_bucket_size=1000),
+ tf.feature_column.crossed_column(
+ ["native_country", "occupation"], hash_bucket_size=1000)
+]
+
+deep_columns = [
+ tf.feature_column.indicator_column(workclass),
+ tf.feature_column.indicator_column(education),
+ tf.feature_column.indicator_column(gender),
+ tf.feature_column.indicator_column(relationship),
+ # To show an example of embedding
+ tf.feature_column.embedding_column(native_country, dimension=8),
+ tf.feature_column.embedding_column(occupation, dimension=8),
+ age,
+ education_num,
+ capital_gain,
+ capital_loss,
+ hours_per_week,
+]
def maybe_download(train_data, test_data):
@@ -44,7 +112,9 @@ def maybe_download(train_data, test_data):
train_file_name = train_data
else:
train_file = tempfile.NamedTemporaryFile(delete=False)
- urllib.request.urlretrieve("https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data", train_file.name) # pylint: disable=line-too-long
+ urllib.request.urlretrieve(
+ "https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data",
+ train_file.name) # pylint: disable=line-too-long
train_file_name = train_file.name
train_file.close()
print("Training data is downloaded to %s" % train_file_name)
@@ -53,138 +123,72 @@ def maybe_download(train_data, test_data):
test_file_name = test_data
else:
test_file = tempfile.NamedTemporaryFile(delete=False)
- urllib.request.urlretrieve("https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.test", test_file.name) # pylint: disable=line-too-long
+ urllib.request.urlretrieve(
+ "https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.test",
+ test_file.name) # pylint: disable=line-too-long
test_file_name = test_file.name
test_file.close()
- print("Test data is downloaded to %s" % test_file_name)
+ print("Test data is downloaded to %s"% test_file_name)
return train_file_name, test_file_name
def build_estimator(model_dir, model_type):
"""Build an estimator."""
- # Sparse base columns.
- gender = tf.contrib.layers.sparse_column_with_keys(column_name="gender",
- keys=["female", "male"])
- education = tf.contrib.layers.sparse_column_with_hash_bucket(
- "education", hash_bucket_size=1000)
- relationship = tf.contrib.layers.sparse_column_with_hash_bucket(
- "relationship", hash_bucket_size=100)
- workclass = tf.contrib.layers.sparse_column_with_hash_bucket(
- "workclass", hash_bucket_size=100)
- occupation = tf.contrib.layers.sparse_column_with_hash_bucket(
- "occupation", hash_bucket_size=1000)
- native_country = tf.contrib.layers.sparse_column_with_hash_bucket(
- "native_country", hash_bucket_size=1000)
-
- # Continuous base columns.
- age = tf.contrib.layers.real_valued_column("age")
- education_num = tf.contrib.layers.real_valued_column("education_num")
- capital_gain = tf.contrib.layers.real_valued_column("capital_gain")
- capital_loss = tf.contrib.layers.real_valued_column("capital_loss")
- hours_per_week = tf.contrib.layers.real_valued_column("hours_per_week")
-
- # Transformations.
- age_buckets = tf.contrib.layers.bucketized_column(age,
- boundaries=[
- 18, 25, 30, 35, 40, 45,
- 50, 55, 60, 65
- ])
-
- # Wide columns and deep columns.
- wide_columns = [gender, native_country, education, occupation, workclass,
- relationship, age_buckets,
- tf.contrib.layers.crossed_column([education, occupation],
- hash_bucket_size=int(1e4)),
- tf.contrib.layers.crossed_column(
- [age_buckets, education, occupation],
- hash_bucket_size=int(1e6)),
- tf.contrib.layers.crossed_column([native_country, occupation],
- hash_bucket_size=int(1e4))]
- deep_columns = [
- tf.contrib.layers.embedding_column(workclass, dimension=8),
- tf.contrib.layers.embedding_column(education, dimension=8),
- tf.contrib.layers.embedding_column(gender, dimension=8),
- tf.contrib.layers.embedding_column(relationship, dimension=8),
- tf.contrib.layers.embedding_column(native_country,
- dimension=8),
- tf.contrib.layers.embedding_column(occupation, dimension=8),
- age,
- education_num,
- capital_gain,
- capital_loss,
- hours_per_week,
- ]
+ # Categorical base columns.
if model_type == "wide":
- m = tf.contrib.learn.LinearClassifier(model_dir=model_dir,
- feature_columns=wide_columns)
+ m = tf.estimator.LinearClassifier(
+ model_dir=model_dir, feature_columns=base_columns + crossed_columns)
elif model_type == "deep":
- m = tf.contrib.learn.DNNClassifier(model_dir=model_dir,
- feature_columns=deep_columns,
- hidden_units=[100, 50])
+ m = tf.estimator.DNNClassifier(
+ model_dir=model_dir,
+ feature_columns=deep_columns,
+ hidden_units=[100, 50])
else:
- m = tf.contrib.learn.DNNLinearCombinedClassifier(
+ m = tf.estimator.DNNLinearCombinedClassifier(
model_dir=model_dir,
- linear_feature_columns=wide_columns,
+ linear_feature_columns=crossed_columns,
dnn_feature_columns=deep_columns,
- dnn_hidden_units=[100, 50],
- fix_global_step_increment_bug=True)
+ dnn_hidden_units=[100, 50])
return m
-def input_fn(df):
+def input_fn(data_file, num_epochs, shuffle):
"""Input builder function."""
- # Creates a dictionary mapping from each continuous feature column name (k) to
- # the values of that column stored in a constant Tensor.
- continuous_cols = {k: tf.constant(df[k].values) for k in CONTINUOUS_COLUMNS}
- # Creates a dictionary mapping from each categorical feature column name (k)
- # to the values of that column stored in a tf.SparseTensor.
- categorical_cols = {
- k: tf.SparseTensor(
- indices=[[i, 0] for i in range(df[k].size)],
- values=df[k].values,
- dense_shape=[df[k].size, 1])
- for k in CATEGORICAL_COLUMNS}
- # Merges the two dictionaries into one.
- feature_cols = dict(continuous_cols)
- feature_cols.update(categorical_cols)
- # Converts the label column into a constant Tensor.
- label = tf.constant(df[LABEL_COLUMN].values)
- # Returns the feature columns and the label.
- return feature_cols, label
+ df_data = pd.read_csv(
+ tf.gfile.Open(data_file),
+ names=CSV_COLUMNS,
+ skipinitialspace=True,
+ engine="python",
+ skiprows=1)
+ # remove NaN elements
+ df_data = df_data.dropna(how="any", axis=0)
+ labels = df_data["income_bracket"].apply(lambda x: ">50K" in x).astype(int)
+ return tf.estimator.inputs.pandas_input_fn(
+ x=df_data,
+ y=labels,
+ batch_size=100,
+ num_epochs=num_epochs,
+ shuffle=shuffle,
+ num_threads=5)
def train_and_eval(model_dir, model_type, train_steps, train_data, test_data):
"""Train and evaluate the model."""
train_file_name, test_file_name = maybe_download(train_data, test_data)
- df_train = pd.read_csv(
- tf.gfile.Open(train_file_name),
- names=COLUMNS,
- skipinitialspace=True,
- engine="python")
- df_test = pd.read_csv(
- tf.gfile.Open(test_file_name),
- names=COLUMNS,
- skipinitialspace=True,
- skiprows=1,
- engine="python")
-
- # remove NaN elements
- df_train = df_train.dropna(how='any', axis=0)
- df_test = df_test.dropna(how='any', axis=0)
-
- df_train[LABEL_COLUMN] = (
- df_train["income_bracket"].apply(lambda x: ">50K" in x)).astype(int)
- df_test[LABEL_COLUMN] = (
- df_test["income_bracket"].apply(lambda x: ">50K" in x)).astype(int)
-
model_dir = tempfile.mkdtemp() if not model_dir else model_dir
- print("model directory = %s" % model_dir)
m = build_estimator(model_dir, model_type)
- m.fit(input_fn=lambda: input_fn(df_train), steps=train_steps)
- results = m.evaluate(input_fn=lambda: input_fn(df_test), steps=1)
+ # set num_epochs to None to get infinite stream of data.
+ m.train(
+ input_fn=input_fn(train_file_name, num_epochs=None, shuffle=True),
+ steps=train_steps)
+ # set steps to None to run evaluation until all data consumed.
+ results = m.evaluate(
+ input_fn=input_fn(test_file_name, num_epochs=1, shuffle=False),
+ steps=None)
+ print("model directory = %s" % model_dir)
for key in sorted(results):
print("%s: %s" % (key, results[key]))
@@ -215,7 +219,7 @@ if __name__ == "__main__":
parser.add_argument(
"--train_steps",
type=int,
- default=200,
+ default=2000,
help="Number of training steps."
)
parser.add_argument(