aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/models/image/cifar10/cifar10.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/models/image/cifar10/cifar10.py')
-rw-r--r--tensorflow/models/image/cifar10/cifar10.py151
1 files changed, 19 insertions, 132 deletions
diff --git a/tensorflow/models/image/cifar10/cifar10.py b/tensorflow/models/image/cifar10/cifar10.py
index b9b89473e8..32234db496 100644
--- a/tensorflow/models/image/cifar10/cifar10.py
+++ b/tensorflow/models/image/cifar10/cifar10.py
@@ -43,11 +43,9 @@ import tarfile
import tensorflow.python.platform
from six.moves import urllib
-from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
from tensorflow.models.image.cifar10 import cifar10_input
-from tensorflow.python.platform import gfile
FLAGS = tf.app.flags.FLAGS
@@ -57,15 +55,12 @@ tf.app.flags.DEFINE_integer('batch_size', 128,
tf.app.flags.DEFINE_string('data_dir', '/tmp/cifar10_data',
"""Path to the CIFAR-10 data directory.""")
-# Process images of this size. Note that this differs from the original CIFAR
-# image size of 32 x 32. If one alters this number, then the entire model
-# architecture will change and any model would need to be retrained.
-IMAGE_SIZE = 24
-
# Global constants describing the CIFAR-10 data set.
-NUM_CLASSES = 10
-NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = 50000
-NUM_EXAMPLES_PER_EPOCH_FOR_EVAL = 10000
+IMAGE_SIZE = cifar10_input.IMAGE_SIZE
+NUM_CLASSES = cifar10_input.NUM_CLASSES
+NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = cifar10_input.NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN
+NUM_EXAMPLES_PER_EPOCH_FOR_EVAL = cifar10_input.NUM_EXAMPLES_PER_EPOCH_FOR_EVAL
+
# Constants describing the training process.
MOVING_AVERAGE_DECAY = 0.9999 # The decay to use for the moving average.
@@ -139,91 +134,21 @@ def _variable_with_weight_decay(name, shape, stddev, wd):
return var
-def _generate_image_and_label_batch(image, label, min_queue_examples):
- """Construct a queued batch of images and labels.
-
- Args:
- image: 3-D Tensor of [IMAGE_SIZE, IMAGE_SIZE, 3] of type.float32.
- label: 1-D Tensor of type.int32
- min_queue_examples: int32, minimum number of samples to retain
- in the queue that provides of batches of examples.
-
- Returns:
- images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
- labels: Labels. 1D tensor of [batch_size] size.
- """
- # Create a queue that shuffles the examples, and then
- # read 'FLAGS.batch_size' images + labels from the example queue.
- num_preprocess_threads = 16
- images, label_batch = tf.train.shuffle_batch(
- [image, label],
- batch_size=FLAGS.batch_size,
- num_threads=num_preprocess_threads,
- capacity=min_queue_examples + 3 * FLAGS.batch_size,
- min_after_dequeue=min_queue_examples)
-
- # Display the training images in the visualizer.
- tf.image_summary('images', images)
-
- return images, tf.reshape(label_batch, [FLAGS.batch_size])
-
-
def distorted_inputs():
"""Construct distorted input for CIFAR training using the Reader ops.
- Raises:
- ValueError: if no data_dir
-
Returns:
images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
labels: Labels. 1D tensor of [batch_size] size.
- """
- filenames = [os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin',
- 'data_batch_%d.bin' % i)
- for i in xrange(1, 6)]
- for f in filenames:
- if not gfile.Exists(f):
- raise ValueError('Failed to find file: ' + f)
-
- # Create a queue that produces the filenames to read.
- filename_queue = tf.train.string_input_producer(filenames)
- # Read examples from files in the filename queue.
- read_input = cifar10_input.read_cifar10(filename_queue)
- reshaped_image = tf.cast(read_input.uint8image, tf.float32)
-
- height = IMAGE_SIZE
- width = IMAGE_SIZE
-
- # Image processing for training the network. Note the many random
- # distortions applied to the image.
-
- # Randomly crop a [height, width] section of the image.
- distorted_image = tf.image.random_crop(reshaped_image, [height, width])
-
- # Randomly flip the image horizontally.
- distorted_image = tf.image.random_flip_left_right(distorted_image)
-
- # Because these operations are not commutative, consider randomizing
- # randomize the order their operation.
- distorted_image = tf.image.random_brightness(distorted_image,
- max_delta=63)
- distorted_image = tf.image.random_contrast(distorted_image,
- lower=0.2, upper=1.8)
-
- # Subtract off the mean and divide by the variance of the pixels.
- float_image = tf.image.per_image_whitening(distorted_image)
-
- # Ensure that the random shuffling has good mixing properties.
- min_fraction_of_examples_in_queue = 0.4
- min_queue_examples = int(NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN *
- min_fraction_of_examples_in_queue)
- print ('Filling queue with %d CIFAR images before starting to train. '
- 'This will take a few minutes.' % min_queue_examples)
-
- # Generate a batch of images and labels by building up a queue of examples.
- return _generate_image_and_label_batch(float_image, read_input.label,
- min_queue_examples)
+ Raises:
+ ValueError: If no data_dir
+ """
+ if not FLAGS.data_dir:
+ raise ValueError('Please supply a data_dir')
+ data_dir = os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin')
+ return cifar10_input.distorted_inputs(data_dir=data_dir,
+ batch_size=FLAGS.batch_size)
def inputs(eval_data):
@@ -232,56 +157,18 @@ def inputs(eval_data):
Args:
eval_data: bool, indicating if one should use the train or eval data set.
- Raises:
- ValueError: if no data_dir
-
Returns:
images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
labels: Labels. 1D tensor of [batch_size] size.
+
+ Raises:
+ ValueError: If no data_dir
"""
if not FLAGS.data_dir:
raise ValueError('Please supply a data_dir')
-
- if not eval_data:
- filenames = [os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin',
- 'data_batch_%d.bin' % i)
- for i in xrange(1, 6)]
- num_examples_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN
- else:
- filenames = [os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin',
- 'test_batch.bin')]
- num_examples_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_EVAL
-
- for f in filenames:
- if not gfile.Exists(f):
- raise ValueError('Failed to find file: ' + f)
-
- # Create a queue that produces the filenames to read.
- filename_queue = tf.train.string_input_producer(filenames)
-
- # Read examples from files in the filename queue.
- read_input = cifar10_input.read_cifar10(filename_queue)
- reshaped_image = tf.cast(read_input.uint8image, tf.float32)
-
- height = IMAGE_SIZE
- width = IMAGE_SIZE
-
- # Image processing for evaluation.
- # Crop the central [height, width] of the image.
- resized_image = tf.image.resize_image_with_crop_or_pad(reshaped_image,
- width, height)
-
- # Subtract off the mean and divide by the variance of the pixels.
- float_image = tf.image.per_image_whitening(resized_image)
-
- # Ensure that the random shuffling has good mixing properties.
- min_fraction_of_examples_in_queue = 0.4
- min_queue_examples = int(num_examples_per_epoch *
- min_fraction_of_examples_in_queue)
-
- # Generate a batch of images and labels by building up a queue of examples.
- return _generate_image_and_label_batch(float_image, read_input.label,
- min_queue_examples)
+ data_dir = os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin')
+ return cifar10_input.inputs(eval_data=eval_data, data_dir=data_dir,
+ batch_size=FLAGS.batch_size)
def inference(images):