aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/models/image/cifar10/cifar10_input.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/models/image/cifar10/cifar10_input.py')
-rw-r--r--tensorflow/models/image/cifar10/cifar10_input.py65
1 files changed, 65 insertions, 0 deletions
diff --git a/tensorflow/models/image/cifar10/cifar10_input.py b/tensorflow/models/image/cifar10/cifar10_input.py
new file mode 100644
index 0000000000..686f1bf987
--- /dev/null
+++ b/tensorflow/models/image/cifar10/cifar10_input.py
@@ -0,0 +1,65 @@
+"""Routine for decoding the CIFAR-10 binary file format."""
+
+import tensorflow.python.platform
+import tensorflow as tf
+
+
+def read_cifar10(filename_queue):
+ """Reads and parses examples from CIFAR10 data files.
+
+ Recommendation: if you want N-way read parallelism, call this function
+ N times. This will give you N independent Readers reading different
+ files & positions within those files, which will give better mixing of
+ examples.
+
+ Args:
+ filename_queue: A queue of strings with the filenames to read from.
+
+ Returns:
+ An object representing a single example, with the following fields:
+ height: number of rows in the result (32)
+ width: number of columns in the result (32)
+ depth: number of color channels in the result (3)
+ key: a scalar string Tensor describing the filename & record number
+ for this example.
+ label: an int32 Tensor with the label in the range 0..9.
+ uint8image: a [height, width, depth] uint8 Tensor with the image data
+ """
+
+ class CIFAR10Record(object):
+ pass
+ result = CIFAR10Record()
+
+ # Dimensions of the images in the CIFAR-10 dataset.
+ # See http://www.cs.toronto.edu/~kriz/cifar.html for a description of the
+ # input format.
+ label_bytes = 1 # 2 for CIFAR-100
+ result.height = 32
+ result.width = 32
+ result.depth = 3
+ image_bytes = result.height * result.width * result.depth
+ # Every record consists of a label followed by the image, with a
+ # fixed number of bytes for each.
+ record_bytes = label_bytes + image_bytes
+
+ # Read a record, getting filenames from the filename_queue. No
+ # header or footer in the CIFAR-10 format, so we leave header_bytes
+ # and footer_bytes at their default of 0.
+ reader = tf.FixedLengthRecordReader(record_bytes=record_bytes)
+ result.key, value = reader.read(filename_queue)
+
+ # Convert from a string to a vector of uint8 that is record_bytes long.
+ record_bytes = tf.decode_raw(value, tf.uint8)
+
+ # The first bytes represent the label, which we convert from uint8->int32.
+ result.label = tf.cast(
+ tf.slice(record_bytes, [0], [label_bytes]), tf.int32)
+
+ # The remaining bytes after the label represent the image, which we reshape
+ # from [depth * height * width] to [depth, height, width].
+ depth_major = tf.reshape(tf.slice(record_bytes, [label_bytes], [image_bytes]),
+ [result.depth, result.height, result.width])
+ # Convert from [depth, height, width] to [height, width, depth].
+ result.uint8image = tf.transpose(depth_major, [1, 2, 0])
+
+ return result