diff options
Diffstat (limited to 'tensorflow/models/image/cifar10')
-rw-r--r-- | tensorflow/models/image/cifar10/BUILD | 79 | ||||
-rw-r--r-- | tensorflow/models/image/cifar10/README.md | 10 | ||||
-rwxr-xr-x | tensorflow/models/image/cifar10/__init__.py | 0 | ||||
-rw-r--r-- | tensorflow/models/image/cifar10/cifar10.py | 480 | ||||
-rw-r--r-- | tensorflow/models/image/cifar10/cifar10_eval.py | 148 | ||||
-rw-r--r-- | tensorflow/models/image/cifar10/cifar10_input.py | 65 | ||||
-rw-r--r-- | tensorflow/models/image/cifar10/cifar10_input_test.py | 49 | ||||
-rw-r--r-- | tensorflow/models/image/cifar10/cifar10_multi_gpu_train.py | 265 | ||||
-rw-r--r-- | tensorflow/models/image/cifar10/cifar10_train.py | 119 |
9 files changed, 1215 insertions, 0 deletions
diff --git a/tensorflow/models/image/cifar10/BUILD b/tensorflow/models/image/cifar10/BUILD new file mode 100644 index 0000000000..adf9aaffd4 --- /dev/null +++ b/tensorflow/models/image/cifar10/BUILD @@ -0,0 +1,79 @@ +# Description: +# Example TensorFlow models for CIFAR-10 + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +py_library( + name = "cifar10_input", + srcs = ["cifar10_input.py"], + deps = [ + "//tensorflow:tensorflow_py", + ], +) + +py_test( + name = "cifar10_input_test", + srcs = ["cifar10_input_test.py"], + deps = [ + ":cifar10_input", + "//tensorflow:tensorflow_py", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + ], +) + +py_library( + name = "cifar10", + srcs = ["cifar10.py"], + deps = [ + ":cifar10_input", + "//tensorflow:tensorflow_py", + ], +) + +py_binary( + name = "cifar10_eval", + srcs = [ + "cifar10_eval.py", + ], + visibility = ["//tensorflow:__subpackages__"], + deps = [ + ":cifar10", + ], +) + +py_binary( + name = "cifar10_train", + srcs = [ + "cifar10_train.py", + ], + visibility = ["//tensorflow:__subpackages__"], + deps = [ + ":cifar10", + ], +) + +py_binary( + name = "cifar10_multi_gpu_train", + srcs = [ + "cifar10_multi_gpu_train.py", + ], + visibility = ["//tensorflow:__subpackages__"], + deps = [ + ":cifar10", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/models/image/cifar10/README.md b/tensorflow/models/image/cifar10/README.md new file mode 100644 index 0000000000..67877aedc0 --- /dev/null +++ b/tensorflow/models/image/cifar10/README.md @@ -0,0 +1,10 @@ +CIFAR-10 is a common benchmark in machine learning for image recognition. + +http://www.cs.toronto.edu/~kriz/cifar.html + +Code in this directory demonstrates how to use TensorFlow to train and evaluate a convolutional neural network (CNN) on both CPU and GPU. We also demonstrate how to train a CNN over multiple GPUs. + +Detailed instructions on how to get started available at: + +http://tensorflow.org/tutorials/deep_cnn/ + diff --git a/tensorflow/models/image/cifar10/__init__.py b/tensorflow/models/image/cifar10/__init__.py new file mode 100755 index 0000000000..e69de29bb2 --- /dev/null +++ b/tensorflow/models/image/cifar10/__init__.py diff --git a/tensorflow/models/image/cifar10/cifar10.py b/tensorflow/models/image/cifar10/cifar10.py new file mode 100644 index 0000000000..7870080820 --- /dev/null +++ b/tensorflow/models/image/cifar10/cifar10.py @@ -0,0 +1,480 @@ +"""Builds the CIFAR-10 network. + +Summary of available functions: + + # Compute input images and labels for training. If you would like to run + # evaluations, use input() instead. + inputs, labels = distorted_inputs() + + # Compute inference on the model inputs to make a prediction. + predictions = inference(inputs) + + # Compute the total loss of the prediction with respect to the labels. + loss = loss(predictions, labels) + + # Create a graph to run one step of training with respect to the loss. + train_op = train(loss, global_step) +""" +# pylint: disable=missing-docstring +import gzip +import os +import re +import sys +import tarfile +import urllib + +import tensorflow.python.platform +import tensorflow as tf + +from tensorflow.models.image.cifar10 import cifar10_input +from tensorflow.python.platform import gfile + +FLAGS = tf.app.flags.FLAGS + +# Basic model parameters. +tf.app.flags.DEFINE_integer('batch_size', 128, + """Number of images to process in a batch.""") +tf.app.flags.DEFINE_string('data_dir', '/tmp/cifar10_data', + """Path to the CIFAR-10 data directory.""") + +# Process images of this size. Note that this differs from the original CIFAR +# image size of 32 x 32. If one alters this number, then the entire model +# architecture will change and any model would need to be retrained. +IMAGE_SIZE = 24 + +# Global constants describing the CIFAR-10 data set. +NUM_CLASSES = 10 +NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = 50000 +NUM_EXAMPLES_PER_EPOCH_FOR_EVAL = 10000 + +# Constants describing the training process. +MOVING_AVERAGE_DECAY = 0.9999 # The decay to use for the moving average. +NUM_EPOCHS_PER_DECAY = 350.0 # Epochs after which learning rate decays. +LEARNING_RATE_DECAY_FACTOR = 0.1 # Learning rate decay factor. +INITIAL_LEARNING_RATE = 0.1 # Initial learning rate. + +# If a model is trained with multiple GPU's prefix all Op names with tower_name +# to differentiate the operations. Note that this prefix is removed from the +# names of the summaries when visualizing a model. +TOWER_NAME = 'tower' + +DATA_URL = 'http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz' + + +def _activation_summary(x): + """Helper to create summaries for activations. + + Creates a summary that provides a histogram of activations. + Creates a summary that measure the sparsity of activations. + + Args: + x: Tensor + Returns: + nothing + """ + # Remove 'tower_[0-9]/' from the name in case this is a multi-GPU training + # session. This helps the clarity of presentation on tensorboard. + tensor_name = re.sub('%s_[0-9]*/' % TOWER_NAME, '', x.op.name) + tf.histogram_summary(tensor_name + '/activations', x) + tf.scalar_summary(tensor_name + '/sparsity', tf.nn.zero_fraction(x)) + + +def _variable_on_cpu(name, shape, initializer): + """Helper to create a Variable stored on CPU memory. + + Args: + name: name of the variable + shape: list of ints + initializer: initializer for Variable + + Returns: + Variable Tensor + """ + with tf.device('/cpu:0'): + var = tf.get_variable(name, shape, initializer=initializer) + return var + + +def _variable_with_weight_decay(name, shape, stddev, wd): + """Helper to create an initialized Variable with weight decay. + + Note that the Variable is initialized with a truncated normal distribution. + A weight decay is added only if one is specified. + + Args: + name: name of the variable + shape: list of ints + stddev: standard deviation of a truncated Gaussian + wd: add L2Loss weight decay multiplied by this float. If None, weight + decay is not added for this Variable. + + Returns: + Variable Tensor + """ + var = _variable_on_cpu(name, shape, + tf.truncated_normal_initializer(stddev=stddev)) + if wd: + weight_decay = tf.mul(tf.nn.l2_loss(var), wd, name='weight_loss') + tf.add_to_collection('losses', weight_decay) + return var + + +def _generate_image_and_label_batch(image, label, min_queue_examples): + """Construct a queued batch of images and labels. + + Args: + image: 3-D Tensor of [IMAGE_SIZE, IMAGE_SIZE, 3] of type.float32. + label: 1-D Tensor of type.int32 + min_queue_examples: int32, minimum number of samples to retain + in the queue that provides of batches of examples. + + Returns: + images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size. + labels: Labels. 1D tensor of [batch_size] size. + """ + # Create a queue that shuffles the examples, and then + # read 'FLAGS.batch_size' images + labels from the example queue. + num_preprocess_threads = 16 + images, label_batch = tf.train.shuffle_batch( + [image, label], + batch_size=FLAGS.batch_size, + num_threads=num_preprocess_threads, + capacity=min_queue_examples + 3 * FLAGS.batch_size, + min_after_dequeue=min_queue_examples) + + # Display the training images in the visualizer. + tf.image_summary('images', images) + + return images, tf.reshape(label_batch, [FLAGS.batch_size]) + + +def distorted_inputs(): + """Construct distorted input for CIFAR training using the Reader ops. + + Raises: + ValueError: if no data_dir + + Returns: + images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size. + labels: Labels. 1D tensor of [batch_size] size. + """ + filenames = [os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin', + 'data_batch_%d.bin' % i) + for i in xrange(1, 5)] + for f in filenames: + if not gfile.Exists(f): + raise ValueError('Failed to find file: ' + f) + + # Create a queue that produces the filenames to read. + filename_queue = tf.train.string_input_producer(filenames) + + # Read examples from files in the filename queue. + read_input = cifar10_input.read_cifar10(filename_queue) + reshaped_image = tf.cast(read_input.uint8image, tf.float32) + + height = IMAGE_SIZE + width = IMAGE_SIZE + + # Image processing for training the network. Note the many random + # distortions applied to the image. + + # Randomly crop a [height, width] section of the image. + distorted_image = tf.image.random_crop(reshaped_image, [height, width]) + + # Randomly flip the image horizontally. + distorted_image = tf.image.random_flip_left_right(distorted_image) + + # Because these operations are not commutative, consider randomizing + # randomize the order their operation. + distorted_image = tf.image.random_brightness(distorted_image, + max_delta=63) + distorted_image = tf.image.random_contrast(distorted_image, + lower=0.2, upper=1.8) + + # Subtract off the mean and divide by the variance of the pixels. + float_image = tf.image.per_image_whitening(distorted_image) + + # Ensure that the random shuffling has good mixing properties. + min_fraction_of_examples_in_queue = 0.4 + min_queue_examples = int(NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN * + min_fraction_of_examples_in_queue) + print ('Filling queue with %d CIFAR images before starting to train. ' + 'This will take a few minutes.' % min_queue_examples) + + # Generate a batch of images and labels by building up a queue of examples. + return _generate_image_and_label_batch(float_image, read_input.label, + min_queue_examples) + + +def inputs(eval_data): + """Construct input for CIFAR evaluation using the Reader ops. + + Args: + eval_data: bool, indicating if one should use the train or eval data set. + + Raises: + ValueError: if no data_dir + + Returns: + images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size. + labels: Labels. 1D tensor of [batch_size] size. + """ + if not FLAGS.data_dir: + raise ValueError('Please supply a data_dir') + + if not eval_data: + filenames = [os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin', + 'data_batch_%d.bin' % i) + for i in xrange(1, 5)] + num_examples_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN + else: + filenames = [os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin', + 'test_batch.bin')] + num_examples_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_EVAL + + for f in filenames: + if not gfile.Exists(f): + raise ValueError('Failed to find file: ' + f) + + # Create a queue that produces the filenames to read. + filename_queue = tf.train.string_input_producer(filenames) + + # Read examples from files in the filename queue. + read_input = cifar10_input.read_cifar10(filename_queue) + reshaped_image = tf.cast(read_input.uint8image, tf.float32) + + height = IMAGE_SIZE + width = IMAGE_SIZE + + # Image processing for evaluation. + # Crop the central [height, width] of the image. + resized_image = tf.image.resize_image_with_crop_or_pad(reshaped_image, + width, height) + + # Subtract off the mean and divide by the variance of the pixels. + float_image = tf.image.per_image_whitening(resized_image) + + # Ensure that the random shuffling has good mixing properties. + min_fraction_of_examples_in_queue = 0.4 + min_queue_examples = int(num_examples_per_epoch * + min_fraction_of_examples_in_queue) + + # Generate a batch of images and labels by building up a queue of examples. + return _generate_image_and_label_batch(float_image, read_input.label, + min_queue_examples) + + +def inference(images): + """Build the CIFAR-10 model. + + Args: + images: Images returned from distorted_inputs() or inputs(). + + Returns: + Logits. + """ + # We instantiate all variables using tf.get_variable() instead of + # tf.Variable() in order to share variables across multiple GPU training runs. + # If we only ran this model on a single GPU, we could simplify this function + # by replacing all instances of tf.get_variable() with tf.Variable(). + # + # conv1 + with tf.variable_scope('conv1') as scope: + kernel = _variable_with_weight_decay('weights', shape=[5, 5, 3, 64], + stddev=1e-4, wd=0.0) + conv = tf.nn.conv2d(images, kernel, [1, 1, 1, 1], padding='SAME') + biases = _variable_on_cpu('biases', [64], tf.constant_initializer(0.0)) + bias = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape().as_list()) + conv1 = tf.nn.relu(bias, name=scope.name) + _activation_summary(conv1) + + # pool1 + pool1 = tf.nn.max_pool(conv1, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1], + padding='SAME', name='pool1') + # norm1 + norm1 = tf.nn.lrn(pool1, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75, + name='norm1') + + # conv2 + with tf.variable_scope('conv2') as scope: + kernel = _variable_with_weight_decay('weights', shape=[5, 5, 64, 64], + stddev=1e-4, wd=0.0) + conv = tf.nn.conv2d(norm1, kernel, [1, 1, 1, 1], padding='SAME') + biases = _variable_on_cpu('biases', [64], tf.constant_initializer(0.1)) + bias = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape().as_list()) + conv2 = tf.nn.relu(bias, name=scope.name) + _activation_summary(conv2) + + # norm2 + norm2 = tf.nn.lrn(conv2, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75, + name='norm2') + # pool2 + pool2 = tf.nn.max_pool(norm2, ksize=[1, 3, 3, 1], + strides=[1, 2, 2, 1], padding='SAME', name='pool2') + + # local3 + with tf.variable_scope('local3') as scope: + # Move everything into depth so we can perform a single matrix multiply. + dim = 1 + for d in pool2.get_shape()[1:].as_list(): + dim *= d + reshape = tf.reshape(pool2, [FLAGS.batch_size, dim]) + + weights = _variable_with_weight_decay('weights', shape=[dim, 384], + stddev=0.04, wd=0.004) + biases = _variable_on_cpu('biases', [384], tf.constant_initializer(0.1)) + local3 = tf.nn.relu_layer(reshape, weights, biases, name=scope.name) + _activation_summary(local3) + + # local4 + with tf.variable_scope('local4') as scope: + weights = _variable_with_weight_decay('weights', shape=[384, 192], + stddev=0.04, wd=0.004) + biases = _variable_on_cpu('biases', [192], tf.constant_initializer(0.1)) + local4 = tf.nn.relu_layer(local3, weights, biases, name=scope.name) + _activation_summary(local4) + + # softmax, i.e. softmax(WX + b) + with tf.variable_scope('softmax_linear') as scope: + weights = _variable_with_weight_decay('weights', [192, NUM_CLASSES], + stddev=1/192.0, wd=0.0) + biases = _variable_on_cpu('biases', [NUM_CLASSES], + tf.constant_initializer(0.0)) + softmax_linear = tf.nn.xw_plus_b(local4, weights, biases, name=scope.name) + _activation_summary(softmax_linear) + + return softmax_linear + + +def loss(logits, labels): + """Add L2Loss to all the trainable variables. + + Add summary for for "Loss" and "Loss/avg". + Args: + logits: Logits from inference(). + labels: Labels from distorted_inputs or inputs(). 1-D tensor + of shape [batch_size] + + Returns: + Loss tensor of type float. + """ + # Reshape the labels into a dense Tensor of + # shape [batch_size, NUM_CLASSES]. + sparse_labels = tf.reshape(labels, [FLAGS.batch_size, 1]) + indices = tf.reshape(tf.range(0, FLAGS.batch_size, 1), [FLAGS.batch_size, 1]) + concated = tf.concat(1, [indices, sparse_labels]) + dense_labels = tf.sparse_to_dense(concated, + [FLAGS.batch_size, NUM_CLASSES], + 1.0, 0.0) + + # Calculate the average cross entropy loss across the batch. + cross_entropy = tf.nn.softmax_cross_entropy_with_logits( + logits, dense_labels, name='cross_entropy_per_example') + cross_entropy_mean = tf.reduce_mean(cross_entropy, name='cross_entropy') + tf.add_to_collection('losses', cross_entropy_mean) + + # The total loss is defined as the cross entropy loss plus all of the weight + # decay terms (L2 loss). + return tf.add_n(tf.get_collection('losses'), name='total_loss') + + +def _add_loss_summaries(total_loss): + """Add summaries for losses in CIFAR-10 model. + + Generates moving average for all losses and associated summaries for + visualizing the performance of the network. + + Args: + total_loss: Total loss from loss(). + Returns: + loss_averages_op: op for generating moving averages of losses. + """ + # Compute the moving average of all individual losses and the total loss. + loss_averages = tf.train.ExponentialMovingAverage(0.9, name='avg') + losses = tf.get_collection('losses') + loss_averages_op = loss_averages.apply(losses + [total_loss]) + + # Attach a scalar summmary to all individual losses and the total loss; do the + # same for the averaged version of the losses. + for l in losses + [total_loss]: + # Name each loss as '(raw)' and name the moving average version of the loss + # as the original loss name. + tf.scalar_summary(l.op.name +' (raw)', l) + tf.scalar_summary(l.op.name, loss_averages.average(l)) + + return loss_averages_op + + +def train(total_loss, global_step): + """Train CIFAR-10 model. + + Create an optimizer and apply to all trainable variables. Add moving + average for all trainable variables. + + Args: + total_loss: Total loss from loss(). + global_step: Integer Variable counting the number of training steps + processed. + Returns: + train_op: op for training. + """ + # Variables that affect learning rate. + num_batches_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN / FLAGS.batch_size + decay_steps = int(num_batches_per_epoch * NUM_EPOCHS_PER_DECAY) + + # Decay the learning rate exponentially based on the number of steps. + lr = tf.train.exponential_decay(INITIAL_LEARNING_RATE, + global_step, + decay_steps, + LEARNING_RATE_DECAY_FACTOR, + staircase=True) + tf.scalar_summary('learning_rate', lr) + + # Generate moving averages of all losses and associated summaries. + loss_averages_op = _add_loss_summaries(total_loss) + + # Compute gradients. + with tf.control_dependencies([loss_averages_op]): + opt = tf.train.GradientDescentOptimizer(lr) + grads = opt.compute_gradients(total_loss) + + # Apply gradients. + apply_gradient_op = opt.apply_gradients(grads, global_step=global_step) + + # Add histograms for trainable variables. + for var in tf.trainable_variables(): + tf.histogram_summary(var.op.name, var) + + # Add histograms for gradients. + for grad, var in grads: + if grad: + tf.histogram_summary(var.op.name + '/gradients', grad) + + # Track the moving averages of all trainable variables. + variable_averages = tf.train.ExponentialMovingAverage( + MOVING_AVERAGE_DECAY, global_step) + variables_averages_op = variable_averages.apply(tf.trainable_variables()) + + with tf.control_dependencies([apply_gradient_op, variables_averages_op]): + train_op = tf.no_op(name='train') + + return train_op + + +def maybe_download_and_extract(): + """Download and extract the tarball from Alex's website.""" + dest_directory = FLAGS.data_dir + if not os.path.exists(dest_directory): + os.makedirs(dest_directory) + filename = DATA_URL.split('/')[-1] + filepath = os.path.join(dest_directory, filename) + if not os.path.exists(filepath): + def _progress(count, block_size, total_size): + sys.stdout.write('\r>> Downloading %s %.1f%%' % (filename, + float(count * block_size) / float(total_size) * 100.0)) + sys.stdout.flush() + filepath, _ = urllib.urlretrieve(DATA_URL, filepath, reporthook=_progress) + print + statinfo = os.stat(filepath) + print 'Succesfully downloaded', filename, statinfo.st_size, 'bytes.' + tarfile.open(filepath, 'r:gz').extractall(dest_directory) diff --git a/tensorflow/models/image/cifar10/cifar10_eval.py b/tensorflow/models/image/cifar10/cifar10_eval.py new file mode 100644 index 0000000000..73c224191d --- /dev/null +++ b/tensorflow/models/image/cifar10/cifar10_eval.py @@ -0,0 +1,148 @@ +"""Evaluation for CIFAR-10. + +Accuracy: +cifar10_train.py achieves 83.0% accuracy after 100K steps (256 epochs +of data) as judged by cifar10_eval.py. + +Speed: +On a single Tesla K40, cifar10_train.py processes a single batch of 128 images +in 0.25-0.35 sec (i.e. 350 - 600 images /sec). The model reaches ~86% +accuracy after 100K steps in 8 hours of training time. + +Usage: +Please see the tutorial and website for how to download the CIFAR-10 +data set, compile the program and train the model. + +http://tensorflow.org/tutorials/deep_cnn/ +""" +from datetime import datetime +import math +import time + +import tensorflow.python.platform +from tensorflow.python.platform import gfile +import numpy as np +import tensorflow as tf + +from tensorflow.models.image.cifar10 import cifar10 + +FLAGS = tf.app.flags.FLAGS + +tf.app.flags.DEFINE_string('eval_dir', '/tmp/cifar10_eval', + """Directory where to write event logs.""") +tf.app.flags.DEFINE_string('eval_data', 'test', + """Either 'test' or 'train_eval'.""") +tf.app.flags.DEFINE_string('checkpoint_dir', '/tmp/cifar10_train', + """Directory where to read model checkpoints.""") +tf.app.flags.DEFINE_integer('eval_interval_secs', 60 * 5, + """How often to run the eval.""") +tf.app.flags.DEFINE_integer('num_examples', 10000, + """Number of examples to run.""") +tf.app.flags.DEFINE_boolean('run_once', False, + """Whether to run eval only once.""") + + +def eval_once(saver, summary_writer, top_k_op, summary_op): + """Run Eval once. + + Args: + saver: Saver. + summary_writer: Summary writer. + top_k_op: Top K op. + summary_op: Summary op. + """ + with tf.Session() as sess: + ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir) + if ckpt and ckpt.model_checkpoint_path: + # Restores from checkpoint + saver.restore(sess, ckpt.model_checkpoint_path) + # Assuming model_checkpoint_path looks something like: + # /my-favorite-path/cifar10_train/model.ckpt-0, + # extract global_step from it. + global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1] + else: + print 'No checkpoint file found' + return + + # Start the queue runners. + coord = tf.train.Coordinator() + try: + threads = [] + for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS): + threads.extend(qr.create_threads(sess, coord=coord, daemon=True, + start=True)) + + num_iter = int(math.ceil(FLAGS.num_examples / FLAGS.batch_size)) + true_count = 0 # Counts the number of correct predictions. + total_sample_count = num_iter * FLAGS.batch_size + step = 0 + while step < num_iter and not coord.should_stop(): + predictions = sess.run([top_k_op]) + true_count += np.sum(predictions) + step += 1 + + # Compute precision @ 1. + precision = float(true_count) / float(total_sample_count) + print '%s: precision @ 1 = %.3f' % (datetime.now(), precision) + + summary = tf.Summary() + summary.ParseFromString(sess.run(summary_op)) + summary.value.add(tag='Precision @ 1', simple_value=precision) + summary_writer.add_summary(summary, global_step) + except Exception, e: # pylint: disable=broad-except + coord.request_stop(e) + + coord.request_stop() + coord.join(threads, stop_grace_period_secs=10) + + +def evaluate(): + """Eval CIFAR-10 for a number of steps.""" + with tf.Graph().as_default(): + # Get images and labels for CIFAR-10. + eval_data = FLAGS.eval_data == 'test' + images, labels = cifar10.inputs(eval_data=eval_data) + + # Build a Graph that computes the logits predictions from the + # inference model. + logits = cifar10.inference(images) + + # Calculate predictions. + top_k_op = tf.nn.in_top_k(logits, labels, 1) + + # Restore the moving average version of the learned variables for eval. + variable_averages = tf.train.ExponentialMovingAverage( + cifar10.MOVING_AVERAGE_DECAY) + variables_to_restore = {} + for v in tf.all_variables(): + if v in tf.trainable_variables(): + restore_name = variable_averages.average_name(v) + else: + restore_name = v.op.name + variables_to_restore[restore_name] = v + saver = tf.train.Saver(variables_to_restore) + + # Build the summary operation based on the TF collection of Summaries. + summary_op = tf.merge_all_summaries() + + graph_def = tf.get_default_graph().as_graph_def() + summary_writer = tf.train.SummaryWriter(FLAGS.eval_dir, + graph_def=graph_def) + + while True: + eval_once(saver, summary_writer, top_k_op, summary_op) + if FLAGS.run_once: + break + time.sleep(FLAGS.eval_interval_secs) + + +def main(argv=None): # pylint: disable=unused-argument + cifar10.maybe_download_and_extract() + if gfile.Exists(FLAGS.eval_dir): + gfile.DeleteRecursively(FLAGS.eval_dir) + gfile.MakeDirs(FLAGS.eval_dir) + evaluate() + + +if __name__ == '__main__': + tf.app.run() diff --git a/tensorflow/models/image/cifar10/cifar10_input.py b/tensorflow/models/image/cifar10/cifar10_input.py new file mode 100644 index 0000000000..686f1bf987 --- /dev/null +++ b/tensorflow/models/image/cifar10/cifar10_input.py @@ -0,0 +1,65 @@ +"""Routine for decoding the CIFAR-10 binary file format.""" + +import tensorflow.python.platform +import tensorflow as tf + + +def read_cifar10(filename_queue): + """Reads and parses examples from CIFAR10 data files. + + Recommendation: if you want N-way read parallelism, call this function + N times. This will give you N independent Readers reading different + files & positions within those files, which will give better mixing of + examples. + + Args: + filename_queue: A queue of strings with the filenames to read from. + + Returns: + An object representing a single example, with the following fields: + height: number of rows in the result (32) + width: number of columns in the result (32) + depth: number of color channels in the result (3) + key: a scalar string Tensor describing the filename & record number + for this example. + label: an int32 Tensor with the label in the range 0..9. + uint8image: a [height, width, depth] uint8 Tensor with the image data + """ + + class CIFAR10Record(object): + pass + result = CIFAR10Record() + + # Dimensions of the images in the CIFAR-10 dataset. + # See http://www.cs.toronto.edu/~kriz/cifar.html for a description of the + # input format. + label_bytes = 1 # 2 for CIFAR-100 + result.height = 32 + result.width = 32 + result.depth = 3 + image_bytes = result.height * result.width * result.depth + # Every record consists of a label followed by the image, with a + # fixed number of bytes for each. + record_bytes = label_bytes + image_bytes + + # Read a record, getting filenames from the filename_queue. No + # header or footer in the CIFAR-10 format, so we leave header_bytes + # and footer_bytes at their default of 0. + reader = tf.FixedLengthRecordReader(record_bytes=record_bytes) + result.key, value = reader.read(filename_queue) + + # Convert from a string to a vector of uint8 that is record_bytes long. + record_bytes = tf.decode_raw(value, tf.uint8) + + # The first bytes represent the label, which we convert from uint8->int32. + result.label = tf.cast( + tf.slice(record_bytes, [0], [label_bytes]), tf.int32) + + # The remaining bytes after the label represent the image, which we reshape + # from [depth * height * width] to [depth, height, width]. + depth_major = tf.reshape(tf.slice(record_bytes, [label_bytes], [image_bytes]), + [result.depth, result.height, result.width]) + # Convert from [depth, height, width] to [height, width, depth]. + result.uint8image = tf.transpose(depth_major, [1, 2, 0]) + + return result diff --git a/tensorflow/models/image/cifar10/cifar10_input_test.py b/tensorflow/models/image/cifar10/cifar10_input_test.py new file mode 100644 index 0000000000..d43f5aedcf --- /dev/null +++ b/tensorflow/models/image/cifar10/cifar10_input_test.py @@ -0,0 +1,49 @@ +"""Tests for cifar10 input.""" + +import os + +import tensorflow.python.platform + +import tensorflow as tf + +from tensorflow.models.image.cifar10 import cifar10_input + + +class CIFAR10InputTest(tf.test.TestCase): + + def _record(self, label, red, green, blue): + image_size = 32 * 32 + record = "%s%s%s%s" % (chr(label), chr(red) * image_size, + chr(green) * image_size, chr(blue) * image_size) + expected = [[[red, green, blue]] * 32] * 32 + return record, expected + + def testSimple(self): + labels = [9, 3, 0] + records = [self._record(labels[0], 0, 128, 255), + self._record(labels[1], 255, 0, 1), + self._record(labels[2], 254, 255, 0)] + contents = "".join([record for record, _ in records]) + expected = [expected for _, expected in records] + filename = os.path.join(self.get_temp_dir(), "cifar") + open(filename, "w").write(contents) + + with self.test_session() as sess: + q = tf.FIFOQueue(99, [tf.string], shapes=()) + q.enqueue([filename]).run() + q.close().run() + result = cifar10_input.read_cifar10(q) + + for i in range(3): + key, label, uint8image = sess.run([ + result.key, result.label, result.uint8image]) + self.assertEqual("%s:%d" % (filename, i), key) + self.assertEqual(labels[i], label) + self.assertAllEqual(expected[i], uint8image) + + with self.assertRaises(tf.errors.OutOfRangeError): + sess.run([result.key, result.uint8image]) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow/models/image/cifar10/cifar10_multi_gpu_train.py b/tensorflow/models/image/cifar10/cifar10_multi_gpu_train.py new file mode 100644 index 0000000000..54bc41f444 --- /dev/null +++ b/tensorflow/models/image/cifar10/cifar10_multi_gpu_train.py @@ -0,0 +1,265 @@ +"""A binary to train CIFAR-10 using multiple GPU's with synchronous updates. + +Accuracy: +cifar10_multi_gpu_train.py achieves ~86% accuracy after 100K steps (256 +epochs of data) as judged by cifar10_eval.py. + +Speed: With batch_size 128. + +System | Step Time (sec/batch) | Accuracy +-------------------------------------------------------------------- +1 Tesla K20m | 0.35-0.60 | ~86% at 60K steps (5 hours) +1 Tesla K40m | 0.25-0.35 | ~86% at 100K steps (4 hours) +2 Tesla K20m | 0.13-0.20 | ~84% at 30K steps (2.5 hours) +3 Tesla K20m | 0.13-0.18 | ~84% at 30K steps +4 Tesla K20m | ~0.10 | ~84% at 30K steps + +Usage: +Please see the tutorial and website for how to download the CIFAR-10 +data set, compile the program and train the model. + +http://tensorflow.org/tutorials/deep_cnn/ +""" +from datetime import datetime +import os.path +import re +import time + +# pylint: disable=unused-import,g-bad-import-order +import tensorflow.python.platform +from tensorflow.python.platform import gfile +import numpy as np +import tensorflow as tf +from tensorflow.models.image.cifar10 import cifar10 +# pylint: disable=unused-import,g-bad-import-order + +FLAGS = tf.app.flags.FLAGS + +tf.app.flags.DEFINE_string('train_dir', '/tmp/cifar10_train', + """Directory where to write event logs """ + """and checkpoint.""") +tf.app.flags.DEFINE_integer('max_steps', 1000000, + """Number of batches to run.""") +tf.app.flags.DEFINE_integer('num_gpus', 1, + """How many GPUs to use.""") +tf.app.flags.DEFINE_boolean('log_device_placement', False, + """Whether to log device placement.""") + + +def tower_loss(scope): + """Calculate the total loss on a single tower running the CIFAR model. + + Args: + scope: unique prefix string identifying the CIFAR tower, e.g. 'tower_0' + + Returns: + Tensor of shape [] containing the total loss for a batch of data + """ + # Get images and labels for CIFAR-10. + images, labels = cifar10.distorted_inputs() + + # Build inference Graph. + logits = cifar10.inference(images) + + # Build the portion of the Graph calculating the losses. Note that we will + # assemble the total_loss using a custom function below. + _ = cifar10.loss(logits, labels) + + # Assemble all of the losses for the current tower only. + losses = tf.get_collection('losses', scope) + + # Calculate the total loss for the current tower. + total_loss = tf.add_n(losses, name='total_loss') + + # Compute the moving average of all individual losses and the total loss. + loss_averages = tf.train.ExponentialMovingAverage(0.9, name='avg') + loss_averages_op = loss_averages.apply(losses + [total_loss]) + + # Attach a scalar summmary to all individual losses and the total loss; do the + # same for the averaged version of the losses. + for l in losses + [total_loss]: + # Remove 'tower_[0-9]/' from the name in case this is a multi-GPU training + # session. This helps the clarity of presentation on tensorboard. + loss_name = re.sub('%s_[0-9]*/' % cifar10.TOWER_NAME, '', l.op.name) + # Name each loss as '(raw)' and name the moving average version of the loss + # as the original loss name. + tf.scalar_summary(loss_name +' (raw)', l) + tf.scalar_summary(loss_name, loss_averages.average(l)) + + with tf.control_dependencies([loss_averages_op]): + total_loss = tf.identity(total_loss) + return total_loss + + +def average_gradients(tower_grads): + """Calculate the average gradient for each shared variable across all towers. + + Note that this function provides a synchronization point across all towers. + + Args: + tower_grads: List of lists of (gradient, variable) tuples. The outer list + is over individual gradients. The inner list is over the gradient + calculation for each tower. + Returns: + List of pairs of (gradient, variable) where the gradient has been averaged + across all towers. + """ + average_grads = [] + for grad_and_vars in zip(*tower_grads): + # Note that each grad_and_vars looks like the following: + # ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN)) + grads = [] + for g, _ in grad_and_vars: + # Add 0 dimension to the gradients to represent the tower. + expanded_g = tf.expand_dims(g, 0) + + # Append on a 'tower' dimension which we will average over below. + grads.append(expanded_g) + + # Average over the 'tower' dimension. + grad = tf.concat(0, grads) + grad = tf.reduce_mean(grad, 0) + + # Keep in mind that the Variables are redundant because they are shared + # across towers. So .. we will just return the first tower's pointer to + # the Variable. + v = grad_and_vars[0][1] + grad_and_var = (grad, v) + average_grads.append(grad_and_var) + return average_grads + + +def train(): + """Train CIFAR-10 for a number of steps.""" + with tf.Graph().as_default(), tf.device('/cpu:0'): + # Create a variable to count the number of train() calls. This equals the + # number of batches processed * FLAGS.num_gpus. + global_step = tf.get_variable( + 'global_step', [], + initializer=tf.constant_initializer(0), trainable=False) + + # Calculate the learning rate schedule. + num_batches_per_epoch = (cifar10.NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN / + FLAGS.batch_size) + decay_steps = int(num_batches_per_epoch * cifar10.NUM_EPOCHS_PER_DECAY) + + # Decay the learning rate exponentially based on the number of steps. + lr = tf.train.exponential_decay(cifar10.INITIAL_LEARNING_RATE, + global_step, + decay_steps, + cifar10.LEARNING_RATE_DECAY_FACTOR, + staircase=True) + + # Create an optimizer that performs gradient descent. + opt = tf.train.GradientDescentOptimizer(lr) + + # Calculate the gradients for each model tower. + tower_grads = [] + for i in xrange(FLAGS.num_gpus): + with tf.device('/gpu:%d' % i): + with tf.name_scope('%s_%d' % (cifar10.TOWER_NAME, i)) as scope: + # Calculate the loss for one tower of the CIFAR model. This function + # constructs the entire CIFAR model but shares the variables across + # all towers. + loss = tower_loss(scope) + + # Reuse variables for the next tower. + tf.get_variable_scope().reuse_variables() + + # Retain the summaries from the final tower. + summaries = tf.get_collection(tf.GraphKeys.SUMMARIES, scope) + + # Calculate the gradients for the batch of data on this CIFAR tower. + grads = opt.compute_gradients(loss) + + # Keep track of the gradients across all towers. + tower_grads.append(grads) + + # We must calculate the mean of each gradient. Note that this is the + # synchronization point across all towers. + grads = average_gradients(tower_grads) + + # Add a summary to track the learning rate. + summaries.append(tf.scalar_summary('learning_rate', lr)) + + # Add histograms for gradients. + for grad, var in grads: + if grad: + summaries.append( + tf.histogram_summary(var.op.name + '/gradients', grad)) + + # Apply the gradients to adjust the shared variables. + apply_gradient_op = opt.apply_gradients(grads, global_step=global_step) + + # Add histograms for trainable variables. + for var in tf.trainable_variables(): + summaries.append(tf.histogram_summary(var.op.name, var)) + + # Track the moving averages of all trainable variables. + variable_averages = tf.train.ExponentialMovingAverage( + cifar10.MOVING_AVERAGE_DECAY, global_step) + variables_averages_op = variable_averages.apply(tf.trainable_variables()) + + # Group all updates to into a single train op. + train_op = tf.group(apply_gradient_op, variables_averages_op) + + # Create a saver. + saver = tf.train.Saver(tf.all_variables()) + + # Build the summary operation from the last tower summaries. + summary_op = tf.merge_summary(summaries) + + # Build an initialization operation to run below. + init = tf.initialize_all_variables() + + # Start running operations on the Graph. allow_soft_placement must be set to + # True to build towers on GPU, as some of the ops do not have GPU + # implementations. + sess = tf.Session(config=tf.ConfigProto( + allow_soft_placement=True, + log_device_placement=FLAGS.log_device_placement)) + sess.run(init) + + # Start the queue runners. + tf.train.start_queue_runners(sess=sess) + + summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, + graph_def=sess.graph_def) + + for step in xrange(FLAGS.max_steps): + start_time = time.time() + _, loss_value = sess.run([train_op, loss]) + duration = time.time() - start_time + + assert not np.isnan(loss_value), 'Model diverged with loss = NaN' + + if step % 10 == 0: + num_examples_per_step = FLAGS.batch_size * FLAGS.num_gpus + examples_per_sec = num_examples_per_step / float(duration) + sec_per_batch = float(duration) / FLAGS.num_gpus + + format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f ' + 'sec/batch)') + print (format_str % (datetime.now(), step, loss_value, + examples_per_sec, sec_per_batch)) + + if step % 100 == 0: + summary_str = sess.run(summary_op) + summary_writer.add_summary(summary_str, step) + + # Save the model checkpoint periodically. + if step % 1000 == 0 or (step + 1) == FLAGS.max_steps: + checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt') + saver.save(sess, checkpoint_path, global_step=step) + + +def main(argv=None): # pylint: disable=unused-argument + cifar10.maybe_download_and_extract() + if gfile.Exists(FLAGS.train_dir): + gfile.DeleteRecursively(FLAGS.train_dir) + gfile.MakeDirs(FLAGS.train_dir) + train() + + +if __name__ == '__main__': + tf.app.run() diff --git a/tensorflow/models/image/cifar10/cifar10_train.py b/tensorflow/models/image/cifar10/cifar10_train.py new file mode 100644 index 0000000000..bcb6eeae58 --- /dev/null +++ b/tensorflow/models/image/cifar10/cifar10_train.py @@ -0,0 +1,119 @@ +"""A binary to train CIFAR-10 using a single GPU. + +Accuracy: +cifar10_train.py achieves ~86% accuracy after 100K steps (256 epochs of +data) as judged by cifar10_eval.py. + +Speed: With batch_size 128. + +System | Step Time (sec/batch) | Accuracy +------------------------------------------------------------------ +1 Tesla K20m | 0.35-0.60 | ~86% at 60K steps (5 hours) +1 Tesla K40m | 0.25-0.35 | ~86% at 100K steps (4 hours) + +Usage: +Please see the tutorial and website for how to download the CIFAR-10 +data set, compile the program and train the model. + +http://tensorflow.org/tutorials/deep_cnn/ +""" +from datetime import datetime +import os.path +import time + +import tensorflow.python.platform +from tensorflow.python.platform import gfile + +import numpy as np + +import tensorflow as tf + +from tensorflow.models.image.cifar10 import cifar10 + +FLAGS = tf.app.flags.FLAGS + +tf.app.flags.DEFINE_string('train_dir', '/tmp/cifar10_train', + """Directory where to write event logs """ + """and checkpoint.""") +tf.app.flags.DEFINE_integer('max_steps', 1000000, + """Number of batches to run.""") +tf.app.flags.DEFINE_boolean('log_device_placement', False, + """Whether to log device placement.""") + + +def train(): + """Train CIFAR-10 for a number of steps.""" + with tf.Graph().as_default(): + global_step = tf.Variable(0, trainable=False) + + # Get images and labels for CIFAR-10. + images, labels = cifar10.distorted_inputs() + + # Build a Graph that computes the logits predictions from the + # inference model. + logits = cifar10.inference(images) + + # Calculate loss. + loss = cifar10.loss(logits, labels) + + # Build a Graph that trains the model with one batch of examples and + # updates the model parameters. + train_op = cifar10.train(loss, global_step) + + # Create a saver. + saver = tf.train.Saver(tf.all_variables()) + + # Build the summary operation based on the TF collection of Summaries. + summary_op = tf.merge_all_summaries() + + # Build an initialization operation to run below. + init = tf.initialize_all_variables() + + # Start running operations on the Graph. + sess = tf.Session(config=tf.ConfigProto( + log_device_placement=FLAGS.log_device_placement)) + sess.run(init) + + # Start the queue runners. + tf.train.start_queue_runners(sess=sess) + + summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, + graph_def=sess.graph_def) + + for step in xrange(FLAGS.max_steps): + start_time = time.time() + _, loss_value = sess.run([train_op, loss]) + duration = time.time() - start_time + + assert not np.isnan(loss_value), 'Model diverged with loss = NaN' + + if step % 10 == 0: + num_examples_per_step = FLAGS.batch_size + examples_per_sec = num_examples_per_step / float(duration) + sec_per_batch = float(duration) + + format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f ' + 'sec/batch)') + print (format_str % (datetime.now(), step, loss_value, + examples_per_sec, sec_per_batch)) + + if step % 100 == 0: + summary_str = sess.run(summary_op) + summary_writer.add_summary(summary_str, step) + + # Save the model checkpoint periodically. + if step % 1000 == 0 or (step + 1) == FLAGS.max_steps: + checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt') + saver.save(sess, checkpoint_path, global_step=step) + + +def main(argv=None): # pylint: disable=unused-argument + cifar10.maybe_download_and_extract() + if gfile.Exists(FLAGS.train_dir): + gfile.DeleteRecursively(FLAGS.train_dir) + gfile.MakeDirs(FLAGS.train_dir) + train() + + +if __name__ == '__main__': + tf.app.run() |