aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/examples/learn
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-06-29 10:04:16 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-29 10:09:02 -0700
commitd0360669091554fbafe532bc33768e572b0f01f8 (patch)
tree6c0ea35757b17c4f4527548bbcf684e78de6d471 /tensorflow/examples/learn
parent7eba207ed3046fbbe7e9b628fba30ae7498697d0 (diff)
Updates remaining examples in examples/learn.
PiperOrigin-RevId: 160538962
Diffstat (limited to 'tensorflow/examples/learn')
-rw-r--r--tensorflow/examples/learn/hdf5_classification.py41
-rw-r--r--tensorflow/examples/learn/mnist.py131
-rw-r--r--tensorflow/examples/learn/multiple_gpu.py107
-rwxr-xr-xtensorflow/examples/learn/resnet.py197
4 files changed, 277 insertions, 199 deletions
diff --git a/tensorflow/examples/learn/hdf5_classification.py b/tensorflow/examples/learn/hdf5_classification.py
index db37500246..3a46bbcf41 100644
--- a/tensorflow/examples/learn/hdf5_classification.py
+++ b/tensorflow/examples/learn/hdf5_classification.py
@@ -11,25 +11,27 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-"""Example of DNNClassifier for Iris plant dataset, h5 format."""
+"""Example of DNNClassifier for Iris plant dataset, hdf5 format."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
-from sklearn import cross_validation
+from sklearn import datasets
from sklearn import metrics
+from sklearn import model_selection
import tensorflow as tf
import h5py # pylint: disable=g-bad-import-order
-learn = tf.contrib.learn
+
+X_FEATURE = 'x' # Name of the input feature.
def main(unused_argv):
# Load dataset.
- iris = learn.datasets.load_dataset('iris')
- x_train, x_test, y_train, y_test = cross_validation.train_test_split(
+ iris = datasets.load_iris()
+ x_train, x_test, y_train, y_test = model_selection.train_test_split(
iris.data, iris.target, test_size=0.2, random_state=42)
# Note that we are saving and load iris data as h5 format as a simple
@@ -48,14 +50,31 @@ def main(unused_argv):
y_test = np.array(h5f['y_test'])
# Build 3 layer DNN with 10, 20, 10 units respectively.
- feature_columns = learn.infer_real_valued_columns_from_input(x_train)
- classifier = learn.DNNClassifier(
+ feature_columns = [
+ tf.feature_column.numeric_column(
+ X_FEATURE, shape=np.array(x_train).shape[1:])]
+ classifier = tf.estimator.DNNClassifier(
feature_columns=feature_columns, hidden_units=[10, 20, 10], n_classes=3)
- # Fit and predict.
- classifier.fit(x_train, y_train, steps=200)
- score = metrics.accuracy_score(y_test, classifier.predict(x_test))
- print('Accuracy: {0:f}'.format(score))
+ # Train.
+ train_input_fn = tf.estimator.inputs.numpy_input_fn(
+ x={X_FEATURE: x_train}, y=y_train, num_epochs=None, shuffle=True)
+ classifier.train(input_fn=train_input_fn, steps=200)
+
+ # Predict.
+ test_input_fn = tf.estimator.inputs.numpy_input_fn(
+ x={X_FEATURE: x_test}, y=y_test, num_epochs=1, shuffle=False)
+ predictions = classifier.predict(input_fn=test_input_fn)
+ y_predicted = np.array(list(p['class_ids'] for p in predictions))
+ y_predicted = y_predicted.reshape(np.array(y_test).shape)
+
+ # Score with sklearn.
+ score = metrics.accuracy_score(y_test, y_predicted)
+ print('Accuracy (sklearn): {0:f}'.format(score))
+
+ # Score with tensorflow.
+ scores = classifier.evaluate(input_fn=test_input_fn)
+ print('Accuracy (tensorflow): {0:f}'.format(scores['accuracy']))
if __name__ == '__main__':
diff --git a/tensorflow/examples/learn/mnist.py b/tensorflow/examples/learn/mnist.py
index 15cf4b91dd..5344526b52 100644
--- a/tensorflow/examples/learn/mnist.py
+++ b/tensorflow/examples/learn/mnist.py
@@ -22,89 +22,110 @@ from __future__ import division
from __future__ import print_function
import numpy as np
-from sklearn import metrics
import tensorflow as tf
-layers = tf.contrib.layers
-learn = tf.contrib.learn
+N_DIGITS = 10 # Number of digits.
+X_FEATURE = 'x' # Name of the input feature.
-def max_pool_2x2(tensor_in):
- return tf.nn.max_pool(
- tensor_in, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
-
-def conv_model(feature, target, mode):
+def conv_model(features, labels, mode):
"""2-layer convolution model."""
- # Convert the target to a one-hot tensor of shape (batch_size, 10) and
- # with a on-value of 1 for each one-hot vector of length 10.
- target = tf.one_hot(tf.cast(target, tf.int32), 10, 1, 0)
-
# Reshape feature to 4d tensor with 2nd and 3rd dimensions being
# image width and height final dimension being the number of color channels.
- feature = tf.reshape(feature, [-1, 28, 28, 1])
+ feature = tf.reshape(features[X_FEATURE], [-1, 28, 28, 1])
# First conv layer will compute 32 features for each 5x5 patch
with tf.variable_scope('conv_layer1'):
- h_conv1 = layers.convolution2d(
- feature, 32, kernel_size=[5, 5], activation_fn=tf.nn.relu)
- h_pool1 = max_pool_2x2(h_conv1)
+ h_conv1 = tf.layers.conv2d(
+ feature,
+ filters=32,
+ kernel_size=[5, 5],
+ padding='same',
+ activation=tf.nn.relu)
+ h_pool1 = tf.layers.max_pooling2d(
+ h_conv1, pool_size=2, strides=2, padding='same')
# Second conv layer will compute 64 features for each 5x5 patch.
with tf.variable_scope('conv_layer2'):
- h_conv2 = layers.convolution2d(
- h_pool1, 64, kernel_size=[5, 5], activation_fn=tf.nn.relu)
- h_pool2 = max_pool_2x2(h_conv2)
+ h_conv2 = tf.layers.conv2d(
+ h_pool1,
+ filters=64,
+ kernel_size=[5, 5],
+ padding='same',
+ activation=tf.nn.relu)
+ h_pool2 = tf.layers.max_pooling2d(
+ h_conv2, pool_size=2, strides=2, padding='same')
# reshape tensor into a batch of vectors
h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64])
# Densely connected layer with 1024 neurons.
- h_fc1 = layers.dropout(
- layers.fully_connected(
- h_pool2_flat, 1024, activation_fn=tf.nn.relu),
- keep_prob=0.5,
- is_training=mode == tf.contrib.learn.ModeKeys.TRAIN)
+ h_fc1 = tf.layers.dense(h_pool2_flat, 1024, activation=tf.nn.relu)
+ if mode == tf.estimator.ModeKeys.TRAIN:
+ h_fc1 = tf.layers.dropout(h_fc1, rate=0.5)
# Compute logits (1 per class) and compute loss.
- logits = layers.fully_connected(h_fc1, 10, activation_fn=None)
- loss = tf.losses.softmax_cross_entropy(target, logits)
-
- # Create a tensor for training op.
- train_op = layers.optimize_loss(
- loss,
- tf.contrib.framework.get_global_step(),
- optimizer='SGD',
- learning_rate=0.001)
-
- return tf.argmax(logits, 1), loss, train_op
+ logits = tf.layers.dense(h_fc1, N_DIGITS, activation=None)
+
+ # Compute predictions.
+ predicted_classes = tf.argmax(logits, 1)
+ if mode == tf.estimator.ModeKeys.PREDICT:
+ predictions = {
+ 'class': predicted_classes,
+ 'prob': tf.nn.softmax(logits)
+ }
+ return tf.estimator.EstimatorSpec(mode, predictions=predictions)
+
+ # Compute loss.
+ onehot_labels = tf.one_hot(tf.cast(labels, tf.int32), N_DIGITS, 1, 0)
+ loss = tf.losses.softmax_cross_entropy(
+ onehot_labels=onehot_labels, logits=logits)
+
+ # Create training op.
+ if mode == tf.estimator.ModeKeys.TRAIN:
+ optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)
+ train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step())
+ return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)
+
+ # Compute evaluation metrics.
+ eval_metric_ops = {
+ 'accuracy': tf.metrics.accuracy(
+ labels=labels, predictions=predicted_classes)
+ }
+ return tf.estimator.EstimatorSpec(
+ mode, loss=loss, eval_metric_ops=eval_metric_ops)
def main(unused_args):
### Download and load MNIST dataset.
- mnist = learn.datasets.load_dataset('mnist')
+ mnist = tf.contrib.learn.datasets.DATASETS['mnist']('/tmp/mnist')
+ train_input_fn = tf.estimator.inputs.numpy_input_fn(
+ x={X_FEATURE: mnist.train.images},
+ y=mnist.train.labels.astype(np.int32),
+ batch_size=100,
+ num_epochs=None,
+ shuffle=True)
+ test_input_fn = tf.estimator.inputs.numpy_input_fn(
+ x={X_FEATURE: mnist.train.images},
+ y=mnist.train.labels.astype(np.int32),
+ num_epochs=1,
+ shuffle=False)
### Linear classifier.
- feature_columns = learn.infer_real_valued_columns_from_input(
- mnist.train.images)
- classifier = learn.LinearClassifier(
- feature_columns=feature_columns, n_classes=10)
- classifier.fit(mnist.train.images,
- mnist.train.labels.astype(np.int32),
- batch_size=100,
- steps=1000)
- score = metrics.accuracy_score(mnist.test.labels,
- list(classifier.predict(mnist.test.images)))
- print('Accuracy: {0:f}'.format(score))
+ feature_columns = [
+ tf.feature_column.numeric_column(
+ X_FEATURE, shape=mnist.train.images.shape[1:])]
+ classifier = tf.estimator.LinearClassifier(
+ feature_columns=feature_columns, n_classes=N_DIGITS)
+ classifier.train(input_fn=train_input_fn, steps=200)
+ scores = classifier.evaluate(input_fn=test_input_fn)
+ print('Accuracy (LinearClassifier): {0:f}'.format(scores['accuracy']))
### Convolutional network
- classifier = learn.Estimator(model_fn=conv_model)
- classifier.fit(mnist.train.images,
- mnist.train.labels,
- batch_size=100,
- steps=20000)
- score = metrics.accuracy_score(mnist.test.labels,
- list(classifier.predict(mnist.test.images)))
- print('Accuracy: {0:f}'.format(score))
+ classifier = tf.estimator.Estimator(model_fn=conv_model)
+ classifier.train(input_fn=train_input_fn, steps=200)
+ scores = classifier.evaluate(input_fn=test_input_fn)
+ print('Accuracy (conv_model): {0:f}'.format(scores['accuracy']))
if __name__ == '__main__':
diff --git a/tensorflow/examples/learn/multiple_gpu.py b/tensorflow/examples/learn/multiple_gpu.py
index df58906b39..c7364d1f72 100644
--- a/tensorflow/examples/learn/multiple_gpu.py
+++ b/tensorflow/examples/learn/multiple_gpu.py
@@ -20,75 +20,100 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from sklearn import cross_validation
+import numpy as np
from sklearn import datasets
from sklearn import metrics
+from sklearn import model_selection
import tensorflow as tf
-layers = tf.contrib.layers
-learn = tf.contrib.learn
+X_FEATURE = 'x' # Name of the input feature.
-def my_model(features, target):
+
+def my_model(features, labels, mode):
"""DNN with three hidden layers, and dropout of 0.1 probability.
Note: If you want to run this example with multiple GPUs, Cuda Toolkit 7.0 and
CUDNN 6.5 V2 from NVIDIA need to be installed beforehand.
Args:
- features: `Tensor` of input features.
- target: `Tensor` of targets.
+ features: Dict of input `Tensor`.
+ labels: Label `Tensor`.
+ mode: One of `ModeKeys`.
Returns:
- Tuple of predictions, loss and training op.
+ `EstimatorSpec`.
"""
- # Convert the target to a one-hot tensor of shape (length of features, 3) and
- # with a on-value of 1 for each one-hot vector of length 3.
- target = tf.one_hot(target, 3, 1, 0)
-
# Create three fully connected layers respectively of size 10, 20, and 10 with
# each layer having a dropout probability of 0.1.
- normalizer_fn = layers.dropout
- normalizer_params = {'keep_prob': 0.5}
+ net = features[X_FEATURE]
with tf.device('/gpu:1'):
- features = layers.stack(
- features,
- layers.fully_connected, [10, 20, 10],
- normalizer_fn=normalizer_fn,
- normalizer_params=normalizer_params)
+ for units in [10, 20, 10]:
+ net = tf.layers.dense(net, units=units, activation=tf.nn.relu)
+ net = tf.layers.dropout(net, rate=0.1)
with tf.device('/gpu:2'):
- # Compute logits (1 per class) and compute loss.
- logits = layers.fully_connected(features, 3, activation_fn=None)
- loss = tf.losses.softmax_cross_entropy(target, logits)
-
- # Create a tensor for training op.
- train_op = tf.contrib.layers.optimize_loss(
- loss,
- tf.contrib.framework.get_global_step(),
- optimizer='Adagrad',
- learning_rate=0.1)
-
- return ({
- 'class': tf.argmax(logits, 1),
- 'prob': tf.nn.softmax(logits)
- }, loss, train_op)
+ # Compute logits (1 per class).
+ logits = tf.layers.dense(net, 3, activation=None)
+
+ # Compute predictions.
+ predicted_classes = tf.argmax(logits, 1)
+ if mode == tf.estimator.ModeKeys.PREDICT:
+ predictions = {
+ 'class': predicted_classes,
+ 'prob': tf.nn.softmax(logits)
+ }
+ return tf.estimator.EstimatorSpec(mode, predictions=predictions)
+
+ # Convert the labels to a one-hot tensor of shape (length of features, 3)
+ # and with a on-value of 1 for each one-hot vector of length 3.
+ onehot_labels = tf.one_hot(labels, 3, 1, 0)
+ # Compute loss.
+ loss = tf.losses.softmax_cross_entropy(
+ onehot_labels=onehot_labels, logits=logits)
+
+ # Create training op.
+ if mode == tf.estimator.ModeKeys.TRAIN:
+ optimizer = tf.train.AdagradOptimizer(learning_rate=0.1)
+ train_op = optimizer.minimize(
+ loss, global_step=tf.train.get_global_step())
+ return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)
+
+ # Compute evaluation metrics.
+ eval_metric_ops = {
+ 'accuracy': tf.metrics.accuracy(
+ labels=labels, predictions=predicted_classes)
+ }
+ return tf.estimator.EstimatorSpec(
+ mode, loss=loss, eval_metric_ops=eval_metric_ops)
def main(unused_argv):
iris = datasets.load_iris()
- x_train, x_test, y_train, y_test = cross_validation.train_test_split(
+ x_train, x_test, y_train, y_test = model_selection.train_test_split(
iris.data, iris.target, test_size=0.2, random_state=42)
- classifier = learn.Estimator(model_fn=my_model)
- classifier.fit(x_train, y_train, steps=1000)
+ classifier = tf.estimator.Estimator(model_fn=my_model)
- y_predicted = [
- p['class'] for p in classifier.predict(
- x_test, as_iterable=True)
- ]
+ # Train.
+ train_input_fn = tf.estimator.inputs.numpy_input_fn(
+ x={X_FEATURE: x_train}, y=y_train, num_epochs=None, shuffle=True)
+ classifier.train(input_fn=train_input_fn, steps=100)
+
+ # Predict.
+ test_input_fn = tf.estimator.inputs.numpy_input_fn(
+ x={X_FEATURE: x_test}, y=y_test, num_epochs=1, shuffle=False)
+ predictions = classifier.predict(input_fn=test_input_fn)
+ y_predicted = np.array(list(p['class'] for p in predictions))
+ y_predicted = y_predicted.reshape(np.array(y_test).shape)
+
+ # Score with sklearn.
score = metrics.accuracy_score(y_test, y_predicted)
- print('Accuracy: {0:f}'.format(score))
+ print('Accuracy (sklearn): {0:f}'.format(score))
+
+ # Score with tensorflow.
+ scores = classifier.evaluate(input_fn=test_input_fn)
+ print('Accuracy (tensorflow): {0:f}'.format(scores['accuracy']))
if __name__ == '__main__':
diff --git a/tensorflow/examples/learn/resnet.py b/tensorflow/examples/learn/resnet.py
index 881905fde8..33a09bb6e0 100755
--- a/tensorflow/examples/learn/resnet.py
+++ b/tensorflow/examples/learn/resnet.py
@@ -25,31 +25,17 @@ from __future__ import print_function
from collections import namedtuple
from math import sqrt
-import os
+import numpy as np
import tensorflow as tf
-batch_norm = tf.contrib.layers.batch_norm
-convolution2d = tf.contrib.layers.convolution2d
+N_DIGITS = 10 # Number of digits.
+X_FEATURE = 'x' # Name of the input feature.
-def res_net(x, y, activation=tf.nn.relu):
- """Builds a residual network.
- Note that if the input tensor is 2D, it must be square in order to be
- converted to a 4D tensor.
-
- Borrowed structure from:
- github.com/pkmital/tensorflow_tutorials/blob/master/10_residual_network.py
-
- Args:
- x: Input of the network
- y: Output of the network
- activation: Activation function to apply after each convolution
-
- Returns:
- Predictions and loss tensors.
- """
+def res_net_model(features, labels, mode):
+ """Builds a residual network."""
# Configurations for each bottleneck group.
BottleneckGroup = namedtuple('BottleneckGroup',
@@ -59,6 +45,7 @@ def res_net(x, y, activation=tf.nn.relu):
BottleneckGroup(3, 512, 128), BottleneckGroup(3, 1024, 256)
]
+ x = features[X_FEATURE]
input_shape = x.get_shape().as_list()
# Reshape the input into the right shape if it's 2D tensor
@@ -68,15 +55,24 @@ def res_net(x, y, activation=tf.nn.relu):
# First convolution expands to 64 channels
with tf.variable_scope('conv_layer1'):
- net = convolution2d(
- x, 64, 7, normalizer_fn=batch_norm, activation_fn=activation)
+ net = tf.layers.conv2d(
+ x,
+ filters=64,
+ kernel_size=7,
+ activation=tf.nn.relu)
+ net = tf.layers.batch_normalization(net)
# Max pool
- net = tf.nn.max_pool(net, [1, 3, 3, 1], strides=[1, 2, 2, 1], padding='SAME')
+ net = tf.layers.max_pooling2d(
+ net, pool_size=3, strides=2, padding='same')
# First chain of resnets
with tf.variable_scope('conv_layer2'):
- net = convolution2d(net, groups[0].num_filters, 1, padding='VALID')
+ net = tf.layers.conv2d(
+ net,
+ filters=groups[0].num_filters,
+ kernel_size=1,
+ padding='valid')
# Create the bottleneck groups, each of which contains `num_blocks`
# bottleneck groups.
@@ -86,33 +82,33 @@ def res_net(x, y, activation=tf.nn.relu):
# 1x1 convolution responsible for reducing dimension
with tf.variable_scope(name + '/conv_in'):
- conv = convolution2d(
+ conv = tf.layers.conv2d(
net,
- group.bottleneck_size,
- 1,
- padding='VALID',
- activation_fn=activation,
- normalizer_fn=batch_norm)
+ filters=group.num_filters,
+ kernel_size=1,
+ padding='valid',
+ activation=tf.nn.relu)
+ conv = tf.layers.batch_normalization(conv)
with tf.variable_scope(name + '/conv_bottleneck'):
- conv = convolution2d(
+ conv = tf.layers.conv2d(
conv,
- group.bottleneck_size,
- 3,
- padding='SAME',
- activation_fn=activation,
- normalizer_fn=batch_norm)
+ filters=group.bottleneck_size,
+ kernel_size=3,
+ padding='same',
+ activation=tf.nn.relu)
+ conv = tf.layers.batch_normalization(conv)
# 1x1 convolution responsible for restoring dimension
with tf.variable_scope(name + '/conv_out'):
input_dim = net.get_shape()[-1].value
- conv = convolution2d(
+ conv = tf.layers.conv2d(
conv,
- input_dim,
- 1,
- padding='VALID',
- activation_fn=activation,
- normalizer_fn=batch_norm)
+ filters=input_dim,
+ kernel_size=1,
+ padding='valid',
+ activation=tf.nn.relu)
+ conv = tf.layers.batch_normalization(conv)
# shortcut connections that turn the network into its counterpart
# residual function (identity shortcut)
@@ -122,13 +118,13 @@ def res_net(x, y, activation=tf.nn.relu):
# upscale to the next group size
next_group = groups[group_i + 1]
with tf.variable_scope('block_%d/conv_upscale' % group_i):
- net = convolution2d(
+ net = tf.layers.conv2d(
net,
- next_group.num_filters,
- 1,
- activation_fn=None,
- biases_initializer=None,
- padding='SAME')
+ filters=next_group.num_filters,
+ kernel_size=1,
+ padding='same',
+ activation=None,
+ bias_initializer=None)
except IndexError:
pass
@@ -142,48 +138,65 @@ def res_net(x, y, activation=tf.nn.relu):
net_shape = net.get_shape().as_list()
net = tf.reshape(net, [-1, net_shape[1] * net_shape[2] * net_shape[3]])
- target = tf.one_hot(y, depth=10, dtype=tf.float32)
- logits = tf.contrib.layers.fully_connected(net, 10, activation_fn=None)
- loss = tf.losses.softmax_cross_entropy(target, logits)
- return tf.nn.softmax(logits), loss
-
-
-def res_net_model(x, y):
- prediction, loss = res_net(x, y)
- predicted = tf.argmax(prediction, 1)
- accuracy = tf.equal(predicted, tf.cast(y, tf.int64))
- predictions = {'prob': prediction, 'class': predicted, 'accuracy': accuracy}
- train_op = tf.contrib.layers.optimize_loss(
- loss,
- tf.contrib.framework.get_global_step(),
- optimizer='Adagrad',
- learning_rate=0.001)
- return predictions, loss, train_op
-
-
-# Download and load MNIST data.
-mnist = tf.contrib.learn.datasets.load_dataset('mnist')
-
-# Create a new resnet classifier.
-classifier = tf.contrib.learn.Estimator(model_fn=res_net_model)
-
-tf.logging.set_verbosity(tf.logging.INFO) # Show training logs. (avoid silence)
-
-# Train model and save summaries into logdir.
-classifier.fit(mnist.train.images,
- mnist.train.labels,
- batch_size=100,
- steps=1000)
-
-# Calculate accuracy.
-result = classifier.evaluate(
- x=mnist.test.images,
- y=mnist.test.labels,
- metrics={
- 'accuracy':
- tf.contrib.learn.MetricSpec(
- metric_fn=tf.contrib.metrics.streaming_accuracy,
- prediction_key='accuracy'),
- })
-score = result['accuracy']
-print('Accuracy: {0:f}'.format(score))
+ # Compute logits (1 per class) and compute loss.
+ logits = tf.layers.dense(net, N_DIGITS, activation=None)
+
+ # Compute predictions.
+ predicted_classes = tf.argmax(logits, 1)
+ if mode == tf.estimator.ModeKeys.PREDICT:
+ predictions = {
+ 'class': predicted_classes,
+ 'prob': tf.nn.softmax(logits)
+ }
+ return tf.estimator.EstimatorSpec(mode, predictions=predictions)
+
+ # Compute loss.
+ onehot_labels = tf.one_hot(tf.cast(labels, tf.int32), N_DIGITS, 1, 0)
+ loss = tf.losses.softmax_cross_entropy(
+ onehot_labels=onehot_labels, logits=logits)
+
+ # Create training op.
+ if mode == tf.estimator.ModeKeys.TRAIN:
+ optimizer = tf.train.AdagradOptimizer(learning_rate=0.01)
+ train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step())
+ return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)
+
+ # Compute evaluation metrics.
+ eval_metric_ops = {
+ 'accuracy': tf.metrics.accuracy(
+ labels=labels, predictions=predicted_classes)
+ }
+ return tf.estimator.EstimatorSpec(
+ mode, loss=loss, eval_metric_ops=eval_metric_ops)
+
+
+def main(unused_args):
+ # Download and load MNIST data.
+ mnist = tf.contrib.learn.datasets.DATASETS['mnist']('/tmp/mnist')
+
+ # Create a new resnet classifier.
+ classifier = tf.estimator.Estimator(model_fn=res_net_model)
+
+ tf.logging.set_verbosity(tf.logging.INFO) # Show training logs.
+
+ # Train model and save summaries into logdir.
+ train_input_fn = tf.estimator.inputs.numpy_input_fn(
+ x={X_FEATURE: mnist.train.images},
+ y=mnist.train.labels.astype(np.int32),
+ batch_size=100,
+ num_epochs=None,
+ shuffle=True)
+ classifier.train(input_fn=train_input_fn, steps=100)
+
+ # Calculate accuracy.
+ test_input_fn = tf.estimator.inputs.numpy_input_fn(
+ x={X_FEATURE: mnist.train.images},
+ y=mnist.train.labels.astype(np.int32),
+ num_epochs=1,
+ shuffle=False)
+ scores = classifier.evaluate(input_fn=test_input_fn)
+ print('Accuracy: {0:f}'.format(scores['accuracy']))
+
+
+if __name__ == '__main__':
+ tf.app.run()