From f41959ccb2d9d4c722fe8fc3351401d53bcf4900 Mon Sep 17 00:00:00 2001 From: Manjunath Kudlur Date: Fri, 6 Nov 2015 16:27:58 -0800 Subject: TensorFlow: Initial commit of TensorFlow library. TensorFlow is an open source software library for numerical computation using data flow graphs. Base CL: 107276108 --- .../reading_data/fully_connected_preloaded.py | 134 +++++++++++++++++++++ 1 file changed, 134 insertions(+) create mode 100644 tensorflow/g3doc/how_tos/reading_data/fully_connected_preloaded.py (limited to 'tensorflow/g3doc/how_tos/reading_data/fully_connected_preloaded.py') diff --git a/tensorflow/g3doc/how_tos/reading_data/fully_connected_preloaded.py b/tensorflow/g3doc/how_tos/reading_data/fully_connected_preloaded.py new file mode 100644 index 0000000000..b2436cd2ab --- /dev/null +++ b/tensorflow/g3doc/how_tos/reading_data/fully_connected_preloaded.py @@ -0,0 +1,134 @@ +"""Trains the MNIST network using preloaded data in a constant. + +Command to run this py_binary target: + +bazel run -c opt \ + <...>/tensorflow/g3doc/how_tos/reading_data:fully_connected_preloaded +""" +import os.path +import time + +import tensorflow.python.platform +import numpy +import tensorflow as tf + +from tensorflow.g3doc.tutorials.mnist import input_data +from tensorflow.g3doc.tutorials.mnist import mnist + + +# Basic model parameters as external flags. +flags = tf.app.flags +FLAGS = flags.FLAGS +flags.DEFINE_float('learning_rate', 0.01, 'Initial learning rate.') +flags.DEFINE_integer('num_epochs', 2, 'Number of epochs to run trainer.') +flags.DEFINE_integer('hidden1', 128, 'Number of units in hidden layer 1.') +flags.DEFINE_integer('hidden2', 32, 'Number of units in hidden layer 2.') +flags.DEFINE_integer('batch_size', 100, 'Batch size. ' + 'Must divide evenly into the dataset sizes.') +flags.DEFINE_string('train_dir', 'data', 'Directory to put the training data.') +flags.DEFINE_boolean('fake_data', False, 'If true, uses fake data ' + 'for unit testing.') + + +def run_training(): + """Train MNIST for a number of epochs.""" + # Get the sets of images and labels for training, validation, and + # test on MNIST. + data_sets = input_data.read_data_sets(FLAGS.train_dir, FLAGS.fake_data) + + # Tell TensorFlow that the model will be built into the default Graph. + with tf.Graph().as_default(): + with tf.name_scope('input'): + # Input data + input_images = tf.constant(data_sets.train.images) + input_labels = tf.constant(data_sets.train.labels) + + image, label = tf.train.slice_input_producer( + [input_images, input_labels], num_epochs=FLAGS.num_epochs) + label = tf.cast(label, tf.int32) + images, labels = tf.train.batch( + [image, label], batch_size=FLAGS.batch_size) + + # Build a Graph that computes predictions from the inference model. + logits = mnist.inference(images, FLAGS.hidden1, FLAGS.hidden2) + + # Add to the Graph the Ops for loss calculation. + loss = mnist.loss(logits, labels) + + # Add to the Graph the Ops that calculate and apply gradients. + train_op = mnist.training(loss, FLAGS.learning_rate) + + # Add the Op to compare the logits to the labels during evaluation. + eval_correct = mnist.evaluation(logits, labels) + + # Build the summary operation based on the TF collection of Summaries. + summary_op = tf.merge_all_summaries() + + # Create a saver for writing training checkpoints. + saver = tf.train.Saver() + + # Create the op for initializing variables. + init_op = tf.initialize_all_variables() + + # Create a session for running Ops on the Graph. + sess = tf.Session() + + # Run the Op to initialize the variables. + sess.run(init_op) + + # Instantiate a SummaryWriter to output summaries and the Graph. + summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, + graph_def=sess.graph_def) + + # Start input enqueue threads. + coord = tf.train.Coordinator() + threads = tf.train.start_queue_runners(sess=sess, coord=coord) + + # And then after everything is built, start the training loop. + try: + step = 0 + while not coord.should_stop(): + start_time = time.time() + + # Run one step of the model. + _, loss_value = sess.run([train_op, loss]) + + duration = time.time() - start_time + + # Write the summaries and print an overview fairly often. + if step % 100 == 0: + # Print status to stdout. + print 'Step %d: loss = %.2f (%.3f sec)' % (step, + loss_value, + duration) + # Update the events file. + summary_str = sess.run(summary_op) + summary_writer.add_summary(summary_str, step) + step += 1 + + # Save a checkpoint periodically. + if (step + 1) % 1000 == 0: + print 'Saving' + saver.save(sess, FLAGS.train_dir, global_step=step) + + step += 1 + except tf.errors.OutOfRangeError: + print 'Saving' + saver.save(sess, FLAGS.train_dir, global_step=step) + print 'Done training for %d epochs, %d steps.' % ( + FLAGS.num_epochs, step) + finally: + # When done, ask the threads to stop. + coord.request_stop() + + # Wait for threads to finish. + coord.join(threads) + sess.close() + + +def main(_): + run_training() + + +if __name__ == '__main__': + tf.app.run() -- cgit v1.2.3