aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/examples/get_started
diff options
context:
space:
mode:
authorGravatar Mark Daoust <markdaoust@google.com>2017-08-02 09:39:50 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-08-02 09:43:57 -0700
commitf29cfd6363019da5960848ebdaab0576ea52dbb0 (patch)
treebac0abb9e3f67112e0ab4cf2089a837a4428f635 /tensorflow/examples/get_started
parentdfd0df11488f27a89ced83815b69d86ea4bdd44a (diff)
Add examples supporting docs in get_started.
PiperOrigin-RevId: 163994110
Diffstat (limited to 'tensorflow/examples/get_started')
-rw-r--r--tensorflow/examples/get_started/__init__.py19
-rw-r--r--tensorflow/examples/get_started/regression/BUILD37
-rw-r--r--tensorflow/examples/get_started/regression/__init__.py20
-rw-r--r--tensorflow/examples/get_started/regression/dnn_regression.py91
-rw-r--r--tensorflow/examples/get_started/regression/imports85.py107
-rw-r--r--tensorflow/examples/get_started/regression/linear_regression.py96
-rw-r--r--tensorflow/examples/get_started/regression/linear_regression_categorical.py101
-rw-r--r--tensorflow/examples/get_started/regression/test.py73
8 files changed, 544 insertions, 0 deletions
diff --git a/tensorflow/examples/get_started/__init__.py b/tensorflow/examples/get_started/__init__.py
new file mode 100644
index 0000000000..c12e1da97c
--- /dev/null
+++ b/tensorflow/examples/get_started/__init__.py
@@ -0,0 +1,19 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""A collection of "getting started" examples."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
diff --git a/tensorflow/examples/get_started/regression/BUILD b/tensorflow/examples/get_started/regression/BUILD
new file mode 100644
index 0000000000..779fa1e554
--- /dev/null
+++ b/tensorflow/examples/get_started/regression/BUILD
@@ -0,0 +1,37 @@
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
+
+py_test(
+ name = "test",
+ size = "medium",
+ srcs = [
+ "dnn_regression.py",
+ "imports85.py",
+ "linear_regression.py",
+ "linear_regression_categorical.py",
+ "test.py",
+ ],
+ srcs_version = "PY2AND3",
+ tags = [
+ "manual",
+ "notap",
+ ],
+ deps = [
+ "//tensorflow:tensorflow_py",
+ "//third_party/py/numpy",
+ "//third_party/py/pandas",
+ ],
+)
diff --git a/tensorflow/examples/get_started/regression/__init__.py b/tensorflow/examples/get_started/regression/__init__.py
new file mode 100644
index 0000000000..b81f4789f5
--- /dev/null
+++ b/tensorflow/examples/get_started/regression/__init__.py
@@ -0,0 +1,20 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""A collection of regression examples using `Estimators`."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
diff --git a/tensorflow/examples/get_started/regression/dnn_regression.py b/tensorflow/examples/get_started/regression/dnn_regression.py
new file mode 100644
index 0000000000..06f0665e56
--- /dev/null
+++ b/tensorflow/examples/get_started/regression/dnn_regression.py
@@ -0,0 +1,91 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Regression using the DNNRegressor Estimator."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+import imports85 # pylint: disable=g-bad-import-order
+
+STEPS = 5000
+
+
+def main(argv):
+ """Builds, trains, and evaluates the model."""
+ assert len(argv) == 1
+ (x_train, y_train), (x_test, y_test) = imports85.load_data()
+
+ # Build the training input_fn.
+ input_train = tf.estimator.inputs.pandas_input_fn(
+ x=x_train, y=y_train, num_epochs=None, shuffle=True)
+
+ # Build the validation input_fn.
+ input_test = tf.estimator.inputs.pandas_input_fn(
+ x=x_test, y=y_test, shuffle=True)
+
+ # The first way assigns a unique weight to each category. To do this you must
+ # specify the category's vocabulary (values outside this specification will
+ # receive a weight of zero). Here we specify the vocabulary using a list of
+ # options. The vocabulary can also be specified with a vocabulary file (using
+ # `categorical_column_with_vocabulary_file`). For features covering a
+ # range of positive integers use `categorical_column_with_identity`.
+ body_style_vocab = ["hardtop", "wagon", "sedan", "hatchback", "convertible"]
+ body_style = tf.feature_column.categorical_column_with_vocabulary_list(
+ key="body-style", vocabulary_list=body_style_vocab)
+ make = tf.feature_column.categorical_column_with_hash_bucket(
+ key="make", hash_bucket_size=50)
+
+ feature_columns = [
+ tf.feature_column.numeric_column(key="curb-weight"),
+ tf.feature_column.numeric_column(key="highway-mpg"),
+ # Since this is a DNN model, convert categorical columns from sparse
+ # to dense.
+ # Wrap them in an `indicator_column` to create a
+ # one-hot vector from the input.
+ tf.feature_column.indicator_column(body_style),
+ # Or use an `embedding_column` to create a trainable vector for each
+ # index.
+ tf.feature_column.embedding_column(make, dimension=3),
+ ]
+
+ # Build a DNNRegressor, with 2x20-unit hidden layers, with the feature columns
+ # defined above as input.
+ model = tf.estimator.DNNRegressor(
+ hidden_units=[20, 20], feature_columns=feature_columns)
+
+ # Train the model.
+ model.train(input_fn=input_train, steps=STEPS)
+
+ # Evaluate how the model performs on data it has not yet seen.
+ eval_result = model.evaluate(input_fn=input_test)
+
+ # The evaluation returns a Python dictionary. The "average_loss" key holds the
+ # Mean Squared Error (MSE).
+ average_loss = eval_result["average_loss"]
+
+ # Convert MSE to Root Mean Square Error (RMSE).
+ print("\n" + 80 * "*")
+ print("\nRMS error for the test set: ${:.0f}".format(average_loss**0.5))
+
+ print()
+
+
+if __name__ == "__main__":
+ # The Estimator periodically generates "INFO" logs; make these logs visible.
+ tf.logging.set_verbosity(tf.logging.INFO)
+ tf.app.run(main=main)
diff --git a/tensorflow/examples/get_started/regression/imports85.py b/tensorflow/examples/get_started/regression/imports85.py
new file mode 100644
index 0000000000..4532064622
--- /dev/null
+++ b/tensorflow/examples/get_started/regression/imports85.py
@@ -0,0 +1,107 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""A dataset loader for imports85.data."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+import numpy as np
+import pandas as pd
+import tensorflow as tf
+
+header = collections.OrderedDict([
+ ("symboling", np.int32),
+ ("normalized-losses", np.float32),
+ ("make", str),
+ ("fuel-type", str),
+ ("aspiration", str),
+ ("num-of-doors", str),
+ ("body-style", str),
+ ("drive-wheels", str),
+ ("engine-location", str),
+ ("wheel-base", np.float32),
+ ("length", np.float32),
+ ("width", np.float32),
+ ("height", np.float32),
+ ("curb-weight", np.float32),
+ ("engine-type", str),
+ ("num-of-cylinders", str),
+ ("engine-size", np.float32),
+ ("fuel-system", str),
+ ("bore", np.float32),
+ ("stroke", np.float32),
+ ("compression-ratio", np.float32),
+ ("horsepower", np.float32),
+ ("peak-rpm", np.float32),
+ ("city-mpg", np.float32),
+ ("highway-mpg", np.float32),
+ ("price", np.float32)
+]) # pyformat: disable
+
+
+def raw():
+ """Get the imports85 data and load it as a pd.DataFrame."""
+ url = "https://archive.ics.uci.edu/ml/machine-learning-databases/autos/imports-85.data" # pylint: disable=line-too-long
+ # Download and cache the data.
+ path = tf.contrib.keras.utils.get_file(url.split("/")[-1], url)
+
+ # Load the CSV data into a pandas dataframe.
+ df = pd.read_csv(path, names=header.keys(), dtype=header, na_values="?")
+
+ return df
+
+
+def load_data(y_name="price", train_fraction=0.7, seed=None):
+ """Returns the imports85 shuffled and split into train and test subsets.
+
+ A description of the data is available at:
+ https://archive.ics.uci.edu/ml/datasets/automobile
+
+ The data itself can be found at:
+ https://archive.ics.uci.edu/ml/machine-learning-databases/autos/imports-85.data
+
+ Args:
+ y_name: the column to return as the label.
+ train_fraction: the fraction of the dataset to use for training.
+ seed: The random seed to use when shuffling the data. `None` generates a
+ unique shuffle every run.
+ Returns:
+ a pair of pairs where the first pair is the training data, and the second
+ is the test data:
+ `(x_train, y_train), (x_test, y_test) = get_imports85_dataset(...)`
+ `x` contains a pandas DataFrame of features, while `y` contains the label
+ array.
+ """
+ # Load the raw data columns.
+ data = raw()
+
+ # Delete rows with unknowns
+ data = data.dropna()
+
+ # Shuffle the data
+ np.random.seed(seed)
+
+ # Split the data into train/test subsets.
+ x_train = data.sample(frac=train_fraction, random_state=seed)
+ x_test = data.drop(x_train.index)
+
+ # Extract the label from the features dataframe.
+ y_train = x_train.pop(y_name)
+ y_test = x_test.pop(y_name)
+
+ return (x_train, y_train), (x_test, y_test)
diff --git a/tensorflow/examples/get_started/regression/linear_regression.py b/tensorflow/examples/get_started/regression/linear_regression.py
new file mode 100644
index 0000000000..9793163323
--- /dev/null
+++ b/tensorflow/examples/get_started/regression/linear_regression.py
@@ -0,0 +1,96 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Linear regression using the LinearRegressor Estimator."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import tensorflow as tf
+
+import imports85 # pylint: disable=g-bad-import-order
+
+STEPS = 1000
+
+
+def main(argv):
+ """Builds, trains, and evaluates the model."""
+ assert len(argv) == 1
+ (x_train, y_train), (x_test, y_test) = imports85.load_data()
+
+ # Build the training input_fn.
+ input_train = tf.estimator.inputs.pandas_input_fn(
+ x=x_train,
+ y=y_train,
+ # Setting `num_epochs` to `None` lets the `inpuf_fn` generate data
+ # indefinitely, leaving the call to `Estimator.train` in control.
+ num_epochs=None,
+ shuffle=True)
+
+ # Build the validation input_fn.
+ input_test = tf.estimator.inputs.pandas_input_fn(
+ x=x_test, y=y_test, shuffle=True)
+
+ feature_columns = [
+ # "curb-weight" and "highway-mpg" are numeric columns.
+ tf.feature_column.numeric_column(key="curb-weight"),
+ tf.feature_column.numeric_column(key="highway-mpg"),
+ ]
+
+ # Build the Estimator.
+ model = tf.estimator.LinearRegressor(feature_columns=feature_columns)
+
+ # Train the model.
+ # By default, the Estimators log output every 100 steps.
+ model.train(input_fn=input_train, steps=STEPS)
+
+ # Evaluate how the model performs on data it has not yet seen.
+ eval_result = model.evaluate(input_fn=input_test)
+
+ # The evaluation returns a Python dictionary. The "average_loss" key holds the
+ # Mean Squared Error (MSE).
+ average_loss = eval_result["average_loss"]
+
+ # Convert MSE to Root Mean Square Error (RMSE).
+ print("\n" + 80 * "*")
+ print("\nRMS error for the test set: ${:.0f}".format(average_loss**0.5))
+
+ # Run the model in prediction mode.
+ input_dict = {
+ "curb-weight": np.array([2000, 3000]),
+ "highway-mpg": np.array([30, 40])
+ }
+ predict_input_fn = tf.estimator.inputs.numpy_input_fn(
+ input_dict, shuffle=False)
+ predict_results = model.predict(input_fn=predict_input_fn)
+
+ # Print the prediction results.
+ print("\nPrediction results:")
+ for i, prediction in enumerate(predict_results):
+ msg = ("Curb weight: {: 4d}lbs, "
+ "Highway: {: 0d}mpg, "
+ "Prediction: ${: 9.2f}")
+ msg = msg.format(input_dict["curb-weight"][i], input_dict["highway-mpg"][i],
+ prediction["predictions"][0])
+
+ print(" " + msg)
+ print()
+
+
+if __name__ == "__main__":
+ # The Estimator periodically generates "INFO" logs; make these logs visible.
+ tf.logging.set_verbosity(tf.logging.INFO)
+ tf.app.run(main=main)
diff --git a/tensorflow/examples/get_started/regression/linear_regression_categorical.py b/tensorflow/examples/get_started/regression/linear_regression_categorical.py
new file mode 100644
index 0000000000..0a416595e6
--- /dev/null
+++ b/tensorflow/examples/get_started/regression/linear_regression_categorical.py
@@ -0,0 +1,101 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Linear regression with categorical features."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+import imports85 # pylint: disable=g-bad-import-order
+
+STEPS = 1000
+
+
+def main(argv):
+ """Builds, trains, and evaluates the model."""
+ assert len(argv) == 1
+ (x_train, y_train), (x_test, y_test) = imports85.load_data()
+
+ # Build the training input_fn.
+ input_train = tf.estimator.inputs.pandas_input_fn(
+ x=x_train,
+ y=y_train,
+ # Setting `num_epochs` to `None` lets the `inpuf_fn` generate data
+ # indefinitely, leaving the call to `Estimator.train` in control.
+ num_epochs=None,
+ shuffle=True)
+
+ # Build the validation input_fn.
+ input_test = tf.estimator.inputs.pandas_input_fn(
+ x=x_test, y=y_test, shuffle=True)
+
+ # The following code demonstrates two of the ways that `feature_columns` can
+ # be used to build a model with categorical inputs.
+
+ # The first way assigns a unique weight to each category. To do this, you must
+ # specify the category's vocabulary (values outside this specification will
+ # receive a weight of zero).
+ # Alternatively, you can define the vocabulary in a file (by calling
+ # `categorical_column_with_vocabulary_file`) or as a range of positive
+ # integers (by calling `categorical_column_with_identity`)
+ body_style_vocab = ["hardtop", "wagon", "sedan", "hatchback", "convertible"]
+ body_style_column = tf.feature_column.categorical_column_with_vocabulary_list(
+ key="body-style", vocabulary_list=body_style_vocab)
+
+ # The second way, appropriate for an unspecified vocabulary, is to create a
+ # hashed column. It will create a fixed length list of weights, and
+ # automatically assign each input categort to a weight. Due to the
+ # pseudo-randomness of the process, some weights may be shared between
+ # categories, while others will remain unused.
+ make_column = tf.feature_column.categorical_column_with_hash_bucket(
+ key="make", hash_bucket_size=50)
+
+ feature_columns = [
+ # This model uses the same two numeric features as `linear_regressor.py`
+ tf.feature_column.numeric_column(key="curb-weight"),
+ tf.feature_column.numeric_column(key="highway-mpg"),
+ # This model adds two categorical colums that will adjust the price based
+ # on "make" and "body-style".
+ body_style_column,
+ make_column,
+ ]
+
+ # Build the Estimator.
+ model = tf.estimator.LinearRegressor(feature_columns=feature_columns)
+
+ # Train the model.
+ # By default, the Estimators log output every 100 steps.
+ model.train(input_fn=input_train, steps=STEPS)
+
+ # Evaluate how the model performs on data it has not yet seen.
+ eval_result = model.evaluate(input_fn=input_test)
+
+ # The evaluation returns a Python dictionary. The "average_loss" key holds the
+ # Mean Squared Error (MSE).
+ average_loss = eval_result["average_loss"]
+
+ # Convert MSE to Root Mean Square Error (RMSE).
+ print("\n" + 80 * "*")
+ print("\nRMS error for the test set: ${:.0f}".format(average_loss**0.5))
+
+ print()
+
+
+if __name__ == "__main__":
+ # The Estimator periodically generates "INFO" logs; make these logs visible.
+ tf.logging.set_verbosity(tf.logging.INFO)
+ tf.app.run(main=main)
diff --git a/tensorflow/examples/get_started/regression/test.py b/tensorflow/examples/get_started/regression/test.py
new file mode 100644
index 0000000000..5a644cb8d6
--- /dev/null
+++ b/tensorflow/examples/get_started/regression/test.py
@@ -0,0 +1,73 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""A simple smoke test that runs these examples for 1 training iteraton."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import sys
+
+import pandas as pd
+
+from six.moves import StringIO
+
+import tensorflow.examples.get_started.regression.imports85 as imports85
+
+import tensorflow.examples.get_started.regression.dnn_regression as dnn_regression # pylint: disable=g-bad-import-order,g-import-not-at-top
+import tensorflow.examples.get_started.regression.linear_regression as linear_regression
+import tensorflow.examples.get_started.regression.linear_regression_categorical as linear_regression_categorical
+
+from tensorflow.python.platform import googletest
+from tensorflow.python.platform import test
+
+
+def four_lines():
+ # pylint: disable=line-too-long
+ text = StringIO("""
+ 1,?,alfa-romero,gas,std,two,hatchback,rwd,front,94.50,171.20,65.50,52.40,2823,ohcv,six,152,mpfi,2.68,3.47,9.00,154,5000,19,26,16500
+ 2,164,audi,gas,std,four,sedan,fwd,front,99.80,176.60,66.20,54.30,2337,ohc,four,109,mpfi,3.19,3.40,10.00,102,5500,24,30,13950
+ 2,164,audi,gas,std,four,sedan,4wd,front,99.40,176.60,66.40,54.30,2824,ohc,five,136,mpfi,3.19,3.40,8.00,115,5500,18,22,17450
+ 2,?,audi,gas,std,two,sedan,fwd,front,99.80,177.30,66.30,53.10,2507,ohc,five,136,mpfi,3.19,3.40,8.50,110,5500,19,25,15250""")
+ # pylint: enable=line-too-long
+
+ return pd.read_csv(text, names=imports85.header.keys(),
+ dtype=imports85.header, na_values='?')
+
+
+class RegressionTest(googletest.TestCase):
+ """Test the regression examples in this directory."""
+
+ @test.mock.patch.dict(imports85.__dict__, {'raw': four_lines})
+ @test.mock.patch.dict(linear_regression.__dict__, {'STEPS': 1})
+ @test.mock.patch.dict(sys.modules, {'imports85': imports85})
+ def test_linear_regression(self):
+ linear_regression.main([])
+
+ @test.mock.patch.dict(imports85.__dict__, {'raw': four_lines})
+ @test.mock.patch.dict(linear_regression_categorical.__dict__, {'STEPS': 1})
+ @test.mock.patch.dict(sys.modules, {'imports85': imports85})
+ def test_linear_regression_categorical(self):
+ linear_regression_categorical.main([])
+
+ @test.mock.patch.dict(imports85.__dict__, {'raw': four_lines})
+ @test.mock.patch.dict(dnn_regression.__dict__, {'STEPS': 1})
+ @test.mock.patch.dict(sys.modules, {'imports85': imports85})
+ def test_dnn_regression(self):
+ dnn_regression.main([])
+
+
+if __name__ == '__main__':
+ googletest.main()