diff options
Diffstat (limited to 'tensorflow/models/image/cifar10/cifar10.py')
-rw-r--r-- | tensorflow/models/image/cifar10/cifar10.py | 151 |
1 files changed, 19 insertions, 132 deletions
diff --git a/tensorflow/models/image/cifar10/cifar10.py b/tensorflow/models/image/cifar10/cifar10.py index b9b89473e8..32234db496 100644 --- a/tensorflow/models/image/cifar10/cifar10.py +++ b/tensorflow/models/image/cifar10/cifar10.py @@ -43,11 +43,9 @@ import tarfile import tensorflow.python.platform from six.moves import urllib -from six.moves import xrange # pylint: disable=redefined-builtin import tensorflow as tf from tensorflow.models.image.cifar10 import cifar10_input -from tensorflow.python.platform import gfile FLAGS = tf.app.flags.FLAGS @@ -57,15 +55,12 @@ tf.app.flags.DEFINE_integer('batch_size', 128, tf.app.flags.DEFINE_string('data_dir', '/tmp/cifar10_data', """Path to the CIFAR-10 data directory.""") -# Process images of this size. Note that this differs from the original CIFAR -# image size of 32 x 32. If one alters this number, then the entire model -# architecture will change and any model would need to be retrained. -IMAGE_SIZE = 24 - # Global constants describing the CIFAR-10 data set. -NUM_CLASSES = 10 -NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = 50000 -NUM_EXAMPLES_PER_EPOCH_FOR_EVAL = 10000 +IMAGE_SIZE = cifar10_input.IMAGE_SIZE +NUM_CLASSES = cifar10_input.NUM_CLASSES +NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = cifar10_input.NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN +NUM_EXAMPLES_PER_EPOCH_FOR_EVAL = cifar10_input.NUM_EXAMPLES_PER_EPOCH_FOR_EVAL + # Constants describing the training process. MOVING_AVERAGE_DECAY = 0.9999 # The decay to use for the moving average. @@ -139,91 +134,21 @@ def _variable_with_weight_decay(name, shape, stddev, wd): return var -def _generate_image_and_label_batch(image, label, min_queue_examples): - """Construct a queued batch of images and labels. - - Args: - image: 3-D Tensor of [IMAGE_SIZE, IMAGE_SIZE, 3] of type.float32. - label: 1-D Tensor of type.int32 - min_queue_examples: int32, minimum number of samples to retain - in the queue that provides of batches of examples. - - Returns: - images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size. - labels: Labels. 1D tensor of [batch_size] size. - """ - # Create a queue that shuffles the examples, and then - # read 'FLAGS.batch_size' images + labels from the example queue. - num_preprocess_threads = 16 - images, label_batch = tf.train.shuffle_batch( - [image, label], - batch_size=FLAGS.batch_size, - num_threads=num_preprocess_threads, - capacity=min_queue_examples + 3 * FLAGS.batch_size, - min_after_dequeue=min_queue_examples) - - # Display the training images in the visualizer. - tf.image_summary('images', images) - - return images, tf.reshape(label_batch, [FLAGS.batch_size]) - - def distorted_inputs(): """Construct distorted input for CIFAR training using the Reader ops. - Raises: - ValueError: if no data_dir - Returns: images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size. labels: Labels. 1D tensor of [batch_size] size. - """ - filenames = [os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin', - 'data_batch_%d.bin' % i) - for i in xrange(1, 6)] - for f in filenames: - if not gfile.Exists(f): - raise ValueError('Failed to find file: ' + f) - - # Create a queue that produces the filenames to read. - filename_queue = tf.train.string_input_producer(filenames) - # Read examples from files in the filename queue. - read_input = cifar10_input.read_cifar10(filename_queue) - reshaped_image = tf.cast(read_input.uint8image, tf.float32) - - height = IMAGE_SIZE - width = IMAGE_SIZE - - # Image processing for training the network. Note the many random - # distortions applied to the image. - - # Randomly crop a [height, width] section of the image. - distorted_image = tf.image.random_crop(reshaped_image, [height, width]) - - # Randomly flip the image horizontally. - distorted_image = tf.image.random_flip_left_right(distorted_image) - - # Because these operations are not commutative, consider randomizing - # randomize the order their operation. - distorted_image = tf.image.random_brightness(distorted_image, - max_delta=63) - distorted_image = tf.image.random_contrast(distorted_image, - lower=0.2, upper=1.8) - - # Subtract off the mean and divide by the variance of the pixels. - float_image = tf.image.per_image_whitening(distorted_image) - - # Ensure that the random shuffling has good mixing properties. - min_fraction_of_examples_in_queue = 0.4 - min_queue_examples = int(NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN * - min_fraction_of_examples_in_queue) - print ('Filling queue with %d CIFAR images before starting to train. ' - 'This will take a few minutes.' % min_queue_examples) - - # Generate a batch of images and labels by building up a queue of examples. - return _generate_image_and_label_batch(float_image, read_input.label, - min_queue_examples) + Raises: + ValueError: If no data_dir + """ + if not FLAGS.data_dir: + raise ValueError('Please supply a data_dir') + data_dir = os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin') + return cifar10_input.distorted_inputs(data_dir=data_dir, + batch_size=FLAGS.batch_size) def inputs(eval_data): @@ -232,56 +157,18 @@ def inputs(eval_data): Args: eval_data: bool, indicating if one should use the train or eval data set. - Raises: - ValueError: if no data_dir - Returns: images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size. labels: Labels. 1D tensor of [batch_size] size. + + Raises: + ValueError: If no data_dir """ if not FLAGS.data_dir: raise ValueError('Please supply a data_dir') - - if not eval_data: - filenames = [os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin', - 'data_batch_%d.bin' % i) - for i in xrange(1, 6)] - num_examples_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN - else: - filenames = [os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin', - 'test_batch.bin')] - num_examples_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_EVAL - - for f in filenames: - if not gfile.Exists(f): - raise ValueError('Failed to find file: ' + f) - - # Create a queue that produces the filenames to read. - filename_queue = tf.train.string_input_producer(filenames) - - # Read examples from files in the filename queue. - read_input = cifar10_input.read_cifar10(filename_queue) - reshaped_image = tf.cast(read_input.uint8image, tf.float32) - - height = IMAGE_SIZE - width = IMAGE_SIZE - - # Image processing for evaluation. - # Crop the central [height, width] of the image. - resized_image = tf.image.resize_image_with_crop_or_pad(reshaped_image, - width, height) - - # Subtract off the mean and divide by the variance of the pixels. - float_image = tf.image.per_image_whitening(resized_image) - - # Ensure that the random shuffling has good mixing properties. - min_fraction_of_examples_in_queue = 0.4 - min_queue_examples = int(num_examples_per_epoch * - min_fraction_of_examples_in_queue) - - # Generate a batch of images and labels by building up a queue of examples. - return _generate_image_and_label_batch(float_image, read_input.label, - min_queue_examples) + data_dir = os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin') + return cifar10_input.inputs(eval_data=eval_data, data_dir=data_dir, + batch_size=FLAGS.batch_size) def inference(images): |