aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/models/image/cifar10
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/models/image/cifar10')
-rw-r--r--tensorflow/models/image/cifar10/BUILD79
-rw-r--r--tensorflow/models/image/cifar10/README.md10
-rwxr-xr-xtensorflow/models/image/cifar10/__init__.py0
-rw-r--r--tensorflow/models/image/cifar10/cifar10.py480
-rw-r--r--tensorflow/models/image/cifar10/cifar10_eval.py148
-rw-r--r--tensorflow/models/image/cifar10/cifar10_input.py65
-rw-r--r--tensorflow/models/image/cifar10/cifar10_input_test.py49
-rw-r--r--tensorflow/models/image/cifar10/cifar10_multi_gpu_train.py265
-rw-r--r--tensorflow/models/image/cifar10/cifar10_train.py119
9 files changed, 1215 insertions, 0 deletions
diff --git a/tensorflow/models/image/cifar10/BUILD b/tensorflow/models/image/cifar10/BUILD
new file mode 100644
index 0000000000..adf9aaffd4
--- /dev/null
+++ b/tensorflow/models/image/cifar10/BUILD
@@ -0,0 +1,79 @@
+# Description:
+# Example TensorFlow models for CIFAR-10
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+py_library(
+ name = "cifar10_input",
+ srcs = ["cifar10_input.py"],
+ deps = [
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
+py_test(
+ name = "cifar10_input_test",
+ srcs = ["cifar10_input_test.py"],
+ deps = [
+ ":cifar10_input",
+ "//tensorflow:tensorflow_py",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+py_library(
+ name = "cifar10",
+ srcs = ["cifar10.py"],
+ deps = [
+ ":cifar10_input",
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
+py_binary(
+ name = "cifar10_eval",
+ srcs = [
+ "cifar10_eval.py",
+ ],
+ visibility = ["//tensorflow:__subpackages__"],
+ deps = [
+ ":cifar10",
+ ],
+)
+
+py_binary(
+ name = "cifar10_train",
+ srcs = [
+ "cifar10_train.py",
+ ],
+ visibility = ["//tensorflow:__subpackages__"],
+ deps = [
+ ":cifar10",
+ ],
+)
+
+py_binary(
+ name = "cifar10_multi_gpu_train",
+ srcs = [
+ "cifar10_multi_gpu_train.py",
+ ],
+ visibility = ["//tensorflow:__subpackages__"],
+ deps = [
+ ":cifar10",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/models/image/cifar10/README.md b/tensorflow/models/image/cifar10/README.md
new file mode 100644
index 0000000000..67877aedc0
--- /dev/null
+++ b/tensorflow/models/image/cifar10/README.md
@@ -0,0 +1,10 @@
+CIFAR-10 is a common benchmark in machine learning for image recognition.
+
+http://www.cs.toronto.edu/~kriz/cifar.html
+
+Code in this directory demonstrates how to use TensorFlow to train and evaluate a convolutional neural network (CNN) on both CPU and GPU. We also demonstrate how to train a CNN over multiple GPUs.
+
+Detailed instructions on how to get started available at:
+
+http://tensorflow.org/tutorials/deep_cnn/
+
diff --git a/tensorflow/models/image/cifar10/__init__.py b/tensorflow/models/image/cifar10/__init__.py
new file mode 100755
index 0000000000..e69de29bb2
--- /dev/null
+++ b/tensorflow/models/image/cifar10/__init__.py
diff --git a/tensorflow/models/image/cifar10/cifar10.py b/tensorflow/models/image/cifar10/cifar10.py
new file mode 100644
index 0000000000..7870080820
--- /dev/null
+++ b/tensorflow/models/image/cifar10/cifar10.py
@@ -0,0 +1,480 @@
+"""Builds the CIFAR-10 network.
+
+Summary of available functions:
+
+ # Compute input images and labels for training. If you would like to run
+ # evaluations, use input() instead.
+ inputs, labels = distorted_inputs()
+
+ # Compute inference on the model inputs to make a prediction.
+ predictions = inference(inputs)
+
+ # Compute the total loss of the prediction with respect to the labels.
+ loss = loss(predictions, labels)
+
+ # Create a graph to run one step of training with respect to the loss.
+ train_op = train(loss, global_step)
+"""
+# pylint: disable=missing-docstring
+import gzip
+import os
+import re
+import sys
+import tarfile
+import urllib
+
+import tensorflow.python.platform
+import tensorflow as tf
+
+from tensorflow.models.image.cifar10 import cifar10_input
+from tensorflow.python.platform import gfile
+
+FLAGS = tf.app.flags.FLAGS
+
+# Basic model parameters.
+tf.app.flags.DEFINE_integer('batch_size', 128,
+ """Number of images to process in a batch.""")
+tf.app.flags.DEFINE_string('data_dir', '/tmp/cifar10_data',
+ """Path to the CIFAR-10 data directory.""")
+
+# Process images of this size. Note that this differs from the original CIFAR
+# image size of 32 x 32. If one alters this number, then the entire model
+# architecture will change and any model would need to be retrained.
+IMAGE_SIZE = 24
+
+# Global constants describing the CIFAR-10 data set.
+NUM_CLASSES = 10
+NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = 50000
+NUM_EXAMPLES_PER_EPOCH_FOR_EVAL = 10000
+
+# Constants describing the training process.
+MOVING_AVERAGE_DECAY = 0.9999 # The decay to use for the moving average.
+NUM_EPOCHS_PER_DECAY = 350.0 # Epochs after which learning rate decays.
+LEARNING_RATE_DECAY_FACTOR = 0.1 # Learning rate decay factor.
+INITIAL_LEARNING_RATE = 0.1 # Initial learning rate.
+
+# If a model is trained with multiple GPU's prefix all Op names with tower_name
+# to differentiate the operations. Note that this prefix is removed from the
+# names of the summaries when visualizing a model.
+TOWER_NAME = 'tower'
+
+DATA_URL = 'http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz'
+
+
+def _activation_summary(x):
+ """Helper to create summaries for activations.
+
+ Creates a summary that provides a histogram of activations.
+ Creates a summary that measure the sparsity of activations.
+
+ Args:
+ x: Tensor
+ Returns:
+ nothing
+ """
+ # Remove 'tower_[0-9]/' from the name in case this is a multi-GPU training
+ # session. This helps the clarity of presentation on tensorboard.
+ tensor_name = re.sub('%s_[0-9]*/' % TOWER_NAME, '', x.op.name)
+ tf.histogram_summary(tensor_name + '/activations', x)
+ tf.scalar_summary(tensor_name + '/sparsity', tf.nn.zero_fraction(x))
+
+
+def _variable_on_cpu(name, shape, initializer):
+ """Helper to create a Variable stored on CPU memory.
+
+ Args:
+ name: name of the variable
+ shape: list of ints
+ initializer: initializer for Variable
+
+ Returns:
+ Variable Tensor
+ """
+ with tf.device('/cpu:0'):
+ var = tf.get_variable(name, shape, initializer=initializer)
+ return var
+
+
+def _variable_with_weight_decay(name, shape, stddev, wd):
+ """Helper to create an initialized Variable with weight decay.
+
+ Note that the Variable is initialized with a truncated normal distribution.
+ A weight decay is added only if one is specified.
+
+ Args:
+ name: name of the variable
+ shape: list of ints
+ stddev: standard deviation of a truncated Gaussian
+ wd: add L2Loss weight decay multiplied by this float. If None, weight
+ decay is not added for this Variable.
+
+ Returns:
+ Variable Tensor
+ """
+ var = _variable_on_cpu(name, shape,
+ tf.truncated_normal_initializer(stddev=stddev))
+ if wd:
+ weight_decay = tf.mul(tf.nn.l2_loss(var), wd, name='weight_loss')
+ tf.add_to_collection('losses', weight_decay)
+ return var
+
+
+def _generate_image_and_label_batch(image, label, min_queue_examples):
+ """Construct a queued batch of images and labels.
+
+ Args:
+ image: 3-D Tensor of [IMAGE_SIZE, IMAGE_SIZE, 3] of type.float32.
+ label: 1-D Tensor of type.int32
+ min_queue_examples: int32, minimum number of samples to retain
+ in the queue that provides of batches of examples.
+
+ Returns:
+ images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
+ labels: Labels. 1D tensor of [batch_size] size.
+ """
+ # Create a queue that shuffles the examples, and then
+ # read 'FLAGS.batch_size' images + labels from the example queue.
+ num_preprocess_threads = 16
+ images, label_batch = tf.train.shuffle_batch(
+ [image, label],
+ batch_size=FLAGS.batch_size,
+ num_threads=num_preprocess_threads,
+ capacity=min_queue_examples + 3 * FLAGS.batch_size,
+ min_after_dequeue=min_queue_examples)
+
+ # Display the training images in the visualizer.
+ tf.image_summary('images', images)
+
+ return images, tf.reshape(label_batch, [FLAGS.batch_size])
+
+
+def distorted_inputs():
+ """Construct distorted input for CIFAR training using the Reader ops.
+
+ Raises:
+ ValueError: if no data_dir
+
+ Returns:
+ images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
+ labels: Labels. 1D tensor of [batch_size] size.
+ """
+ filenames = [os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin',
+ 'data_batch_%d.bin' % i)
+ for i in xrange(1, 5)]
+ for f in filenames:
+ if not gfile.Exists(f):
+ raise ValueError('Failed to find file: ' + f)
+
+ # Create a queue that produces the filenames to read.
+ filename_queue = tf.train.string_input_producer(filenames)
+
+ # Read examples from files in the filename queue.
+ read_input = cifar10_input.read_cifar10(filename_queue)
+ reshaped_image = tf.cast(read_input.uint8image, tf.float32)
+
+ height = IMAGE_SIZE
+ width = IMAGE_SIZE
+
+ # Image processing for training the network. Note the many random
+ # distortions applied to the image.
+
+ # Randomly crop a [height, width] section of the image.
+ distorted_image = tf.image.random_crop(reshaped_image, [height, width])
+
+ # Randomly flip the image horizontally.
+ distorted_image = tf.image.random_flip_left_right(distorted_image)
+
+ # Because these operations are not commutative, consider randomizing
+ # randomize the order their operation.
+ distorted_image = tf.image.random_brightness(distorted_image,
+ max_delta=63)
+ distorted_image = tf.image.random_contrast(distorted_image,
+ lower=0.2, upper=1.8)
+
+ # Subtract off the mean and divide by the variance of the pixels.
+ float_image = tf.image.per_image_whitening(distorted_image)
+
+ # Ensure that the random shuffling has good mixing properties.
+ min_fraction_of_examples_in_queue = 0.4
+ min_queue_examples = int(NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN *
+ min_fraction_of_examples_in_queue)
+ print ('Filling queue with %d CIFAR images before starting to train. '
+ 'This will take a few minutes.' % min_queue_examples)
+
+ # Generate a batch of images and labels by building up a queue of examples.
+ return _generate_image_and_label_batch(float_image, read_input.label,
+ min_queue_examples)
+
+
+def inputs(eval_data):
+ """Construct input for CIFAR evaluation using the Reader ops.
+
+ Args:
+ eval_data: bool, indicating if one should use the train or eval data set.
+
+ Raises:
+ ValueError: if no data_dir
+
+ Returns:
+ images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
+ labels: Labels. 1D tensor of [batch_size] size.
+ """
+ if not FLAGS.data_dir:
+ raise ValueError('Please supply a data_dir')
+
+ if not eval_data:
+ filenames = [os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin',
+ 'data_batch_%d.bin' % i)
+ for i in xrange(1, 5)]
+ num_examples_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN
+ else:
+ filenames = [os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin',
+ 'test_batch.bin')]
+ num_examples_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_EVAL
+
+ for f in filenames:
+ if not gfile.Exists(f):
+ raise ValueError('Failed to find file: ' + f)
+
+ # Create a queue that produces the filenames to read.
+ filename_queue = tf.train.string_input_producer(filenames)
+
+ # Read examples from files in the filename queue.
+ read_input = cifar10_input.read_cifar10(filename_queue)
+ reshaped_image = tf.cast(read_input.uint8image, tf.float32)
+
+ height = IMAGE_SIZE
+ width = IMAGE_SIZE
+
+ # Image processing for evaluation.
+ # Crop the central [height, width] of the image.
+ resized_image = tf.image.resize_image_with_crop_or_pad(reshaped_image,
+ width, height)
+
+ # Subtract off the mean and divide by the variance of the pixels.
+ float_image = tf.image.per_image_whitening(resized_image)
+
+ # Ensure that the random shuffling has good mixing properties.
+ min_fraction_of_examples_in_queue = 0.4
+ min_queue_examples = int(num_examples_per_epoch *
+ min_fraction_of_examples_in_queue)
+
+ # Generate a batch of images and labels by building up a queue of examples.
+ return _generate_image_and_label_batch(float_image, read_input.label,
+ min_queue_examples)
+
+
+def inference(images):
+ """Build the CIFAR-10 model.
+
+ Args:
+ images: Images returned from distorted_inputs() or inputs().
+
+ Returns:
+ Logits.
+ """
+ # We instantiate all variables using tf.get_variable() instead of
+ # tf.Variable() in order to share variables across multiple GPU training runs.
+ # If we only ran this model on a single GPU, we could simplify this function
+ # by replacing all instances of tf.get_variable() with tf.Variable().
+ #
+ # conv1
+ with tf.variable_scope('conv1') as scope:
+ kernel = _variable_with_weight_decay('weights', shape=[5, 5, 3, 64],
+ stddev=1e-4, wd=0.0)
+ conv = tf.nn.conv2d(images, kernel, [1, 1, 1, 1], padding='SAME')
+ biases = _variable_on_cpu('biases', [64], tf.constant_initializer(0.0))
+ bias = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape().as_list())
+ conv1 = tf.nn.relu(bias, name=scope.name)
+ _activation_summary(conv1)
+
+ # pool1
+ pool1 = tf.nn.max_pool(conv1, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1],
+ padding='SAME', name='pool1')
+ # norm1
+ norm1 = tf.nn.lrn(pool1, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75,
+ name='norm1')
+
+ # conv2
+ with tf.variable_scope('conv2') as scope:
+ kernel = _variable_with_weight_decay('weights', shape=[5, 5, 64, 64],
+ stddev=1e-4, wd=0.0)
+ conv = tf.nn.conv2d(norm1, kernel, [1, 1, 1, 1], padding='SAME')
+ biases = _variable_on_cpu('biases', [64], tf.constant_initializer(0.1))
+ bias = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape().as_list())
+ conv2 = tf.nn.relu(bias, name=scope.name)
+ _activation_summary(conv2)
+
+ # norm2
+ norm2 = tf.nn.lrn(conv2, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75,
+ name='norm2')
+ # pool2
+ pool2 = tf.nn.max_pool(norm2, ksize=[1, 3, 3, 1],
+ strides=[1, 2, 2, 1], padding='SAME', name='pool2')
+
+ # local3
+ with tf.variable_scope('local3') as scope:
+ # Move everything into depth so we can perform a single matrix multiply.
+ dim = 1
+ for d in pool2.get_shape()[1:].as_list():
+ dim *= d
+ reshape = tf.reshape(pool2, [FLAGS.batch_size, dim])
+
+ weights = _variable_with_weight_decay('weights', shape=[dim, 384],
+ stddev=0.04, wd=0.004)
+ biases = _variable_on_cpu('biases', [384], tf.constant_initializer(0.1))
+ local3 = tf.nn.relu_layer(reshape, weights, biases, name=scope.name)
+ _activation_summary(local3)
+
+ # local4
+ with tf.variable_scope('local4') as scope:
+ weights = _variable_with_weight_decay('weights', shape=[384, 192],
+ stddev=0.04, wd=0.004)
+ biases = _variable_on_cpu('biases', [192], tf.constant_initializer(0.1))
+ local4 = tf.nn.relu_layer(local3, weights, biases, name=scope.name)
+ _activation_summary(local4)
+
+ # softmax, i.e. softmax(WX + b)
+ with tf.variable_scope('softmax_linear') as scope:
+ weights = _variable_with_weight_decay('weights', [192, NUM_CLASSES],
+ stddev=1/192.0, wd=0.0)
+ biases = _variable_on_cpu('biases', [NUM_CLASSES],
+ tf.constant_initializer(0.0))
+ softmax_linear = tf.nn.xw_plus_b(local4, weights, biases, name=scope.name)
+ _activation_summary(softmax_linear)
+
+ return softmax_linear
+
+
+def loss(logits, labels):
+ """Add L2Loss to all the trainable variables.
+
+ Add summary for for "Loss" and "Loss/avg".
+ Args:
+ logits: Logits from inference().
+ labels: Labels from distorted_inputs or inputs(). 1-D tensor
+ of shape [batch_size]
+
+ Returns:
+ Loss tensor of type float.
+ """
+ # Reshape the labels into a dense Tensor of
+ # shape [batch_size, NUM_CLASSES].
+ sparse_labels = tf.reshape(labels, [FLAGS.batch_size, 1])
+ indices = tf.reshape(tf.range(0, FLAGS.batch_size, 1), [FLAGS.batch_size, 1])
+ concated = tf.concat(1, [indices, sparse_labels])
+ dense_labels = tf.sparse_to_dense(concated,
+ [FLAGS.batch_size, NUM_CLASSES],
+ 1.0, 0.0)
+
+ # Calculate the average cross entropy loss across the batch.
+ cross_entropy = tf.nn.softmax_cross_entropy_with_logits(
+ logits, dense_labels, name='cross_entropy_per_example')
+ cross_entropy_mean = tf.reduce_mean(cross_entropy, name='cross_entropy')
+ tf.add_to_collection('losses', cross_entropy_mean)
+
+ # The total loss is defined as the cross entropy loss plus all of the weight
+ # decay terms (L2 loss).
+ return tf.add_n(tf.get_collection('losses'), name='total_loss')
+
+
+def _add_loss_summaries(total_loss):
+ """Add summaries for losses in CIFAR-10 model.
+
+ Generates moving average for all losses and associated summaries for
+ visualizing the performance of the network.
+
+ Args:
+ total_loss: Total loss from loss().
+ Returns:
+ loss_averages_op: op for generating moving averages of losses.
+ """
+ # Compute the moving average of all individual losses and the total loss.
+ loss_averages = tf.train.ExponentialMovingAverage(0.9, name='avg')
+ losses = tf.get_collection('losses')
+ loss_averages_op = loss_averages.apply(losses + [total_loss])
+
+ # Attach a scalar summmary to all individual losses and the total loss; do the
+ # same for the averaged version of the losses.
+ for l in losses + [total_loss]:
+ # Name each loss as '(raw)' and name the moving average version of the loss
+ # as the original loss name.
+ tf.scalar_summary(l.op.name +' (raw)', l)
+ tf.scalar_summary(l.op.name, loss_averages.average(l))
+
+ return loss_averages_op
+
+
+def train(total_loss, global_step):
+ """Train CIFAR-10 model.
+
+ Create an optimizer and apply to all trainable variables. Add moving
+ average for all trainable variables.
+
+ Args:
+ total_loss: Total loss from loss().
+ global_step: Integer Variable counting the number of training steps
+ processed.
+ Returns:
+ train_op: op for training.
+ """
+ # Variables that affect learning rate.
+ num_batches_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN / FLAGS.batch_size
+ decay_steps = int(num_batches_per_epoch * NUM_EPOCHS_PER_DECAY)
+
+ # Decay the learning rate exponentially based on the number of steps.
+ lr = tf.train.exponential_decay(INITIAL_LEARNING_RATE,
+ global_step,
+ decay_steps,
+ LEARNING_RATE_DECAY_FACTOR,
+ staircase=True)
+ tf.scalar_summary('learning_rate', lr)
+
+ # Generate moving averages of all losses and associated summaries.
+ loss_averages_op = _add_loss_summaries(total_loss)
+
+ # Compute gradients.
+ with tf.control_dependencies([loss_averages_op]):
+ opt = tf.train.GradientDescentOptimizer(lr)
+ grads = opt.compute_gradients(total_loss)
+
+ # Apply gradients.
+ apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)
+
+ # Add histograms for trainable variables.
+ for var in tf.trainable_variables():
+ tf.histogram_summary(var.op.name, var)
+
+ # Add histograms for gradients.
+ for grad, var in grads:
+ if grad:
+ tf.histogram_summary(var.op.name + '/gradients', grad)
+
+ # Track the moving averages of all trainable variables.
+ variable_averages = tf.train.ExponentialMovingAverage(
+ MOVING_AVERAGE_DECAY, global_step)
+ variables_averages_op = variable_averages.apply(tf.trainable_variables())
+
+ with tf.control_dependencies([apply_gradient_op, variables_averages_op]):
+ train_op = tf.no_op(name='train')
+
+ return train_op
+
+
+def maybe_download_and_extract():
+ """Download and extract the tarball from Alex's website."""
+ dest_directory = FLAGS.data_dir
+ if not os.path.exists(dest_directory):
+ os.makedirs(dest_directory)
+ filename = DATA_URL.split('/')[-1]
+ filepath = os.path.join(dest_directory, filename)
+ if not os.path.exists(filepath):
+ def _progress(count, block_size, total_size):
+ sys.stdout.write('\r>> Downloading %s %.1f%%' % (filename,
+ float(count * block_size) / float(total_size) * 100.0))
+ sys.stdout.flush()
+ filepath, _ = urllib.urlretrieve(DATA_URL, filepath, reporthook=_progress)
+ print
+ statinfo = os.stat(filepath)
+ print 'Succesfully downloaded', filename, statinfo.st_size, 'bytes.'
+ tarfile.open(filepath, 'r:gz').extractall(dest_directory)
diff --git a/tensorflow/models/image/cifar10/cifar10_eval.py b/tensorflow/models/image/cifar10/cifar10_eval.py
new file mode 100644
index 0000000000..73c224191d
--- /dev/null
+++ b/tensorflow/models/image/cifar10/cifar10_eval.py
@@ -0,0 +1,148 @@
+"""Evaluation for CIFAR-10.
+
+Accuracy:
+cifar10_train.py achieves 83.0% accuracy after 100K steps (256 epochs
+of data) as judged by cifar10_eval.py.
+
+Speed:
+On a single Tesla K40, cifar10_train.py processes a single batch of 128 images
+in 0.25-0.35 sec (i.e. 350 - 600 images /sec). The model reaches ~86%
+accuracy after 100K steps in 8 hours of training time.
+
+Usage:
+Please see the tutorial and website for how to download the CIFAR-10
+data set, compile the program and train the model.
+
+http://tensorflow.org/tutorials/deep_cnn/
+"""
+from datetime import datetime
+import math
+import time
+
+import tensorflow.python.platform
+from tensorflow.python.platform import gfile
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.models.image.cifar10 import cifar10
+
+FLAGS = tf.app.flags.FLAGS
+
+tf.app.flags.DEFINE_string('eval_dir', '/tmp/cifar10_eval',
+ """Directory where to write event logs.""")
+tf.app.flags.DEFINE_string('eval_data', 'test',
+ """Either 'test' or 'train_eval'.""")
+tf.app.flags.DEFINE_string('checkpoint_dir', '/tmp/cifar10_train',
+ """Directory where to read model checkpoints.""")
+tf.app.flags.DEFINE_integer('eval_interval_secs', 60 * 5,
+ """How often to run the eval.""")
+tf.app.flags.DEFINE_integer('num_examples', 10000,
+ """Number of examples to run.""")
+tf.app.flags.DEFINE_boolean('run_once', False,
+ """Whether to run eval only once.""")
+
+
+def eval_once(saver, summary_writer, top_k_op, summary_op):
+ """Run Eval once.
+
+ Args:
+ saver: Saver.
+ summary_writer: Summary writer.
+ top_k_op: Top K op.
+ summary_op: Summary op.
+ """
+ with tf.Session() as sess:
+ ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
+ if ckpt and ckpt.model_checkpoint_path:
+ # Restores from checkpoint
+ saver.restore(sess, ckpt.model_checkpoint_path)
+ # Assuming model_checkpoint_path looks something like:
+ # /my-favorite-path/cifar10_train/model.ckpt-0,
+ # extract global_step from it.
+ global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
+ else:
+ print 'No checkpoint file found'
+ return
+
+ # Start the queue runners.
+ coord = tf.train.Coordinator()
+ try:
+ threads = []
+ for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
+ threads.extend(qr.create_threads(sess, coord=coord, daemon=True,
+ start=True))
+
+ num_iter = int(math.ceil(FLAGS.num_examples / FLAGS.batch_size))
+ true_count = 0 # Counts the number of correct predictions.
+ total_sample_count = num_iter * FLAGS.batch_size
+ step = 0
+ while step < num_iter and not coord.should_stop():
+ predictions = sess.run([top_k_op])
+ true_count += np.sum(predictions)
+ step += 1
+
+ # Compute precision @ 1.
+ precision = float(true_count) / float(total_sample_count)
+ print '%s: precision @ 1 = %.3f' % (datetime.now(), precision)
+
+ summary = tf.Summary()
+ summary.ParseFromString(sess.run(summary_op))
+ summary.value.add(tag='Precision @ 1', simple_value=precision)
+ summary_writer.add_summary(summary, global_step)
+ except Exception, e: # pylint: disable=broad-except
+ coord.request_stop(e)
+
+ coord.request_stop()
+ coord.join(threads, stop_grace_period_secs=10)
+
+
+def evaluate():
+ """Eval CIFAR-10 for a number of steps."""
+ with tf.Graph().as_default():
+ # Get images and labels for CIFAR-10.
+ eval_data = FLAGS.eval_data == 'test'
+ images, labels = cifar10.inputs(eval_data=eval_data)
+
+ # Build a Graph that computes the logits predictions from the
+ # inference model.
+ logits = cifar10.inference(images)
+
+ # Calculate predictions.
+ top_k_op = tf.nn.in_top_k(logits, labels, 1)
+
+ # Restore the moving average version of the learned variables for eval.
+ variable_averages = tf.train.ExponentialMovingAverage(
+ cifar10.MOVING_AVERAGE_DECAY)
+ variables_to_restore = {}
+ for v in tf.all_variables():
+ if v in tf.trainable_variables():
+ restore_name = variable_averages.average_name(v)
+ else:
+ restore_name = v.op.name
+ variables_to_restore[restore_name] = v
+ saver = tf.train.Saver(variables_to_restore)
+
+ # Build the summary operation based on the TF collection of Summaries.
+ summary_op = tf.merge_all_summaries()
+
+ graph_def = tf.get_default_graph().as_graph_def()
+ summary_writer = tf.train.SummaryWriter(FLAGS.eval_dir,
+ graph_def=graph_def)
+
+ while True:
+ eval_once(saver, summary_writer, top_k_op, summary_op)
+ if FLAGS.run_once:
+ break
+ time.sleep(FLAGS.eval_interval_secs)
+
+
+def main(argv=None): # pylint: disable=unused-argument
+ cifar10.maybe_download_and_extract()
+ if gfile.Exists(FLAGS.eval_dir):
+ gfile.DeleteRecursively(FLAGS.eval_dir)
+ gfile.MakeDirs(FLAGS.eval_dir)
+ evaluate()
+
+
+if __name__ == '__main__':
+ tf.app.run()
diff --git a/tensorflow/models/image/cifar10/cifar10_input.py b/tensorflow/models/image/cifar10/cifar10_input.py
new file mode 100644
index 0000000000..686f1bf987
--- /dev/null
+++ b/tensorflow/models/image/cifar10/cifar10_input.py
@@ -0,0 +1,65 @@
+"""Routine for decoding the CIFAR-10 binary file format."""
+
+import tensorflow.python.platform
+import tensorflow as tf
+
+
+def read_cifar10(filename_queue):
+ """Reads and parses examples from CIFAR10 data files.
+
+ Recommendation: if you want N-way read parallelism, call this function
+ N times. This will give you N independent Readers reading different
+ files & positions within those files, which will give better mixing of
+ examples.
+
+ Args:
+ filename_queue: A queue of strings with the filenames to read from.
+
+ Returns:
+ An object representing a single example, with the following fields:
+ height: number of rows in the result (32)
+ width: number of columns in the result (32)
+ depth: number of color channels in the result (3)
+ key: a scalar string Tensor describing the filename & record number
+ for this example.
+ label: an int32 Tensor with the label in the range 0..9.
+ uint8image: a [height, width, depth] uint8 Tensor with the image data
+ """
+
+ class CIFAR10Record(object):
+ pass
+ result = CIFAR10Record()
+
+ # Dimensions of the images in the CIFAR-10 dataset.
+ # See http://www.cs.toronto.edu/~kriz/cifar.html for a description of the
+ # input format.
+ label_bytes = 1 # 2 for CIFAR-100
+ result.height = 32
+ result.width = 32
+ result.depth = 3
+ image_bytes = result.height * result.width * result.depth
+ # Every record consists of a label followed by the image, with a
+ # fixed number of bytes for each.
+ record_bytes = label_bytes + image_bytes
+
+ # Read a record, getting filenames from the filename_queue. No
+ # header or footer in the CIFAR-10 format, so we leave header_bytes
+ # and footer_bytes at their default of 0.
+ reader = tf.FixedLengthRecordReader(record_bytes=record_bytes)
+ result.key, value = reader.read(filename_queue)
+
+ # Convert from a string to a vector of uint8 that is record_bytes long.
+ record_bytes = tf.decode_raw(value, tf.uint8)
+
+ # The first bytes represent the label, which we convert from uint8->int32.
+ result.label = tf.cast(
+ tf.slice(record_bytes, [0], [label_bytes]), tf.int32)
+
+ # The remaining bytes after the label represent the image, which we reshape
+ # from [depth * height * width] to [depth, height, width].
+ depth_major = tf.reshape(tf.slice(record_bytes, [label_bytes], [image_bytes]),
+ [result.depth, result.height, result.width])
+ # Convert from [depth, height, width] to [height, width, depth].
+ result.uint8image = tf.transpose(depth_major, [1, 2, 0])
+
+ return result
diff --git a/tensorflow/models/image/cifar10/cifar10_input_test.py b/tensorflow/models/image/cifar10/cifar10_input_test.py
new file mode 100644
index 0000000000..d43f5aedcf
--- /dev/null
+++ b/tensorflow/models/image/cifar10/cifar10_input_test.py
@@ -0,0 +1,49 @@
+"""Tests for cifar10 input."""
+
+import os
+
+import tensorflow.python.platform
+
+import tensorflow as tf
+
+from tensorflow.models.image.cifar10 import cifar10_input
+
+
+class CIFAR10InputTest(tf.test.TestCase):
+
+ def _record(self, label, red, green, blue):
+ image_size = 32 * 32
+ record = "%s%s%s%s" % (chr(label), chr(red) * image_size,
+ chr(green) * image_size, chr(blue) * image_size)
+ expected = [[[red, green, blue]] * 32] * 32
+ return record, expected
+
+ def testSimple(self):
+ labels = [9, 3, 0]
+ records = [self._record(labels[0], 0, 128, 255),
+ self._record(labels[1], 255, 0, 1),
+ self._record(labels[2], 254, 255, 0)]
+ contents = "".join([record for record, _ in records])
+ expected = [expected for _, expected in records]
+ filename = os.path.join(self.get_temp_dir(), "cifar")
+ open(filename, "w").write(contents)
+
+ with self.test_session() as sess:
+ q = tf.FIFOQueue(99, [tf.string], shapes=())
+ q.enqueue([filename]).run()
+ q.close().run()
+ result = cifar10_input.read_cifar10(q)
+
+ for i in range(3):
+ key, label, uint8image = sess.run([
+ result.key, result.label, result.uint8image])
+ self.assertEqual("%s:%d" % (filename, i), key)
+ self.assertEqual(labels[i], label)
+ self.assertAllEqual(expected[i], uint8image)
+
+ with self.assertRaises(tf.errors.OutOfRangeError):
+ sess.run([result.key, result.uint8image])
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/models/image/cifar10/cifar10_multi_gpu_train.py b/tensorflow/models/image/cifar10/cifar10_multi_gpu_train.py
new file mode 100644
index 0000000000..54bc41f444
--- /dev/null
+++ b/tensorflow/models/image/cifar10/cifar10_multi_gpu_train.py
@@ -0,0 +1,265 @@
+"""A binary to train CIFAR-10 using multiple GPU's with synchronous updates.
+
+Accuracy:
+cifar10_multi_gpu_train.py achieves ~86% accuracy after 100K steps (256
+epochs of data) as judged by cifar10_eval.py.
+
+Speed: With batch_size 128.
+
+System | Step Time (sec/batch) | Accuracy
+--------------------------------------------------------------------
+1 Tesla K20m | 0.35-0.60 | ~86% at 60K steps (5 hours)
+1 Tesla K40m | 0.25-0.35 | ~86% at 100K steps (4 hours)
+2 Tesla K20m | 0.13-0.20 | ~84% at 30K steps (2.5 hours)
+3 Tesla K20m | 0.13-0.18 | ~84% at 30K steps
+4 Tesla K20m | ~0.10 | ~84% at 30K steps
+
+Usage:
+Please see the tutorial and website for how to download the CIFAR-10
+data set, compile the program and train the model.
+
+http://tensorflow.org/tutorials/deep_cnn/
+"""
+from datetime import datetime
+import os.path
+import re
+import time
+
+# pylint: disable=unused-import,g-bad-import-order
+import tensorflow.python.platform
+from tensorflow.python.platform import gfile
+import numpy as np
+import tensorflow as tf
+from tensorflow.models.image.cifar10 import cifar10
+# pylint: disable=unused-import,g-bad-import-order
+
+FLAGS = tf.app.flags.FLAGS
+
+tf.app.flags.DEFINE_string('train_dir', '/tmp/cifar10_train',
+ """Directory where to write event logs """
+ """and checkpoint.""")
+tf.app.flags.DEFINE_integer('max_steps', 1000000,
+ """Number of batches to run.""")
+tf.app.flags.DEFINE_integer('num_gpus', 1,
+ """How many GPUs to use.""")
+tf.app.flags.DEFINE_boolean('log_device_placement', False,
+ """Whether to log device placement.""")
+
+
+def tower_loss(scope):
+ """Calculate the total loss on a single tower running the CIFAR model.
+
+ Args:
+ scope: unique prefix string identifying the CIFAR tower, e.g. 'tower_0'
+
+ Returns:
+ Tensor of shape [] containing the total loss for a batch of data
+ """
+ # Get images and labels for CIFAR-10.
+ images, labels = cifar10.distorted_inputs()
+
+ # Build inference Graph.
+ logits = cifar10.inference(images)
+
+ # Build the portion of the Graph calculating the losses. Note that we will
+ # assemble the total_loss using a custom function below.
+ _ = cifar10.loss(logits, labels)
+
+ # Assemble all of the losses for the current tower only.
+ losses = tf.get_collection('losses', scope)
+
+ # Calculate the total loss for the current tower.
+ total_loss = tf.add_n(losses, name='total_loss')
+
+ # Compute the moving average of all individual losses and the total loss.
+ loss_averages = tf.train.ExponentialMovingAverage(0.9, name='avg')
+ loss_averages_op = loss_averages.apply(losses + [total_loss])
+
+ # Attach a scalar summmary to all individual losses and the total loss; do the
+ # same for the averaged version of the losses.
+ for l in losses + [total_loss]:
+ # Remove 'tower_[0-9]/' from the name in case this is a multi-GPU training
+ # session. This helps the clarity of presentation on tensorboard.
+ loss_name = re.sub('%s_[0-9]*/' % cifar10.TOWER_NAME, '', l.op.name)
+ # Name each loss as '(raw)' and name the moving average version of the loss
+ # as the original loss name.
+ tf.scalar_summary(loss_name +' (raw)', l)
+ tf.scalar_summary(loss_name, loss_averages.average(l))
+
+ with tf.control_dependencies([loss_averages_op]):
+ total_loss = tf.identity(total_loss)
+ return total_loss
+
+
+def average_gradients(tower_grads):
+ """Calculate the average gradient for each shared variable across all towers.
+
+ Note that this function provides a synchronization point across all towers.
+
+ Args:
+ tower_grads: List of lists of (gradient, variable) tuples. The outer list
+ is over individual gradients. The inner list is over the gradient
+ calculation for each tower.
+ Returns:
+ List of pairs of (gradient, variable) where the gradient has been averaged
+ across all towers.
+ """
+ average_grads = []
+ for grad_and_vars in zip(*tower_grads):
+ # Note that each grad_and_vars looks like the following:
+ # ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN))
+ grads = []
+ for g, _ in grad_and_vars:
+ # Add 0 dimension to the gradients to represent the tower.
+ expanded_g = tf.expand_dims(g, 0)
+
+ # Append on a 'tower' dimension which we will average over below.
+ grads.append(expanded_g)
+
+ # Average over the 'tower' dimension.
+ grad = tf.concat(0, grads)
+ grad = tf.reduce_mean(grad, 0)
+
+ # Keep in mind that the Variables are redundant because they are shared
+ # across towers. So .. we will just return the first tower's pointer to
+ # the Variable.
+ v = grad_and_vars[0][1]
+ grad_and_var = (grad, v)
+ average_grads.append(grad_and_var)
+ return average_grads
+
+
+def train():
+ """Train CIFAR-10 for a number of steps."""
+ with tf.Graph().as_default(), tf.device('/cpu:0'):
+ # Create a variable to count the number of train() calls. This equals the
+ # number of batches processed * FLAGS.num_gpus.
+ global_step = tf.get_variable(
+ 'global_step', [],
+ initializer=tf.constant_initializer(0), trainable=False)
+
+ # Calculate the learning rate schedule.
+ num_batches_per_epoch = (cifar10.NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN /
+ FLAGS.batch_size)
+ decay_steps = int(num_batches_per_epoch * cifar10.NUM_EPOCHS_PER_DECAY)
+
+ # Decay the learning rate exponentially based on the number of steps.
+ lr = tf.train.exponential_decay(cifar10.INITIAL_LEARNING_RATE,
+ global_step,
+ decay_steps,
+ cifar10.LEARNING_RATE_DECAY_FACTOR,
+ staircase=True)
+
+ # Create an optimizer that performs gradient descent.
+ opt = tf.train.GradientDescentOptimizer(lr)
+
+ # Calculate the gradients for each model tower.
+ tower_grads = []
+ for i in xrange(FLAGS.num_gpus):
+ with tf.device('/gpu:%d' % i):
+ with tf.name_scope('%s_%d' % (cifar10.TOWER_NAME, i)) as scope:
+ # Calculate the loss for one tower of the CIFAR model. This function
+ # constructs the entire CIFAR model but shares the variables across
+ # all towers.
+ loss = tower_loss(scope)
+
+ # Reuse variables for the next tower.
+ tf.get_variable_scope().reuse_variables()
+
+ # Retain the summaries from the final tower.
+ summaries = tf.get_collection(tf.GraphKeys.SUMMARIES, scope)
+
+ # Calculate the gradients for the batch of data on this CIFAR tower.
+ grads = opt.compute_gradients(loss)
+
+ # Keep track of the gradients across all towers.
+ tower_grads.append(grads)
+
+ # We must calculate the mean of each gradient. Note that this is the
+ # synchronization point across all towers.
+ grads = average_gradients(tower_grads)
+
+ # Add a summary to track the learning rate.
+ summaries.append(tf.scalar_summary('learning_rate', lr))
+
+ # Add histograms for gradients.
+ for grad, var in grads:
+ if grad:
+ summaries.append(
+ tf.histogram_summary(var.op.name + '/gradients', grad))
+
+ # Apply the gradients to adjust the shared variables.
+ apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)
+
+ # Add histograms for trainable variables.
+ for var in tf.trainable_variables():
+ summaries.append(tf.histogram_summary(var.op.name, var))
+
+ # Track the moving averages of all trainable variables.
+ variable_averages = tf.train.ExponentialMovingAverage(
+ cifar10.MOVING_AVERAGE_DECAY, global_step)
+ variables_averages_op = variable_averages.apply(tf.trainable_variables())
+
+ # Group all updates to into a single train op.
+ train_op = tf.group(apply_gradient_op, variables_averages_op)
+
+ # Create a saver.
+ saver = tf.train.Saver(tf.all_variables())
+
+ # Build the summary operation from the last tower summaries.
+ summary_op = tf.merge_summary(summaries)
+
+ # Build an initialization operation to run below.
+ init = tf.initialize_all_variables()
+
+ # Start running operations on the Graph. allow_soft_placement must be set to
+ # True to build towers on GPU, as some of the ops do not have GPU
+ # implementations.
+ sess = tf.Session(config=tf.ConfigProto(
+ allow_soft_placement=True,
+ log_device_placement=FLAGS.log_device_placement))
+ sess.run(init)
+
+ # Start the queue runners.
+ tf.train.start_queue_runners(sess=sess)
+
+ summary_writer = tf.train.SummaryWriter(FLAGS.train_dir,
+ graph_def=sess.graph_def)
+
+ for step in xrange(FLAGS.max_steps):
+ start_time = time.time()
+ _, loss_value = sess.run([train_op, loss])
+ duration = time.time() - start_time
+
+ assert not np.isnan(loss_value), 'Model diverged with loss = NaN'
+
+ if step % 10 == 0:
+ num_examples_per_step = FLAGS.batch_size * FLAGS.num_gpus
+ examples_per_sec = num_examples_per_step / float(duration)
+ sec_per_batch = float(duration) / FLAGS.num_gpus
+
+ format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
+ 'sec/batch)')
+ print (format_str % (datetime.now(), step, loss_value,
+ examples_per_sec, sec_per_batch))
+
+ if step % 100 == 0:
+ summary_str = sess.run(summary_op)
+ summary_writer.add_summary(summary_str, step)
+
+ # Save the model checkpoint periodically.
+ if step % 1000 == 0 or (step + 1) == FLAGS.max_steps:
+ checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
+ saver.save(sess, checkpoint_path, global_step=step)
+
+
+def main(argv=None): # pylint: disable=unused-argument
+ cifar10.maybe_download_and_extract()
+ if gfile.Exists(FLAGS.train_dir):
+ gfile.DeleteRecursively(FLAGS.train_dir)
+ gfile.MakeDirs(FLAGS.train_dir)
+ train()
+
+
+if __name__ == '__main__':
+ tf.app.run()
diff --git a/tensorflow/models/image/cifar10/cifar10_train.py b/tensorflow/models/image/cifar10/cifar10_train.py
new file mode 100644
index 0000000000..bcb6eeae58
--- /dev/null
+++ b/tensorflow/models/image/cifar10/cifar10_train.py
@@ -0,0 +1,119 @@
+"""A binary to train CIFAR-10 using a single GPU.
+
+Accuracy:
+cifar10_train.py achieves ~86% accuracy after 100K steps (256 epochs of
+data) as judged by cifar10_eval.py.
+
+Speed: With batch_size 128.
+
+System | Step Time (sec/batch) | Accuracy
+------------------------------------------------------------------
+1 Tesla K20m | 0.35-0.60 | ~86% at 60K steps (5 hours)
+1 Tesla K40m | 0.25-0.35 | ~86% at 100K steps (4 hours)
+
+Usage:
+Please see the tutorial and website for how to download the CIFAR-10
+data set, compile the program and train the model.
+
+http://tensorflow.org/tutorials/deep_cnn/
+"""
+from datetime import datetime
+import os.path
+import time
+
+import tensorflow.python.platform
+from tensorflow.python.platform import gfile
+
+import numpy as np
+
+import tensorflow as tf
+
+from tensorflow.models.image.cifar10 import cifar10
+
+FLAGS = tf.app.flags.FLAGS
+
+tf.app.flags.DEFINE_string('train_dir', '/tmp/cifar10_train',
+ """Directory where to write event logs """
+ """and checkpoint.""")
+tf.app.flags.DEFINE_integer('max_steps', 1000000,
+ """Number of batches to run.""")
+tf.app.flags.DEFINE_boolean('log_device_placement', False,
+ """Whether to log device placement.""")
+
+
+def train():
+ """Train CIFAR-10 for a number of steps."""
+ with tf.Graph().as_default():
+ global_step = tf.Variable(0, trainable=False)
+
+ # Get images and labels for CIFAR-10.
+ images, labels = cifar10.distorted_inputs()
+
+ # Build a Graph that computes the logits predictions from the
+ # inference model.
+ logits = cifar10.inference(images)
+
+ # Calculate loss.
+ loss = cifar10.loss(logits, labels)
+
+ # Build a Graph that trains the model with one batch of examples and
+ # updates the model parameters.
+ train_op = cifar10.train(loss, global_step)
+
+ # Create a saver.
+ saver = tf.train.Saver(tf.all_variables())
+
+ # Build the summary operation based on the TF collection of Summaries.
+ summary_op = tf.merge_all_summaries()
+
+ # Build an initialization operation to run below.
+ init = tf.initialize_all_variables()
+
+ # Start running operations on the Graph.
+ sess = tf.Session(config=tf.ConfigProto(
+ log_device_placement=FLAGS.log_device_placement))
+ sess.run(init)
+
+ # Start the queue runners.
+ tf.train.start_queue_runners(sess=sess)
+
+ summary_writer = tf.train.SummaryWriter(FLAGS.train_dir,
+ graph_def=sess.graph_def)
+
+ for step in xrange(FLAGS.max_steps):
+ start_time = time.time()
+ _, loss_value = sess.run([train_op, loss])
+ duration = time.time() - start_time
+
+ assert not np.isnan(loss_value), 'Model diverged with loss = NaN'
+
+ if step % 10 == 0:
+ num_examples_per_step = FLAGS.batch_size
+ examples_per_sec = num_examples_per_step / float(duration)
+ sec_per_batch = float(duration)
+
+ format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
+ 'sec/batch)')
+ print (format_str % (datetime.now(), step, loss_value,
+ examples_per_sec, sec_per_batch))
+
+ if step % 100 == 0:
+ summary_str = sess.run(summary_op)
+ summary_writer.add_summary(summary_str, step)
+
+ # Save the model checkpoint periodically.
+ if step % 1000 == 0 or (step + 1) == FLAGS.max_steps:
+ checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
+ saver.save(sess, checkpoint_path, global_step=step)
+
+
+def main(argv=None): # pylint: disable=unused-argument
+ cifar10.maybe_download_and_extract()
+ if gfile.Exists(FLAGS.train_dir):
+ gfile.DeleteRecursively(FLAGS.train_dir)
+ gfile.MakeDirs(FLAGS.train_dir)
+ train()
+
+
+if __name__ == '__main__':
+ tf.app.run()