diff options
Diffstat (limited to 'tensorflow/g3doc/how_tos/reading_data/fully_connected_reader.py')
-rw-r--r-- | tensorflow/g3doc/how_tos/reading_data/fully_connected_reader.py | 180 |
1 files changed, 180 insertions, 0 deletions
diff --git a/tensorflow/g3doc/how_tos/reading_data/fully_connected_reader.py b/tensorflow/g3doc/how_tos/reading_data/fully_connected_reader.py new file mode 100644 index 0000000000..f1e10ca34e --- /dev/null +++ b/tensorflow/g3doc/how_tos/reading_data/fully_connected_reader.py @@ -0,0 +1,180 @@ +"""Train and Eval the MNIST network. + +This version is like fully_connected_feed.py but uses data converted +to a TFRecords file containing tf.train.Example protocol buffers. +See tensorflow/g3doc/how_tos/reading_data.md#reading-from-files +for context. + +YOU MUST run convert_to_records before running this (but you only need to +run it once). +""" + +import os.path +import time + +import tensorflow.python.platform +import numpy +import tensorflow as tf + +from tensorflow.g3doc.tutorials.mnist import mnist + + +# Basic model parameters as external flags. +flags = tf.app.flags +FLAGS = flags.FLAGS +flags.DEFINE_float('learning_rate', 0.01, 'Initial learning rate.') +flags.DEFINE_integer('num_epochs', 2, 'Number of epochs to run trainer.') +flags.DEFINE_integer('hidden1', 128, 'Number of units in hidden layer 1.') +flags.DEFINE_integer('hidden2', 32, 'Number of units in hidden layer 2.') +flags.DEFINE_integer('batch_size', 100, 'Batch size.') +flags.DEFINE_string('train_dir', 'data', 'Directory with the training data.') + +# Constants used for dealing with the files, matches convert_to_records. +TRAIN_FILE = 'train.tfrecords' +VALIDATION_FILE = 'validation.tfrecords' + + +def read_and_decode(filename_queue): + reader = tf.TFRecordReader() + _, serialized_example = reader.read(filename_queue) + features = tf.parse_single_example( + serialized_example, + dense_keys=['image_raw', 'label'], + # Defaults are not specified since both keys are required. + dense_types=[tf.string, tf.int64]) + + # Convert from a scalar string tensor (whose single string has + # length mnist.IMAGE_PIXELS) to a uint8 tensor with shape + # [mnist.IMAGE_PIXELS]. + image = tf.decode_raw(features['image_raw'], tf.uint8) + image.set_shape([mnist.IMAGE_PIXELS]) + + # OPTIONAL: Could reshape into a 28x28 image and apply distortions + # here. Since we are not applying any distortions in this + # example, and the next step expects the image to be flattened + # into a vector, we don't bother. + + # Convert from [0, 255] -> [-0.5, 0.5] floats. + image = tf.cast(image, tf.float32) * (1. / 255) - 0.5 + + # Convert label from a scalar uint8 tensor to an int32 scalar. + label = tf.cast(features['label'], tf.int32) + + return image, label + + +def inputs(train, batch_size, num_epochs): + """Reads input data num_epochs times. + + Args: + train: Selects between the training (True) and validation (False) data. + batch_size: Number of examples per returned batch. + num_epochs: Number of times to read the input data, or 0/None to + train forever. + + Returns: + A tuple (images, labels), where: + * images is a float tensor with shape [batch_size, mnist.IMAGE_PIXELS] + in the range [-0.5, 0.5]. + * labels is an int32 tensor with shape [batch_size] with the true label, + a number in the range [0, mnist.NUM_CLASSES). + Note that an tf.train.QueueRunner is added to the graph, which + must be run using e.g. tf.train.start_queue_runners(). + """ + if not num_epochs: num_epochs = None + filename = os.path.join(FLAGS.train_dir, + TRAIN_FILE if train else VALIDATION_FILE) + + with tf.name_scope('input'): + filename_queue = tf.train.string_input_producer( + [filename], num_epochs=num_epochs) + + # Even when reading in multiple threads, share the filename + # queue. + image, label = read_and_decode(filename_queue) + + # Shuffle the examples and collect them into batch_size batches. + # (Internally uses a RandomShuffleQueue.) + # We run this in two threads to avoid being a bottleneck. + images, sparse_labels = tf.train.shuffle_batch( + [image, label], batch_size=batch_size, num_threads=2, + capacity=1000 + 3 * batch_size, + # Ensures a minimum amount of shuffling of examples. + min_after_dequeue=1000) + + return images, sparse_labels + + +def run_training(): + """Train MNIST for a number of steps.""" + + # Tell TensorFlow that the model will be built into the default Graph. + with tf.Graph().as_default(): + # Input images and labels. + images, labels = inputs(train=True, batch_size=FLAGS.batch_size, + num_epochs=FLAGS.num_epochs) + + # Build a Graph that computes predictions from the inference model. + logits = mnist.inference(images, + FLAGS.hidden1, + FLAGS.hidden2) + + # Add to the Graph the loss calculation. + loss = mnist.loss(logits, labels) + + # Add to the Graph operations that train the model. + train_op = mnist.training(loss, FLAGS.learning_rate) + + # The op for initializing the variables. + init_op = tf.initialize_all_variables(); + + # Create a session for running operations in the Graph. + sess = tf.Session() + + # Initialize the variables (the trained variables and the + # epoch counter). + sess.run(init_op) + + # Start input enqueue threads. + coord = tf.train.Coordinator() + threads = tf.train.start_queue_runners(sess=sess, coord=coord) + + try: + step = 0 + while not coord.should_stop(): + start_time = time.time() + + # Run one step of the model. The return values are + # the activations from the `train_op` (which is + # discarded) and the `loss` op. To inspect the values + # of your ops or variables, you may include them in + # the list passed to sess.run() and the value tensors + # will be returned in the tuple from the call. + _, loss_value = sess.run([train_op, loss]) + + duration = time.time() - start_time + + # Print an overview fairly often. + if step % 100 == 0: + print 'Step %d: loss = %.2f (%.3f sec)' % (step, + loss_value, + duration) + step += 1 + except tf.errors.OutOfRangeError: + print 'Done training for %d epochs, %d steps.' % ( + FLAGS.num_epochs, step) + finally: + # When done, ask the threads to stop. + coord.request_stop() + + # Wait for threads to finish. + coord.join(threads) + sess.close() + + +def main(_): + run_training() + + +if __name__ == '__main__': + tf.app.run() |