aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/examples/learn
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-06-27 08:35:47 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-27 08:40:47 -0700
commita132b8330039e7ed326d090cdae35c97561f68b1 (patch)
tree0a24d9c96e22863a69a642fc1801da6af47b24d1 /tensorflow/examples/learn
parent6e046535dbcc1249a001256cb6e89355288ac750 (diff)
Updates some more examples in examples/learn.
PiperOrigin-RevId: 160278757
Diffstat (limited to 'tensorflow/examples/learn')
-rw-r--r--tensorflow/examples/learn/iris.py10
-rw-r--r--tensorflow/examples/learn/iris_custom_decay_dnn.py89
-rw-r--r--tensorflow/examples/learn/iris_custom_model.py101
-rw-r--r--tensorflow/examples/learn/iris_run_config.py52
4 files changed, 173 insertions, 79 deletions
diff --git a/tensorflow/examples/learn/iris.py b/tensorflow/examples/learn/iris.py
index 2ec490b7a2..33e8d45801 100644
--- a/tensorflow/examples/learn/iris.py
+++ b/tensorflow/examples/learn/iris.py
@@ -25,6 +25,9 @@ from sklearn import model_selection
import tensorflow as tf
+X_FEATURE = 'x' # Name of the input feature.
+
+
def main(unused_argv):
# Load dataset.
iris = datasets.load_iris()
@@ -33,18 +36,19 @@ def main(unused_argv):
# Build 3 layer DNN with 10, 20, 10 units respectively.
feature_columns = [
- tf.feature_column.numeric_column('x', shape=np.array(x_train).shape[1:])]
+ tf.feature_column.numeric_column(
+ X_FEATURE, shape=np.array(x_train).shape[1:])]
classifier = tf.estimator.DNNClassifier(
feature_columns=feature_columns, hidden_units=[10, 20, 10], n_classes=3)
# Train.
train_input_fn = tf.estimator.inputs.numpy_input_fn(
- x={'x': x_train}, y=y_train, num_epochs=None, shuffle=True)
+ x={X_FEATURE: x_train}, y=y_train, num_epochs=None, shuffle=True)
classifier.train(input_fn=train_input_fn, steps=200)
# Predict.
test_input_fn = tf.estimator.inputs.numpy_input_fn(
- x={'x': x_test}, y=y_test, num_epochs=1, shuffle=False)
+ x={X_FEATURE: x_test}, y=y_test, num_epochs=1, shuffle=False)
predictions = classifier.predict(input_fn=test_input_fn)
y_predicted = np.array(list(p['class_ids'] for p in predictions))
y_predicted = y_predicted.reshape(np.array(y_test).shape)
diff --git a/tensorflow/examples/learn/iris_custom_decay_dnn.py b/tensorflow/examples/learn/iris_custom_decay_dnn.py
index 31acbd30cd..072357e51c 100644
--- a/tensorflow/examples/learn/iris_custom_decay_dnn.py
+++ b/tensorflow/examples/learn/iris_custom_decay_dnn.py
@@ -17,36 +17,87 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import numpy as np
from sklearn import datasets
from sklearn import metrics
-from sklearn.cross_validation import train_test_split
+from sklearn import model_selection
import tensorflow as tf
-def optimizer_exp_decay():
- global_step = tf.contrib.framework.get_or_create_global_step()
- learning_rate = tf.train.exponential_decay(
- learning_rate=0.1, global_step=global_step,
- decay_steps=100, decay_rate=0.001)
- return tf.train.AdagradOptimizer(learning_rate=learning_rate)
+X_FEATURE = 'x' # Name of the input feature.
+
+
+def my_model(features, labels, mode):
+ """DNN with three hidden layers."""
+ # Create three fully connected layers respectively of size 10, 20, and 10.
+ net = features[X_FEATURE]
+ for units in [10, 20, 10]:
+ net = tf.layers.dense(net, units=units, activation=tf.nn.relu)
+
+ # Compute logits (1 per class).
+ logits = tf.layers.dense(net, 3, activation=None)
+
+ # Compute predictions.
+ predicted_classes = tf.argmax(logits, 1)
+ if mode == tf.estimator.ModeKeys.PREDICT:
+ predictions = {
+ 'class': predicted_classes,
+ 'prob': tf.nn.softmax(logits)
+ }
+ return tf.estimator.EstimatorSpec(mode, predictions=predictions)
+
+ # Convert the labels to a one-hot tensor of shape (length of features, 3) and
+ # with a on-value of 1 for each one-hot vector of length 3.
+ onehot_labels = tf.one_hot(labels, 3, 1, 0)
+ # Compute loss.
+ loss = tf.losses.softmax_cross_entropy(
+ onehot_labels=onehot_labels, logits=logits)
+
+ # Create training op with exponentially decaying learning rate.
+ if mode == tf.estimator.ModeKeys.TRAIN:
+ global_step = tf.train.get_global_step()
+ learning_rate = tf.train.exponential_decay(
+ learning_rate=0.1, global_step=global_step,
+ decay_steps=100, decay_rate=0.001)
+ optimizer = tf.train.AdagradOptimizer(learning_rate=learning_rate)
+ train_op = optimizer.minimize(loss, global_step=global_step)
+ return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)
+
+ # Compute evaluation metrics.
+ eval_metric_ops = {
+ 'accuracy': tf.metrics.accuracy(
+ labels=labels, predictions=predicted_classes)
+ }
+ return tf.estimator.EstimatorSpec(
+ mode, loss=loss, eval_metric_ops=eval_metric_ops)
def main(unused_argv):
iris = datasets.load_iris()
- x_train, x_test, y_train, y_test = train_test_split(
+ x_train, x_test, y_train, y_test = model_selection.train_test_split(
iris.data, iris.target, test_size=0.2, random_state=42)
- feature_columns = tf.contrib.learn.infer_real_valued_columns_from_input(
- x_train)
- classifier = tf.contrib.learn.DNNClassifier(feature_columns=feature_columns,
- hidden_units=[10, 20, 10],
- n_classes=3,
- optimizer=optimizer_exp_decay)
-
- classifier.fit(x_train, y_train, steps=800)
- predictions = list(classifier.predict(x_test, as_iterable=True))
- score = metrics.accuracy_score(y_test, predictions)
- print('Accuracy: {0:f}'.format(score))
+ classifier = tf.estimator.Estimator(model_fn=my_model)
+
+ # Train.
+ train_input_fn = tf.estimator.inputs.numpy_input_fn(
+ x={X_FEATURE: x_train}, y=y_train, num_epochs=None, shuffle=True)
+ classifier.train(input_fn=train_input_fn, steps=1000)
+
+ # Predict.
+ test_input_fn = tf.estimator.inputs.numpy_input_fn(
+ x={X_FEATURE: x_test}, y=y_test, num_epochs=1, shuffle=False)
+ predictions = classifier.predict(input_fn=test_input_fn)
+ y_predicted = np.array(list(p['class'] for p in predictions))
+ y_predicted = y_predicted.reshape(np.array(y_test).shape)
+
+ # Score with sklearn.
+ score = metrics.accuracy_score(y_test, y_predicted)
+ print('Accuracy (sklearn): {0:f}'.format(score))
+
+ # Score with tensorflow.
+ scores = classifier.evaluate(input_fn=test_input_fn)
+ print('Accuracy (tensorflow): {0:f}'.format(scores['accuracy']))
if __name__ == '__main__':
diff --git a/tensorflow/examples/learn/iris_custom_model.py b/tensorflow/examples/learn/iris_custom_model.py
index fbc50716c9..471a99ba76 100644
--- a/tensorflow/examples/learn/iris_custom_model.py
+++ b/tensorflow/examples/learn/iris_custom_model.py
@@ -16,62 +16,85 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from sklearn import cross_validation
+import numpy as np
from sklearn import datasets
from sklearn import metrics
+from sklearn import model_selection
import tensorflow as tf
-layers = tf.contrib.layers
-learn = tf.contrib.learn
+X_FEATURE = 'x' # Name of the input feature.
-def my_model(features, target):
- """DNN with three hidden layers, and dropout of 0.1 probability."""
- # Convert the target to a one-hot tensor of shape (length of features, 3) and
- # with a on-value of 1 for each one-hot vector of length 3.
- target = tf.one_hot(target, 3, 1, 0)
+def my_model(features, labels, mode):
+ """DNN with three hidden layers, and dropout of 0.1 probability."""
# Create three fully connected layers respectively of size 10, 20, and 10 with
# each layer having a dropout probability of 0.1.
- normalizer_fn = layers.dropout
- normalizer_params = {'keep_prob': 0.9}
- features = layers.stack(
- features,
- layers.fully_connected, [10, 20, 10],
- normalizer_fn=normalizer_fn,
- normalizer_params=normalizer_params)
-
- # Compute logits (1 per class) and compute loss.
- logits = layers.fully_connected(features, 3, activation_fn=None)
- loss = tf.losses.softmax_cross_entropy(target, logits)
-
- # Create a tensor for training op.
- train_op = tf.contrib.layers.optimize_loss(
- loss,
- tf.contrib.framework.get_global_step(),
- optimizer='Adagrad',
- learning_rate=0.1)
-
- return ({
- 'class': tf.argmax(logits, 1),
- 'prob': tf.nn.softmax(logits)
- }, loss, train_op)
+ net = features[X_FEATURE]
+ for units in [10, 20, 10]:
+ net = tf.layers.dense(net, units=units, activation=tf.nn.relu)
+ net = tf.layers.dropout(net, rate=0.1)
+
+ # Compute logits (1 per class).
+ logits = tf.layers.dense(net, 3, activation=None)
+
+ # Compute predictions.
+ predicted_classes = tf.argmax(logits, 1)
+ if mode == tf.estimator.ModeKeys.PREDICT:
+ predictions = {
+ 'class': predicted_classes,
+ 'prob': tf.nn.softmax(logits)
+ }
+ return tf.estimator.EstimatorSpec(mode, predictions=predictions)
+
+ # Convert the labels to a one-hot tensor of shape (length of features, 3) and
+ # with a on-value of 1 for each one-hot vector of length 3.
+ onehot_labels = tf.one_hot(labels, 3, 1, 0)
+ # Compute loss.
+ loss = tf.losses.softmax_cross_entropy(
+ onehot_labels=onehot_labels, logits=logits)
+
+ # Create training op.
+ if mode == tf.estimator.ModeKeys.TRAIN:
+ optimizer = tf.train.AdagradOptimizer(learning_rate=0.1)
+ train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step())
+ return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)
+
+ # Compute evaluation metrics.
+ eval_metric_ops = {
+ 'accuracy': tf.metrics.accuracy(
+ labels=labels, predictions=predicted_classes)
+ }
+ return tf.estimator.EstimatorSpec(
+ mode, loss=loss, eval_metric_ops=eval_metric_ops)
def main(unused_argv):
iris = datasets.load_iris()
- x_train, x_test, y_train, y_test = cross_validation.train_test_split(
+ x_train, x_test, y_train, y_test = model_selection.train_test_split(
iris.data, iris.target, test_size=0.2, random_state=42)
- classifier = learn.Estimator(model_fn=my_model)
- classifier.fit(x_train, y_train, steps=1000)
+ classifier = tf.estimator.Estimator(model_fn=my_model)
- y_predicted = [
- p['class'] for p in classifier.predict(
- x_test, as_iterable=True)
- ]
+ # Train.
+ train_input_fn = tf.estimator.inputs.numpy_input_fn(
+ x={X_FEATURE: x_train}, y=y_train, num_epochs=None, shuffle=True)
+ classifier.train(input_fn=train_input_fn, steps=1000)
+
+ # Predict.
+ test_input_fn = tf.estimator.inputs.numpy_input_fn(
+ x={X_FEATURE: x_test}, y=y_test, num_epochs=1, shuffle=False)
+ predictions = classifier.predict(input_fn=test_input_fn)
+ y_predicted = np.array(list(p['class'] for p in predictions))
+ y_predicted = y_predicted.reshape(np.array(y_test).shape)
+
+ # Score with sklearn.
score = metrics.accuracy_score(y_test, y_predicted)
- print('Accuracy: {0:f}'.format(score))
+ print('Accuracy (sklearn): {0:f}'.format(score))
+
+ # Score with tensorflow.
+ scores = classifier.evaluate(input_fn=test_input_fn)
+ print('Accuracy (tensorflow): {0:f}'.format(scores['accuracy']))
if __name__ == '__main__':
diff --git a/tensorflow/examples/learn/iris_run_config.py b/tensorflow/examples/learn/iris_run_config.py
index b7b8b5cd01..286c824e30 100644
--- a/tensorflow/examples/learn/iris_run_config.py
+++ b/tensorflow/examples/learn/iris_run_config.py
@@ -18,37 +18,53 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from sklearn import cross_validation
+import numpy as np
from sklearn import datasets
from sklearn import metrics
+from sklearn import model_selection
import tensorflow as tf
+X_FEATURE = 'x' # Name of the input feature.
+
+
def main(unused_argv):
# Load dataset.
iris = datasets.load_iris()
- x_train, x_test, y_train, y_test = cross_validation.train_test_split(
+ x_train, x_test, y_train, y_test = model_selection.train_test_split(
iris.data, iris.target, test_size=0.2, random_state=42)
# You can define you configurations by providing a RunConfig object to
- # estimator to control session configurations, e.g. num_cores
- # and gpu_memory_fraction
- run_config = tf.contrib.learn.estimators.RunConfig(
- num_cores=3, gpu_memory_fraction=0.6)
+ # estimator to control session configurations, e.g. tf_random_seed.
+ run_config = tf.estimator.RunConfig().replace(tf_random_seed=1)
# Build 3 layer DNN with 10, 20, 10 units respectively.
- feature_columns = tf.contrib.learn.infer_real_valued_columns_from_input(
- x_train)
- classifier = tf.contrib.learn.DNNClassifier(feature_columns=feature_columns,
- hidden_units=[10, 20, 10],
- n_classes=3,
- config=run_config)
-
- # Fit and predict.
- classifier.fit(x_train, y_train, steps=200)
- predictions = list(classifier.predict(x_test, as_iterable=True))
- score = metrics.accuracy_score(y_test, predictions)
- print('Accuracy: {0:f}'.format(score))
+ feature_columns = [
+ tf.feature_column.numeric_column(
+ X_FEATURE, shape=np.array(x_train).shape[1:])]
+ classifier = tf.estimator.DNNClassifier(
+ feature_columns=feature_columns, hidden_units=[10, 20, 10], n_classes=3,
+ config=run_config)
+
+ # Train.
+ train_input_fn = tf.estimator.inputs.numpy_input_fn(
+ x={X_FEATURE: x_train}, y=y_train, num_epochs=None, shuffle=True)
+ classifier.train(input_fn=train_input_fn, steps=200)
+
+ # Predict.
+ test_input_fn = tf.estimator.inputs.numpy_input_fn(
+ x={X_FEATURE: x_test}, y=y_test, num_epochs=1, shuffle=False)
+ predictions = classifier.predict(input_fn=test_input_fn)
+ y_predicted = np.array(list(p['class_ids'] for p in predictions))
+ y_predicted = y_predicted.reshape(np.array(y_test).shape)
+
+ # Score with sklearn.
+ score = metrics.accuracy_score(y_test, y_predicted)
+ print('Accuracy (sklearn): {0:f}'.format(score))
+
+ # Score with tensorflow.
+ scores = classifier.evaluate(input_fn=test_input_fn)
+ print('Accuracy (tensorflow): {0:f}'.format(scores['accuracy']))
if __name__ == '__main__':