diff options
author | Neal Wu <wun@google.com> | 2017-11-14 13:43:50 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-11-14 13:47:48 -0800 |
commit | 29d84f18369cfe08beae97cff0aa8bde601b4cfc (patch) | |
tree | 4d00fae36f80b1f145504dcd4df047eb550cd048 /tensorflow/examples/learn | |
parent | 144eaa8e273da43b7ca881d7dcac98b65f698f11 (diff) |
Remove wide_n_deep_tutorial.py in tensorflow/examples/learn in favor of wide_deep.py in the TensorFlow official models
PiperOrigin-RevId: 175728483
Diffstat (limited to 'tensorflow/examples/learn')
-rw-r--r-- | tensorflow/examples/learn/BUILD | 8 | ||||
-rw-r--r-- | tensorflow/examples/learn/README.md | 2 | ||||
-rwxr-xr-x | tensorflow/examples/learn/examples_test.sh | 1 | ||||
-rw-r--r-- | tensorflow/examples/learn/wide_n_deep_tutorial.py | 252 |
4 files changed, 1 insertions, 262 deletions
diff --git a/tensorflow/examples/learn/BUILD b/tensorflow/examples/learn/BUILD index 23a42a60ba..aba7f600b5 100644 --- a/tensorflow/examples/learn/BUILD +++ b/tensorflow/examples/learn/BUILD @@ -114,13 +114,6 @@ py_binary( ) py_binary( - name = "wide_n_deep_tutorial", - srcs = ["wide_n_deep_tutorial.py"], - srcs_version = "PY2AND3", - deps = ["//tensorflow:tensorflow_py"], -) - -py_binary( name = "mnist", srcs = ["mnist.py"], srcs_version = "PY2AND3", @@ -153,7 +146,6 @@ sh_test( ":text_classification_character_cnn", ":text_classification_character_rnn", ":text_classification_cnn", - ":wide_n_deep_tutorial", ], tags = [ "manual", diff --git a/tensorflow/examples/learn/README.md b/tensorflow/examples/learn/README.md index 70d9db85ee..b74a8f39d9 100644 --- a/tensorflow/examples/learn/README.md +++ b/tensorflow/examples/learn/README.md @@ -23,7 +23,7 @@ processing (`pip install -U pandas`). ## Specialized Models * [Building a Random Forest Model](https://www.tensorflow.org/code/tensorflow/examples/learn/random_forest_mnist.py) -* [Building a Wide & Deep Model](https://www.tensorflow.org/code/tensorflow/examples/learn/wide_n_deep_tutorial.py) +* [Building a Wide & Deep Model](https://github.com/tensorflow/models/tree/master/official/wide_deep/wide_deep.py) * [Building a Residual Network Model](https://www.tensorflow.org/code/tensorflow/examples/learn/resnet.py) ## Text classification diff --git a/tensorflow/examples/learn/examples_test.sh b/tensorflow/examples/learn/examples_test.sh index b8763de471..ef5e8a5de2 100755 --- a/tensorflow/examples/learn/examples_test.sh +++ b/tensorflow/examples/learn/examples_test.sh @@ -56,4 +56,3 @@ test text_classification_builtin_rnn_model --test_with_fake_data test text_classification_character_cnn --test_with_fake_data test text_classification_character_rnn --test_with_fake_data test text_classification_cnn --test_with_fake_data -test wide_n_deep_tutorial diff --git a/tensorflow/examples/learn/wide_n_deep_tutorial.py b/tensorflow/examples/learn/wide_n_deep_tutorial.py deleted file mode 100644 index 072353392a..0000000000 --- a/tensorflow/examples/learn/wide_n_deep_tutorial.py +++ /dev/null @@ -1,252 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Example code for TensorFlow Wide & Deep Tutorial using TF High Level API. - -This example uses APIs in Tensorflow 1.4 or above. -""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import argparse -import shutil -import sys -import tempfile - -import pandas as pd -from six.moves import urllib -import tensorflow as tf - - -CSV_COLUMNS = [ - "age", "workclass", "fnlwgt", "education", "education_num", - "marital_status", "occupation", "relationship", "race", "gender", - "capital_gain", "capital_loss", "hours_per_week", "native_country", - "income_bracket" -] - -gender = tf.feature_column.categorical_column_with_vocabulary_list( - "gender", ["Female", "Male"]) -education = tf.feature_column.categorical_column_with_vocabulary_list( - "education", [ - "Bachelors", "HS-grad", "11th", "Masters", "9th", - "Some-college", "Assoc-acdm", "Assoc-voc", "7th-8th", - "Doctorate", "Prof-school", "5th-6th", "10th", "1st-4th", - "Preschool", "12th" - ]) -marital_status = tf.feature_column.categorical_column_with_vocabulary_list( - "marital_status", [ - "Married-civ-spouse", "Divorced", "Married-spouse-absent", - "Never-married", "Separated", "Married-AF-spouse", "Widowed" - ]) -relationship = tf.feature_column.categorical_column_with_vocabulary_list( - "relationship", [ - "Husband", "Not-in-family", "Wife", "Own-child", "Unmarried", - "Other-relative" - ]) -workclass = tf.feature_column.categorical_column_with_vocabulary_list( - "workclass", [ - "Self-emp-not-inc", "Private", "State-gov", "Federal-gov", - "Local-gov", "?", "Self-emp-inc", "Without-pay", "Never-worked" - ]) - -# To show an example of hashing: -occupation = tf.feature_column.categorical_column_with_hash_bucket( - "occupation", hash_bucket_size=1000) -native_country = tf.feature_column.categorical_column_with_hash_bucket( - "native_country", hash_bucket_size=1000) - -# Continuous base columns. -age = tf.feature_column.numeric_column("age") -education_num = tf.feature_column.numeric_column("education_num") -capital_gain = tf.feature_column.numeric_column("capital_gain") -capital_loss = tf.feature_column.numeric_column("capital_loss") -hours_per_week = tf.feature_column.numeric_column("hours_per_week") - -# Transformations. -age_buckets = tf.feature_column.bucketized_column( - age, boundaries=[18, 25, 30, 35, 40, 45, 50, 55, 60, 65]) - -# Wide columns and deep columns. -base_columns = [ - gender, education, marital_status, relationship, workclass, occupation, - native_country, age_buckets, -] - -crossed_columns = [ - tf.feature_column.crossed_column( - ["education", "occupation"], hash_bucket_size=1000), - tf.feature_column.crossed_column( - [age_buckets, "education", "occupation"], hash_bucket_size=1000), - tf.feature_column.crossed_column( - ["native_country", "occupation"], hash_bucket_size=1000) -] - -deep_columns = [ - tf.feature_column.indicator_column(workclass), - tf.feature_column.indicator_column(education), - tf.feature_column.indicator_column(gender), - tf.feature_column.indicator_column(relationship), - # To show an example of embedding - tf.feature_column.embedding_column(native_country, dimension=8), - tf.feature_column.embedding_column(occupation, dimension=8), - age, - education_num, - capital_gain, - capital_loss, - hours_per_week, -] - - -FLAGS = None - - -def maybe_download(train_data, test_data): - """Maybe downloads training data and returns train and test file names.""" - if train_data: - train_file_name = train_data - else: - train_file = tempfile.NamedTemporaryFile(delete=False) - urllib.request.urlretrieve( - "https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data", - train_file.name) # pylint: disable=line-too-long - train_file_name = train_file.name - train_file.close() - print("Training data is downloaded to %s" % train_file_name) - - if test_data: - test_file_name = test_data - else: - test_file = tempfile.NamedTemporaryFile(delete=False) - urllib.request.urlretrieve( - "https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.test", - test_file.name) # pylint: disable=line-too-long - test_file_name = test_file.name - test_file.close() - print("Test data is downloaded to %s"% test_file_name) - - return train_file_name, test_file_name - - -def build_estimator(model_dir, model_type): - """Build an estimator.""" - if model_type == "wide": - m = tf.estimator.LinearClassifier( - model_dir=model_dir, feature_columns=base_columns + crossed_columns) - elif model_type == "deep": - m = tf.estimator.DNNClassifier( - model_dir=model_dir, - feature_columns=deep_columns, - hidden_units=[100, 50]) - else: - m = tf.estimator.DNNLinearCombinedClassifier( - model_dir=model_dir, - linear_feature_columns=crossed_columns, - dnn_feature_columns=deep_columns, - dnn_hidden_units=[100, 50]) - return m - - -def input_fn(data_file, num_epochs, shuffle): - """Returns an `input_fn` required by Estimator train/evaluate. - - Args: - data_file: The file path to the dataset. - num_epochs: Number of epochs to iterate over data. If `None`, `input_fn` - will generate infinite stream of data. - shuffle: bool, whether to read the data in random order. - """ - df_data = pd.read_csv( - tf.gfile.Open(data_file), - names=CSV_COLUMNS, - skipinitialspace=True, - engine="python", - skiprows=1) - # remove NaN elements - df_data = df_data.dropna(how="any", axis=0) - labels = df_data["income_bracket"].apply(lambda x: ">50K" in x).astype(int) - - return tf.estimator.inputs.pandas_input_fn( - x=df_data, - y=labels, - batch_size=100, - num_epochs=num_epochs, - shuffle=shuffle, - num_threads=1) - - -def main(_): - tf.logging.set_verbosity(tf.logging.INFO) - - train_file_name, test_file_name = maybe_download(FLAGS.train_data, - FLAGS.test_data) - - # Specify file path below if want to find the output easily - model_dir = FLAGS.model_dir if FLAGS.model_dir else tempfile.mkdtemp() - - estimator = build_estimator(model_dir, FLAGS.model_type) - - # `tf.estimator.TrainSpec`, `tf.estimator.EvalSpec`, and - # `tf.estimator.train_and_evaluate` API are available in TF 1.4. - train_spec = tf.estimator.TrainSpec( - input_fn=input_fn(train_file_name, num_epochs=None, shuffle=True), - max_steps=FLAGS.train_steps) - - eval_spec = tf.estimator.EvalSpec( - input_fn=input_fn(test_file_name, num_epochs=1, shuffle=False), - # set steps to None to run evaluation until all data consumed. - steps=None) - - tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec) - - # Manual cleanup - shutil.rmtree(model_dir) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.register("type", "bool", lambda v: v.lower() == "true") - parser.add_argument( - "--model_dir", - type=str, - default="", - help="Base directory for output models." - ) - parser.add_argument( - "--model_type", - type=str, - default="wide_n_deep", - help="Valid model types: {'wide', 'deep', 'wide_n_deep'}." - ) - parser.add_argument( - "--train_steps", - type=int, - default=2000, - help="Number of training steps." - ) - parser.add_argument( - "--train_data", - type=str, - default="", - help="Path to the training data." - ) - parser.add_argument( - "--test_data", - type=str, - default="", - help="Path to the test data." - ) - FLAGS, unparsed = parser.parse_known_args() - tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) |