aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/examples/learn
diff options
context:
space:
mode:
authorGravatar Jianwei Xie <xiejw@google.com>2017-10-20 10:45:51 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-20 10:57:03 -0700
commit0f5683d629c6607d1baeaa44ecd264321ae05abc (patch)
tree197f9a31f858df7aad4006b47a213041a5a84454 /tensorflow/examples/learn
parent8f7439888c7c3ea7f188df64952cfb4f1e082ecc (diff)
Migrate the iris example to use TF core API.
PiperOrigin-RevId: 172902682
Diffstat (limited to 'tensorflow/examples/learn')
-rw-r--r--tensorflow/examples/learn/iris.py101
1 files changed, 74 insertions, 27 deletions
diff --git a/tensorflow/examples/learn/iris.py b/tensorflow/examples/learn/iris.py
index 33e8d45801..0a50b3ba87 100644
--- a/tensorflow/examples/learn/iris.py
+++ b/tensorflow/examples/learn/iris.py
@@ -17,47 +17,94 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import numpy as np
-from sklearn import datasets
-from sklearn import metrics
-from sklearn import model_selection
+import os
+import urllib
import tensorflow as tf
+# Data sets
+IRIS_TRAINING = 'iris_training.csv'
+IRIS_TRAINING_URL = 'http://download.tensorflow.org/data/iris_training.csv'
-X_FEATURE = 'x' # Name of the input feature.
+IRIS_TEST = 'iris_test.csv'
+IRIS_TEST_URL = 'http://download.tensorflow.org/data/iris_test.csv'
+
+FEATURE_KEYS = ['sepal_length', 'sepal_width', 'petal_length', 'petal_width']
+
+
+def maybe_download_iris_data(file_name, download_url):
+ """Downloads the file and returns the number of data."""
+ if not os.path.exists(file_name):
+ raw = urllib.urlopen(download_url).read()
+ with open(file_name, 'w') as f:
+ f.write(raw)
+
+ # The first line is a comma-separated string. The first one is the number of
+ # total data in the file.
+ with open(file_name, 'r') as f:
+ first_line = f.readline()
+ num_elements = first_line.split(',')[0]
+ return int(num_elements)
+
+
+def input_fn(file_name, num_data, batch_size, is_training):
+ """Creates an input_fn required by Estimator train/evaluate."""
+ # If the data sets aren't stored locally, download them.
+
+ def _parse_csv(rows_string_tensor):
+ """Takes the string input tensor and returns tuple of (features, labels)."""
+ # Last dim is the label.
+ num_features = len(FEATURE_KEYS)
+ num_columns = num_features + 1
+ columns = tf.decode_csv(rows_string_tensor,
+ record_defaults=[[]] * num_columns)
+ features = dict(zip(FEATURE_KEYS, columns[:num_features]))
+ labels = tf.cast(columns[num_features], tf.int32)
+ return features, labels
+
+ def _input_fn():
+ """The input_fn."""
+ dataset = tf.data.TextLineDataset([file_name])
+ # Skip the first line (which does not have data).
+ dataset = dataset.skip(1)
+ dataset = dataset.map(_parse_csv)
+
+ if is_training:
+ # For this small dataset, which can fit into memory, to achieve true
+ # randomness, the shuffle buffer size is set as the total number of
+ # elements in the dataset.
+ dataset = dataset.shuffle(num_data)
+ dataset = dataset.repeat()
+
+ dataset = dataset.batch(batch_size)
+ iterator = dataset.make_one_shot_iterator()
+ features, labels = iterator.get_next()
+ return features, labels
+
+ return _input_fn
def main(unused_argv):
- # Load dataset.
- iris = datasets.load_iris()
- x_train, x_test, y_train, y_test = model_selection.train_test_split(
- iris.data, iris.target, test_size=0.2, random_state=42)
+ tf.logging.set_verbosity(tf.logging.INFO)
+
+ num_training_data = maybe_download_iris_data(
+ IRIS_TRAINING, IRIS_TRAINING_URL)
+ num_test_data = maybe_download_iris_data(IRIS_TEST, IRIS_TEST_URL)
# Build 3 layer DNN with 10, 20, 10 units respectively.
feature_columns = [
- tf.feature_column.numeric_column(
- X_FEATURE, shape=np.array(x_train).shape[1:])]
+ tf.feature_column.numeric_column(key, shape=1) for key in FEATURE_KEYS]
classifier = tf.estimator.DNNClassifier(
feature_columns=feature_columns, hidden_units=[10, 20, 10], n_classes=3)
# Train.
- train_input_fn = tf.estimator.inputs.numpy_input_fn(
- x={X_FEATURE: x_train}, y=y_train, num_epochs=None, shuffle=True)
- classifier.train(input_fn=train_input_fn, steps=200)
-
- # Predict.
- test_input_fn = tf.estimator.inputs.numpy_input_fn(
- x={X_FEATURE: x_test}, y=y_test, num_epochs=1, shuffle=False)
- predictions = classifier.predict(input_fn=test_input_fn)
- y_predicted = np.array(list(p['class_ids'] for p in predictions))
- y_predicted = y_predicted.reshape(np.array(y_test).shape)
-
- # Score with sklearn.
- score = metrics.accuracy_score(y_test, y_predicted)
- print('Accuracy (sklearn): {0:f}'.format(score))
-
- # Score with tensorflow.
+ train_input_fn = input_fn(IRIS_TRAINING, num_training_data, batch_size=32,
+ is_training=True)
+ classifier.train(input_fn=train_input_fn, steps=400)
+
+ # Eval.
+ test_input_fn = input_fn(IRIS_TEST, num_test_data, batch_size=32,
+ is_training=False)
scores = classifier.evaluate(input_fn=test_input_fn)
print('Accuracy (tensorflow): {0:f}'.format(scores['accuracy']))