diff options
Diffstat (limited to 'tensorflow/g3doc/how_tos/reading_data/convert_to_records.py')
-rw-r--r-- | tensorflow/g3doc/how_tos/reading_data/convert_to_records.py | 87 |
1 files changed, 87 insertions, 0 deletions
diff --git a/tensorflow/g3doc/how_tos/reading_data/convert_to_records.py b/tensorflow/g3doc/how_tos/reading_data/convert_to_records.py new file mode 100644 index 0000000000..1d510cdfa9 --- /dev/null +++ b/tensorflow/g3doc/how_tos/reading_data/convert_to_records.py @@ -0,0 +1,87 @@ +"""Converts MNIST data to TFRecords file format with Example protos.""" + +import os +import tensorflow.python.platform + +import numpy +import tensorflow as tf +from tensorflow.g3doc.tutorials.mnist import input_data + + +TRAIN_IMAGES = 'train-images-idx3-ubyte.gz' # MNIST filenames +TRAIN_LABELS = 'train-labels-idx1-ubyte.gz' +TEST_IMAGES = 't10k-images-idx3-ubyte.gz' +TEST_LABELS = 't10k-labels-idx1-ubyte.gz' + + +tf.app.flags.DEFINE_string('directory', 'data', + 'Directory to download data files and write the ' + 'converted result') +tf.app.flags.DEFINE_integer('validation_size', 5000, + 'Number of examples to separate from the training ' + 'data for the validation set.') +FLAGS = tf.app.flags.FLAGS + + +def _int64_feature(value): + return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) + + +def _bytes_feature(value): + return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) + + +def convert_to(images, labels, name): + num_examples = labels.shape[0] + if images.shape[0] != num_examples: + raise ValueError("Images size %d does not match label size %d." % + (dat.shape[0], num_examples)) + rows = images.shape[1] + cols = images.shape[2] + depth = images.shape[3] + + filename = os.path.join(FLAGS.directory, name + '.tfrecords') + print 'Writing', filename + writer = tf.python_io.TFRecordWriter(filename) + for index in range(num_examples): + image_raw = images[index].tostring() + example = tf.train.Example(features=tf.train.Features(feature={ + 'height':_int64_feature(rows), + 'width':_int64_feature(cols), + 'depth':_int64_feature(depth), + 'label':_int64_feature(int(labels[index])), + 'image_raw':_bytes_feature(image_raw)})) + writer.write(example.SerializeToString()) + + +def main(argv): + # Get the data. + train_images_filename = input_data.maybe_download( + TRAIN_IMAGES, FLAGS.directory) + train_labels_filename = input_data.maybe_download( + TRAIN_LABELS, FLAGS.directory) + test_images_filename = input_data.maybe_download( + TEST_IMAGES, FLAGS.directory) + test_labels_filename = input_data.maybe_download( + TEST_LABELS, FLAGS.directory) + + # Extract it into numpy arrays. + train_images = input_data.extract_images(train_images_filename) + train_labels = input_data.extract_labels(train_labels_filename) + test_images = input_data.extract_images(test_images_filename) + test_labels = input_data.extract_labels(test_labels_filename) + + # Generate a validation set. + validation_images = train_images[:FLAGS.validation_size, :, :, :] + validation_labels = train_labels[:FLAGS.validation_size] + train_images = train_images[FLAGS.validation_size:, :, :, :] + train_labels = train_labels[FLAGS.validation_size:] + + # Convert to Examples and write the result to TFRecords. + convert_to(train_images, train_labels, 'train') + convert_to(validation_images, validation_labels, 'validation') + convert_to(test_images, test_labels, 'test') + + +if __name__ == '__main__': + tf.app.run() |