diff options
author | 2017-10-20 10:45:51 -0700 | |
---|---|---|
committer | 2017-10-20 10:57:03 -0700 | |
commit | 0f5683d629c6607d1baeaa44ecd264321ae05abc (patch) | |
tree | 197f9a31f858df7aad4006b47a213041a5a84454 /tensorflow/examples/learn | |
parent | 8f7439888c7c3ea7f188df64952cfb4f1e082ecc (diff) |
Migrate the iris example to use TF core API.
PiperOrigin-RevId: 172902682
Diffstat (limited to 'tensorflow/examples/learn')
-rw-r--r-- | tensorflow/examples/learn/iris.py | 101 |
1 files changed, 74 insertions, 27 deletions
diff --git a/tensorflow/examples/learn/iris.py b/tensorflow/examples/learn/iris.py index 33e8d45801..0a50b3ba87 100644 --- a/tensorflow/examples/learn/iris.py +++ b/tensorflow/examples/learn/iris.py @@ -17,47 +17,94 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import numpy as np -from sklearn import datasets -from sklearn import metrics -from sklearn import model_selection +import os +import urllib import tensorflow as tf +# Data sets +IRIS_TRAINING = 'iris_training.csv' +IRIS_TRAINING_URL = 'http://download.tensorflow.org/data/iris_training.csv' -X_FEATURE = 'x' # Name of the input feature. +IRIS_TEST = 'iris_test.csv' +IRIS_TEST_URL = 'http://download.tensorflow.org/data/iris_test.csv' + +FEATURE_KEYS = ['sepal_length', 'sepal_width', 'petal_length', 'petal_width'] + + +def maybe_download_iris_data(file_name, download_url): + """Downloads the file and returns the number of data.""" + if not os.path.exists(file_name): + raw = urllib.urlopen(download_url).read() + with open(file_name, 'w') as f: + f.write(raw) + + # The first line is a comma-separated string. The first one is the number of + # total data in the file. + with open(file_name, 'r') as f: + first_line = f.readline() + num_elements = first_line.split(',')[0] + return int(num_elements) + + +def input_fn(file_name, num_data, batch_size, is_training): + """Creates an input_fn required by Estimator train/evaluate.""" + # If the data sets aren't stored locally, download them. + + def _parse_csv(rows_string_tensor): + """Takes the string input tensor and returns tuple of (features, labels).""" + # Last dim is the label. + num_features = len(FEATURE_KEYS) + num_columns = num_features + 1 + columns = tf.decode_csv(rows_string_tensor, + record_defaults=[[]] * num_columns) + features = dict(zip(FEATURE_KEYS, columns[:num_features])) + labels = tf.cast(columns[num_features], tf.int32) + return features, labels + + def _input_fn(): + """The input_fn.""" + dataset = tf.data.TextLineDataset([file_name]) + # Skip the first line (which does not have data). + dataset = dataset.skip(1) + dataset = dataset.map(_parse_csv) + + if is_training: + # For this small dataset, which can fit into memory, to achieve true + # randomness, the shuffle buffer size is set as the total number of + # elements in the dataset. + dataset = dataset.shuffle(num_data) + dataset = dataset.repeat() + + dataset = dataset.batch(batch_size) + iterator = dataset.make_one_shot_iterator() + features, labels = iterator.get_next() + return features, labels + + return _input_fn def main(unused_argv): - # Load dataset. - iris = datasets.load_iris() - x_train, x_test, y_train, y_test = model_selection.train_test_split( - iris.data, iris.target, test_size=0.2, random_state=42) + tf.logging.set_verbosity(tf.logging.INFO) + + num_training_data = maybe_download_iris_data( + IRIS_TRAINING, IRIS_TRAINING_URL) + num_test_data = maybe_download_iris_data(IRIS_TEST, IRIS_TEST_URL) # Build 3 layer DNN with 10, 20, 10 units respectively. feature_columns = [ - tf.feature_column.numeric_column( - X_FEATURE, shape=np.array(x_train).shape[1:])] + tf.feature_column.numeric_column(key, shape=1) for key in FEATURE_KEYS] classifier = tf.estimator.DNNClassifier( feature_columns=feature_columns, hidden_units=[10, 20, 10], n_classes=3) # Train. - train_input_fn = tf.estimator.inputs.numpy_input_fn( - x={X_FEATURE: x_train}, y=y_train, num_epochs=None, shuffle=True) - classifier.train(input_fn=train_input_fn, steps=200) - - # Predict. - test_input_fn = tf.estimator.inputs.numpy_input_fn( - x={X_FEATURE: x_test}, y=y_test, num_epochs=1, shuffle=False) - predictions = classifier.predict(input_fn=test_input_fn) - y_predicted = np.array(list(p['class_ids'] for p in predictions)) - y_predicted = y_predicted.reshape(np.array(y_test).shape) - - # Score with sklearn. - score = metrics.accuracy_score(y_test, y_predicted) - print('Accuracy (sklearn): {0:f}'.format(score)) - - # Score with tensorflow. + train_input_fn = input_fn(IRIS_TRAINING, num_training_data, batch_size=32, + is_training=True) + classifier.train(input_fn=train_input_fn, steps=400) + + # Eval. + test_input_fn = input_fn(IRIS_TEST, num_test_data, batch_size=32, + is_training=False) scores = classifier.evaluate(input_fn=test_input_fn) print('Accuracy (tensorflow): {0:f}'.format(scores['accuracy'])) |