aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/examples/get_started
diff options
context:
space:
mode:
authorGravatar Mark Daoust <markdaoust@google.com>2017-09-06 12:14:53 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-06 12:24:20 -0700
commitacc7c00588635765b96d6e1a74ff81b8b76ad45d (patch)
treea5efb59fd3d5ea018b213280fb813ff9d5ba8e49 /tensorflow/examples/get_started
parent0f6a17c51e0dd67752d3196c99d4f4dc1746c55a (diff)
Add csv dataset example to get_started/regression.
PiperOrigin-RevId: 167754634
Diffstat (limited to 'tensorflow/examples/get_started')
-rw-r--r--tensorflow/examples/get_started/regression/dnn_regression.py18
-rw-r--r--tensorflow/examples/get_started/regression/imports85.py170
-rw-r--r--tensorflow/examples/get_started/regression/linear_regression.py21
-rw-r--r--tensorflow/examples/get_started/regression/linear_regression_categorical.py21
-rw-r--r--tensorflow/examples/get_started/regression/test.py66
5 files changed, 209 insertions, 87 deletions
diff --git a/tensorflow/examples/get_started/regression/dnn_regression.py b/tensorflow/examples/get_started/regression/dnn_regression.py
index 06f0665e56..7aa3659139 100644
--- a/tensorflow/examples/get_started/regression/dnn_regression.py
+++ b/tensorflow/examples/get_started/regression/dnn_regression.py
@@ -28,15 +28,21 @@ STEPS = 5000
def main(argv):
"""Builds, trains, and evaluates the model."""
assert len(argv) == 1
- (x_train, y_train), (x_test, y_test) = imports85.load_data()
+ (train, test) = imports85.dataset()
# Build the training input_fn.
- input_train = tf.estimator.inputs.pandas_input_fn(
- x=x_train, y=y_train, num_epochs=None, shuffle=True)
+ def input_train():
+ return (
+ # Shuffling with a buffer larger than the data set ensures
+ # that the examples are well mixed.
+ train.shuffle(1000).batch(128)
+ # Repeat forever
+ .repeat().make_one_shot_iterator().get_next())
# Build the validation input_fn.
- input_test = tf.estimator.inputs.pandas_input_fn(
- x=x_test, y=y_test, shuffle=True)
+ def input_test():
+ return (test.shuffle(1000).batch(128)
+ .make_one_shot_iterator().get_next())
# The first way assigns a unique weight to each category. To do this you must
# specify the category's vocabulary (values outside this specification will
@@ -71,7 +77,7 @@ def main(argv):
# Train the model.
model.train(input_fn=input_train, steps=STEPS)
- # Evaluate how the model performs on data it has not yet seen.
+ # Evaluate how the model performs on data it has not yet seen.
eval_result = model.evaluate(input_fn=input_test)
# The evaluation returns a Python dictionary. The "average_loss" key holds the
diff --git a/tensorflow/examples/get_started/regression/imports85.py b/tensorflow/examples/get_started/regression/imports85.py
index 4532064622..41e77222ce 100644
--- a/tensorflow/examples/get_started/regression/imports85.py
+++ b/tensorflow/examples/get_started/regression/imports85.py
@@ -21,53 +21,149 @@ from __future__ import print_function
import collections
import numpy as np
-import pandas as pd
import tensorflow as tf
-header = collections.OrderedDict([
- ("symboling", np.int32),
- ("normalized-losses", np.float32),
- ("make", str),
- ("fuel-type", str),
- ("aspiration", str),
- ("num-of-doors", str),
- ("body-style", str),
- ("drive-wheels", str),
- ("engine-location", str),
- ("wheel-base", np.float32),
- ("length", np.float32),
- ("width", np.float32),
- ("height", np.float32),
- ("curb-weight", np.float32),
- ("engine-type", str),
- ("num-of-cylinders", str),
- ("engine-size", np.float32),
- ("fuel-system", str),
- ("bore", np.float32),
- ("stroke", np.float32),
- ("compression-ratio", np.float32),
- ("horsepower", np.float32),
- ("peak-rpm", np.float32),
- ("city-mpg", np.float32),
- ("highway-mpg", np.float32),
- ("price", np.float32)
+try:
+ import pandas as pd # pylint: disable=g-import-not-at-top
+except ImportError:
+ pass
+
+
+URL = "https://archive.ics.uci.edu/ml/machine-learning-databases/autos/imports-85.data"
+
+# Order is important for the csv-readers, so we use an OrderedDict here.
+defaults = collections.OrderedDict([
+ ("symboling", [0]),
+ ("normalized-losses", [0.0]),
+ ("make", [""]),
+ ("fuel-type", [""]),
+ ("aspiration", [""]),
+ ("num-of-doors", [""]),
+ ("body-style", [""]),
+ ("drive-wheels", [""]),
+ ("engine-location", [""]),
+ ("wheel-base", [0.0]),
+ ("length", [0.0]),
+ ("width", [0.0]),
+ ("height", [0.0]),
+ ("curb-weight", [0.0]),
+ ("engine-type", [""]),
+ ("num-of-cylinders", [""]),
+ ("engine-size", [0.0]),
+ ("fuel-system", [""]),
+ ("bore", [0.0]),
+ ("stroke", [0.0]),
+ ("compression-ratio", [0.0]),
+ ("horsepower", [0.0]),
+ ("peak-rpm", [0.0]),
+ ("city-mpg", [0.0]),
+ ("highway-mpg", [0.0]),
+ ("price", [0.0])
]) # pyformat: disable
-def raw():
- """Get the imports85 data and load it as a pd.DataFrame."""
- url = "https://archive.ics.uci.edu/ml/machine-learning-databases/autos/imports-85.data" # pylint: disable=line-too-long
- # Download and cache the data.
- path = tf.contrib.keras.utils.get_file(url.split("/")[-1], url)
+types = collections.OrderedDict((key, type(value[0]))
+ for key, value in defaults.items())
- # Load the CSV data into a pandas dataframe.
- df = pd.read_csv(path, names=header.keys(), dtype=header, na_values="?")
+
+def _get_imports85():
+ path = tf.contrib.keras.utils.get_file(URL.split("/")[-1], URL)
+ return path
+
+
+def dataset(y_name="price", train_fraction=0.7):
+ """Load the imports85 data as a (train,test) pair of `Dataset`.
+
+ Each dataset generates (features_dict, label) pairs.
+
+ Args:
+ y_name: The name of the column to use as the label.
+ train_fraction: A float, the fraction of data to use for training. The
+ remainder will be used for evaluation.
+ Returns:
+ A (train,test) pair of `Datasets`
+ """
+ # Download and cache the data
+ path = _get_imports85()
+
+ # Define how the lines of the file should be parsed
+ def decode_line(line):
+ """Convert a csv line into a (features_dict,label) pair."""
+ # Decode the line to a tuple of items based on the types of
+ # csv_header.values().
+ items = tf.decode_csv(line, defaults.values())
+
+ # Convert the keys and items to a dict.
+ pairs = zip(defaults.keys(), items)
+ features_dict = dict(pairs)
+
+ # Remove the label from the features_dict
+ label = features_dict.pop(y_name)
+
+ return features_dict, label
+
+ def has_no_question_marks(line):
+ """Returns True if the line of text has no question marks."""
+ # split the line into an array of characters
+ chars = tf.string_split(line[tf.newaxis], "").values
+ # for each character check if it is a question mark
+ is_question = tf.equal(chars, "?")
+ any_question = tf.reduce_any(is_question)
+ no_question = ~any_question
+
+ return no_question
+
+ def in_training_set(line):
+ """Returns a boolean tensor, true if the line is in the training set."""
+ # If you randomly split the dataset you won't get the same split in both
+ # sessions if you stop and restart training later. Also a simple
+ # random split won't work with a dataset that's too big to `.cache()` as
+ # we are doing here.
+ num_buckets = 1000000
+ bucket_id = tf.string_to_hash_bucket_fast(line, num_buckets)
+ # Use the hash bucket id as a random number that's deterministic per example
+ return bucket_id < int(train_fraction * num_buckets)
+
+ def in_test_set(line):
+ """Returns a boolean tensor, true if the line is in the training set."""
+ # Items not in the training set are in the test set.
+ # This line must use `~` instead of `not` beacuse `not` only works on python
+ # booleans but we are dealing with symbolic tensors.
+ return ~in_training_set(line)
+
+ base_dataset = (tf.contrib.data
+ # Get the lines from the file.
+ .TextLineDataset(path)
+ # drop lines with question marks.
+ .filter(has_no_question_marks))
+
+ train = (base_dataset
+ # Take only the training-set lines.
+ .filter(in_training_set)
+ # Cache data so you only read the file once.
+ .cache()
+ # Decode each line into a (features_dict, label) pair.
+ .map(decode_line))
+
+ # Do the same for the test-set.
+ test = (base_dataset.filter(in_test_set).cache().map(decode_line))
+
+ return train, test
+
+
+def raw_dataframe():
+ """Load the imports85 data as a pd.DataFrame."""
+ # Download and cache the data
+ path = _get_imports85()
+
+ # Load it into a pandas dataframe
+ df = pd.read_csv(path, names=types.keys(), dtype=types, na_values="?")
return df
def load_data(y_name="price", train_fraction=0.7, seed=None):
- """Returns the imports85 shuffled and split into train and test subsets.
+ """Get the imports85 data set.
A description of the data is available at:
https://archive.ics.uci.edu/ml/datasets/automobile
@@ -88,7 +184,7 @@ def load_data(y_name="price", train_fraction=0.7, seed=None):
array.
"""
# Load the raw data columns.
- data = raw()
+ data = raw_dataframe()
# Delete rows with unknowns
data = data.dropna()
diff --git a/tensorflow/examples/get_started/regression/linear_regression.py b/tensorflow/examples/get_started/regression/linear_regression.py
index 9793163323..dd44077663 100644
--- a/tensorflow/examples/get_started/regression/linear_regression.py
+++ b/tensorflow/examples/get_started/regression/linear_regression.py
@@ -29,20 +29,21 @@ STEPS = 1000
def main(argv):
"""Builds, trains, and evaluates the model."""
assert len(argv) == 1
- (x_train, y_train), (x_test, y_test) = imports85.load_data()
+ (train, test) = imports85.dataset()
# Build the training input_fn.
- input_train = tf.estimator.inputs.pandas_input_fn(
- x=x_train,
- y=y_train,
- # Setting `num_epochs` to `None` lets the `inpuf_fn` generate data
- # indefinitely, leaving the call to `Estimator.train` in control.
- num_epochs=None,
- shuffle=True)
+ def input_train():
+ return (
+ # Shuffling with a buffer larger than the data set ensures
+ # that the examples are well mixed.
+ train.shuffle(1000).batch(128)
+ # Repeat forever
+ .repeat().make_one_shot_iterator().get_next())
# Build the validation input_fn.
- input_test = tf.estimator.inputs.pandas_input_fn(
- x=x_test, y=y_test, shuffle=True)
+ def input_test():
+ return (test.shuffle(1000).batch(128)
+ .make_one_shot_iterator().get_next())
feature_columns = [
# "curb-weight" and "highway-mpg" are numeric columns.
diff --git a/tensorflow/examples/get_started/regression/linear_regression_categorical.py b/tensorflow/examples/get_started/regression/linear_regression_categorical.py
index 0a416595e6..38ecfada9d 100644
--- a/tensorflow/examples/get_started/regression/linear_regression_categorical.py
+++ b/tensorflow/examples/get_started/regression/linear_regression_categorical.py
@@ -28,20 +28,21 @@ STEPS = 1000
def main(argv):
"""Builds, trains, and evaluates the model."""
assert len(argv) == 1
- (x_train, y_train), (x_test, y_test) = imports85.load_data()
+ (train, test) = imports85.dataset()
# Build the training input_fn.
- input_train = tf.estimator.inputs.pandas_input_fn(
- x=x_train,
- y=y_train,
- # Setting `num_epochs` to `None` lets the `inpuf_fn` generate data
- # indefinitely, leaving the call to `Estimator.train` in control.
- num_epochs=None,
- shuffle=True)
+ def input_train():
+ return (
+ # Shuffling with a buffer larger than the data set ensures
+ # that the examples are well mixed.
+ train.shuffle(1000).batch(128)
+ # Repeat forever
+ .repeat().make_one_shot_iterator().get_next())
# Build the validation input_fn.
- input_test = tf.estimator.inputs.pandas_input_fn(
- x=x_test, y=y_test, shuffle=True)
+ def input_test():
+ return (test.shuffle(1000).batch(128)
+ .make_one_shot_iterator().get_next())
# The following code demonstrates two of the ways that `feature_columns` can
# be used to build a model with categorical inputs.
diff --git a/tensorflow/examples/get_started/regression/test.py b/tensorflow/examples/get_started/regression/test.py
index 5a644cb8d6..fa06dde9ae 100644
--- a/tensorflow/examples/get_started/regression/test.py
+++ b/tensorflow/examples/get_started/regression/test.py
@@ -26,48 +26,66 @@ from six.moves import StringIO
import tensorflow.examples.get_started.regression.imports85 as imports85
-import tensorflow.examples.get_started.regression.dnn_regression as dnn_regression # pylint: disable=g-bad-import-order,g-import-not-at-top
+sys.modules["imports85"] = imports85
+
+# pylint: disable=g-bad-import-order,g-import-not-at-top
+import tensorflow.contrib.data as data
+
+import tensorflow.examples.get_started.regression.dnn_regression as dnn_regression
import tensorflow.examples.get_started.regression.linear_regression as linear_regression
import tensorflow.examples.get_started.regression.linear_regression_categorical as linear_regression_categorical
from tensorflow.python.platform import googletest
from tensorflow.python.platform import test
+# pylint: disable=g-bad-import-order,g-import-not-at-top
+
+
+# pylint: disable=line-too-long
+FOUR_LINES = "\n".join([
+ "1,?,alfa-romero,gas,std,two,hatchback,rwd,front,94.50,171.20,65.50,52.40,2823,ohcv,six,152,mpfi,2.68,3.47,9.00,154,5000,19,26,16500",
+ "2,164,audi,gas,std,four,sedan,fwd,front,99.80,176.60,66.20,54.30,2337,ohc,four,109,mpfi,3.19,3.40,10.00,102,5500,24,30,13950",
+ "2,164,audi,gas,std,four,sedan,4wd,front,99.40,176.60,66.40,54.30,2824,ohc,five,136,mpfi,3.19,3.40,8.00,115,5500,18,22,17450",
+ "2,?,audi,gas,std,two,sedan,fwd,front,99.80,177.30,66.30,53.10,2507,ohc,five,136,mpfi,3.19,3.40,8.50,110,5500,19,25,15250",])
+
+# pylint: enable=line-too-long
+
+
+def four_lines_dataframe():
+ text = StringIO(FOUR_LINES)
+ return pd.read_csv(text, names=imports85.types.keys(),
+ dtype=imports85.types, na_values="?")
-def four_lines():
- # pylint: disable=line-too-long
- text = StringIO("""
- 1,?,alfa-romero,gas,std,two,hatchback,rwd,front,94.50,171.20,65.50,52.40,2823,ohcv,six,152,mpfi,2.68,3.47,9.00,154,5000,19,26,16500
- 2,164,audi,gas,std,four,sedan,fwd,front,99.80,176.60,66.20,54.30,2337,ohc,four,109,mpfi,3.19,3.40,10.00,102,5500,24,30,13950
- 2,164,audi,gas,std,four,sedan,4wd,front,99.40,176.60,66.40,54.30,2824,ohc,five,136,mpfi,3.19,3.40,8.00,115,5500,18,22,17450
- 2,?,audi,gas,std,two,sedan,fwd,front,99.80,177.30,66.30,53.10,2507,ohc,five,136,mpfi,3.19,3.40,8.50,110,5500,19,25,15250""")
- # pylint: enable=line-too-long
- return pd.read_csv(text, names=imports85.header.keys(),
- dtype=imports85.header, na_values='?')
+def four_lines_dataset(*args, **kwargs):
+ del args, kwargs
+ return data.Dataset.from_tensor_slices(FOUR_LINES.split("\n"))
class RegressionTest(googletest.TestCase):
"""Test the regression examples in this directory."""
- @test.mock.patch.dict(imports85.__dict__, {'raw': four_lines})
- @test.mock.patch.dict(linear_regression.__dict__, {'STEPS': 1})
- @test.mock.patch.dict(sys.modules, {'imports85': imports85})
+ @test.mock.patch.dict(data.__dict__,
+ {"TextLineDataset": four_lines_dataset})
+ @test.mock.patch.dict(imports85.__dict__, {"_get_imports85": (lambda: None)})
+ @test.mock.patch.dict(linear_regression.__dict__, {"STEPS": 1})
def test_linear_regression(self):
- linear_regression.main([])
+ linear_regression.main([""])
- @test.mock.patch.dict(imports85.__dict__, {'raw': four_lines})
- @test.mock.patch.dict(linear_regression_categorical.__dict__, {'STEPS': 1})
- @test.mock.patch.dict(sys.modules, {'imports85': imports85})
+ @test.mock.patch.dict(data.__dict__,
+ {"TextLineDataset": four_lines_dataset})
+ @test.mock.patch.dict(imports85.__dict__, {"_get_imports85": (lambda: None)})
+ @test.mock.patch.dict(linear_regression_categorical.__dict__, {"STEPS": 1})
def test_linear_regression_categorical(self):
- linear_regression_categorical.main([])
+ linear_regression_categorical.main([""])
- @test.mock.patch.dict(imports85.__dict__, {'raw': four_lines})
- @test.mock.patch.dict(dnn_regression.__dict__, {'STEPS': 1})
- @test.mock.patch.dict(sys.modules, {'imports85': imports85})
+ @test.mock.patch.dict(data.__dict__,
+ {"TextLineDataset": four_lines_dataset})
+ @test.mock.patch.dict(imports85.__dict__, {"_get_imports85": (lambda: None)})
+ @test.mock.patch.dict(dnn_regression.__dict__, {"STEPS": 1})
def test_dnn_regression(self):
- dnn_regression.main([])
+ dnn_regression.main([""])
-if __name__ == '__main__':
+if __name__ == "__main__":
googletest.main()