diff options
Diffstat (limited to 'tensorflow/models/image/cifar10/cifar10_input.py')
-rw-r--r-- | tensorflow/models/image/cifar10/cifar10_input.py | 65 |
1 files changed, 65 insertions, 0 deletions
diff --git a/tensorflow/models/image/cifar10/cifar10_input.py b/tensorflow/models/image/cifar10/cifar10_input.py new file mode 100644 index 0000000000..686f1bf987 --- /dev/null +++ b/tensorflow/models/image/cifar10/cifar10_input.py @@ -0,0 +1,65 @@ +"""Routine for decoding the CIFAR-10 binary file format.""" + +import tensorflow.python.platform +import tensorflow as tf + + +def read_cifar10(filename_queue): + """Reads and parses examples from CIFAR10 data files. + + Recommendation: if you want N-way read parallelism, call this function + N times. This will give you N independent Readers reading different + files & positions within those files, which will give better mixing of + examples. + + Args: + filename_queue: A queue of strings with the filenames to read from. + + Returns: + An object representing a single example, with the following fields: + height: number of rows in the result (32) + width: number of columns in the result (32) + depth: number of color channels in the result (3) + key: a scalar string Tensor describing the filename & record number + for this example. + label: an int32 Tensor with the label in the range 0..9. + uint8image: a [height, width, depth] uint8 Tensor with the image data + """ + + class CIFAR10Record(object): + pass + result = CIFAR10Record() + + # Dimensions of the images in the CIFAR-10 dataset. + # See http://www.cs.toronto.edu/~kriz/cifar.html for a description of the + # input format. + label_bytes = 1 # 2 for CIFAR-100 + result.height = 32 + result.width = 32 + result.depth = 3 + image_bytes = result.height * result.width * result.depth + # Every record consists of a label followed by the image, with a + # fixed number of bytes for each. + record_bytes = label_bytes + image_bytes + + # Read a record, getting filenames from the filename_queue. No + # header or footer in the CIFAR-10 format, so we leave header_bytes + # and footer_bytes at their default of 0. + reader = tf.FixedLengthRecordReader(record_bytes=record_bytes) + result.key, value = reader.read(filename_queue) + + # Convert from a string to a vector of uint8 that is record_bytes long. + record_bytes = tf.decode_raw(value, tf.uint8) + + # The first bytes represent the label, which we convert from uint8->int32. + result.label = tf.cast( + tf.slice(record_bytes, [0], [label_bytes]), tf.int32) + + # The remaining bytes after the label represent the image, which we reshape + # from [depth * height * width] to [depth, height, width]. + depth_major = tf.reshape(tf.slice(record_bytes, [label_bytes], [image_bytes]), + [result.depth, result.height, result.width]) + # Convert from [depth, height, width] to [height, width, depth]. + result.uint8image = tf.transpose(depth_major, [1, 2, 0]) + + return result |