aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/examples/learn
diff options
context:
space:
mode:
authorGravatar Neal Wu <wun@google.com>2017-11-14 13:43:50 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-14 13:47:48 -0800
commit29d84f18369cfe08beae97cff0aa8bde601b4cfc (patch)
tree4d00fae36f80b1f145504dcd4df047eb550cd048 /tensorflow/examples/learn
parent144eaa8e273da43b7ca881d7dcac98b65f698f11 (diff)
Remove wide_n_deep_tutorial.py in tensorflow/examples/learn in favor of wide_deep.py in the TensorFlow official models
PiperOrigin-RevId: 175728483
Diffstat (limited to 'tensorflow/examples/learn')
-rw-r--r--tensorflow/examples/learn/BUILD8
-rw-r--r--tensorflow/examples/learn/README.md2
-rwxr-xr-xtensorflow/examples/learn/examples_test.sh1
-rw-r--r--tensorflow/examples/learn/wide_n_deep_tutorial.py252
4 files changed, 1 insertions, 262 deletions
diff --git a/tensorflow/examples/learn/BUILD b/tensorflow/examples/learn/BUILD
index 23a42a60ba..aba7f600b5 100644
--- a/tensorflow/examples/learn/BUILD
+++ b/tensorflow/examples/learn/BUILD
@@ -114,13 +114,6 @@ py_binary(
)
py_binary(
- name = "wide_n_deep_tutorial",
- srcs = ["wide_n_deep_tutorial.py"],
- srcs_version = "PY2AND3",
- deps = ["//tensorflow:tensorflow_py"],
-)
-
-py_binary(
name = "mnist",
srcs = ["mnist.py"],
srcs_version = "PY2AND3",
@@ -153,7 +146,6 @@ sh_test(
":text_classification_character_cnn",
":text_classification_character_rnn",
":text_classification_cnn",
- ":wide_n_deep_tutorial",
],
tags = [
"manual",
diff --git a/tensorflow/examples/learn/README.md b/tensorflow/examples/learn/README.md
index 70d9db85ee..b74a8f39d9 100644
--- a/tensorflow/examples/learn/README.md
+++ b/tensorflow/examples/learn/README.md
@@ -23,7 +23,7 @@ processing (`pip install -U pandas`).
## Specialized Models
* [Building a Random Forest Model](https://www.tensorflow.org/code/tensorflow/examples/learn/random_forest_mnist.py)
-* [Building a Wide & Deep Model](https://www.tensorflow.org/code/tensorflow/examples/learn/wide_n_deep_tutorial.py)
+* [Building a Wide & Deep Model](https://github.com/tensorflow/models/tree/master/official/wide_deep/wide_deep.py)
* [Building a Residual Network Model](https://www.tensorflow.org/code/tensorflow/examples/learn/resnet.py)
## Text classification
diff --git a/tensorflow/examples/learn/examples_test.sh b/tensorflow/examples/learn/examples_test.sh
index b8763de471..ef5e8a5de2 100755
--- a/tensorflow/examples/learn/examples_test.sh
+++ b/tensorflow/examples/learn/examples_test.sh
@@ -56,4 +56,3 @@ test text_classification_builtin_rnn_model --test_with_fake_data
test text_classification_character_cnn --test_with_fake_data
test text_classification_character_rnn --test_with_fake_data
test text_classification_cnn --test_with_fake_data
-test wide_n_deep_tutorial
diff --git a/tensorflow/examples/learn/wide_n_deep_tutorial.py b/tensorflow/examples/learn/wide_n_deep_tutorial.py
deleted file mode 100644
index 072353392a..0000000000
--- a/tensorflow/examples/learn/wide_n_deep_tutorial.py
+++ /dev/null
@@ -1,252 +0,0 @@
-# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Example code for TensorFlow Wide & Deep Tutorial using TF High Level API.
-
-This example uses APIs in Tensorflow 1.4 or above.
-"""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import argparse
-import shutil
-import sys
-import tempfile
-
-import pandas as pd
-from six.moves import urllib
-import tensorflow as tf
-
-
-CSV_COLUMNS = [
- "age", "workclass", "fnlwgt", "education", "education_num",
- "marital_status", "occupation", "relationship", "race", "gender",
- "capital_gain", "capital_loss", "hours_per_week", "native_country",
- "income_bracket"
-]
-
-gender = tf.feature_column.categorical_column_with_vocabulary_list(
- "gender", ["Female", "Male"])
-education = tf.feature_column.categorical_column_with_vocabulary_list(
- "education", [
- "Bachelors", "HS-grad", "11th", "Masters", "9th",
- "Some-college", "Assoc-acdm", "Assoc-voc", "7th-8th",
- "Doctorate", "Prof-school", "5th-6th", "10th", "1st-4th",
- "Preschool", "12th"
- ])
-marital_status = tf.feature_column.categorical_column_with_vocabulary_list(
- "marital_status", [
- "Married-civ-spouse", "Divorced", "Married-spouse-absent",
- "Never-married", "Separated", "Married-AF-spouse", "Widowed"
- ])
-relationship = tf.feature_column.categorical_column_with_vocabulary_list(
- "relationship", [
- "Husband", "Not-in-family", "Wife", "Own-child", "Unmarried",
- "Other-relative"
- ])
-workclass = tf.feature_column.categorical_column_with_vocabulary_list(
- "workclass", [
- "Self-emp-not-inc", "Private", "State-gov", "Federal-gov",
- "Local-gov", "?", "Self-emp-inc", "Without-pay", "Never-worked"
- ])
-
-# To show an example of hashing:
-occupation = tf.feature_column.categorical_column_with_hash_bucket(
- "occupation", hash_bucket_size=1000)
-native_country = tf.feature_column.categorical_column_with_hash_bucket(
- "native_country", hash_bucket_size=1000)
-
-# Continuous base columns.
-age = tf.feature_column.numeric_column("age")
-education_num = tf.feature_column.numeric_column("education_num")
-capital_gain = tf.feature_column.numeric_column("capital_gain")
-capital_loss = tf.feature_column.numeric_column("capital_loss")
-hours_per_week = tf.feature_column.numeric_column("hours_per_week")
-
-# Transformations.
-age_buckets = tf.feature_column.bucketized_column(
- age, boundaries=[18, 25, 30, 35, 40, 45, 50, 55, 60, 65])
-
-# Wide columns and deep columns.
-base_columns = [
- gender, education, marital_status, relationship, workclass, occupation,
- native_country, age_buckets,
-]
-
-crossed_columns = [
- tf.feature_column.crossed_column(
- ["education", "occupation"], hash_bucket_size=1000),
- tf.feature_column.crossed_column(
- [age_buckets, "education", "occupation"], hash_bucket_size=1000),
- tf.feature_column.crossed_column(
- ["native_country", "occupation"], hash_bucket_size=1000)
-]
-
-deep_columns = [
- tf.feature_column.indicator_column(workclass),
- tf.feature_column.indicator_column(education),
- tf.feature_column.indicator_column(gender),
- tf.feature_column.indicator_column(relationship),
- # To show an example of embedding
- tf.feature_column.embedding_column(native_country, dimension=8),
- tf.feature_column.embedding_column(occupation, dimension=8),
- age,
- education_num,
- capital_gain,
- capital_loss,
- hours_per_week,
-]
-
-
-FLAGS = None
-
-
-def maybe_download(train_data, test_data):
- """Maybe downloads training data and returns train and test file names."""
- if train_data:
- train_file_name = train_data
- else:
- train_file = tempfile.NamedTemporaryFile(delete=False)
- urllib.request.urlretrieve(
- "https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data",
- train_file.name) # pylint: disable=line-too-long
- train_file_name = train_file.name
- train_file.close()
- print("Training data is downloaded to %s" % train_file_name)
-
- if test_data:
- test_file_name = test_data
- else:
- test_file = tempfile.NamedTemporaryFile(delete=False)
- urllib.request.urlretrieve(
- "https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.test",
- test_file.name) # pylint: disable=line-too-long
- test_file_name = test_file.name
- test_file.close()
- print("Test data is downloaded to %s"% test_file_name)
-
- return train_file_name, test_file_name
-
-
-def build_estimator(model_dir, model_type):
- """Build an estimator."""
- if model_type == "wide":
- m = tf.estimator.LinearClassifier(
- model_dir=model_dir, feature_columns=base_columns + crossed_columns)
- elif model_type == "deep":
- m = tf.estimator.DNNClassifier(
- model_dir=model_dir,
- feature_columns=deep_columns,
- hidden_units=[100, 50])
- else:
- m = tf.estimator.DNNLinearCombinedClassifier(
- model_dir=model_dir,
- linear_feature_columns=crossed_columns,
- dnn_feature_columns=deep_columns,
- dnn_hidden_units=[100, 50])
- return m
-
-
-def input_fn(data_file, num_epochs, shuffle):
- """Returns an `input_fn` required by Estimator train/evaluate.
-
- Args:
- data_file: The file path to the dataset.
- num_epochs: Number of epochs to iterate over data. If `None`, `input_fn`
- will generate infinite stream of data.
- shuffle: bool, whether to read the data in random order.
- """
- df_data = pd.read_csv(
- tf.gfile.Open(data_file),
- names=CSV_COLUMNS,
- skipinitialspace=True,
- engine="python",
- skiprows=1)
- # remove NaN elements
- df_data = df_data.dropna(how="any", axis=0)
- labels = df_data["income_bracket"].apply(lambda x: ">50K" in x).astype(int)
-
- return tf.estimator.inputs.pandas_input_fn(
- x=df_data,
- y=labels,
- batch_size=100,
- num_epochs=num_epochs,
- shuffle=shuffle,
- num_threads=1)
-
-
-def main(_):
- tf.logging.set_verbosity(tf.logging.INFO)
-
- train_file_name, test_file_name = maybe_download(FLAGS.train_data,
- FLAGS.test_data)
-
- # Specify file path below if want to find the output easily
- model_dir = FLAGS.model_dir if FLAGS.model_dir else tempfile.mkdtemp()
-
- estimator = build_estimator(model_dir, FLAGS.model_type)
-
- # `tf.estimator.TrainSpec`, `tf.estimator.EvalSpec`, and
- # `tf.estimator.train_and_evaluate` API are available in TF 1.4.
- train_spec = tf.estimator.TrainSpec(
- input_fn=input_fn(train_file_name, num_epochs=None, shuffle=True),
- max_steps=FLAGS.train_steps)
-
- eval_spec = tf.estimator.EvalSpec(
- input_fn=input_fn(test_file_name, num_epochs=1, shuffle=False),
- # set steps to None to run evaluation until all data consumed.
- steps=None)
-
- tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
-
- # Manual cleanup
- shutil.rmtree(model_dir)
-
-
-if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.register("type", "bool", lambda v: v.lower() == "true")
- parser.add_argument(
- "--model_dir",
- type=str,
- default="",
- help="Base directory for output models."
- )
- parser.add_argument(
- "--model_type",
- type=str,
- default="wide_n_deep",
- help="Valid model types: {'wide', 'deep', 'wide_n_deep'}."
- )
- parser.add_argument(
- "--train_steps",
- type=int,
- default=2000,
- help="Number of training steps."
- )
- parser.add_argument(
- "--train_data",
- type=str,
- default="",
- help="Path to the training data."
- )
- parser.add_argument(
- "--test_data",
- type=str,
- default="",
- help="Path to the test data."
- )
- FLAGS, unparsed = parser.parse_known_args()
- tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)