aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/g3doc/how_tos/reading_data/fully_connected_preloaded.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/g3doc/how_tos/reading_data/fully_connected_preloaded.py')
-rw-r--r--tensorflow/g3doc/how_tos/reading_data/fully_connected_preloaded.py134
1 files changed, 134 insertions, 0 deletions
diff --git a/tensorflow/g3doc/how_tos/reading_data/fully_connected_preloaded.py b/tensorflow/g3doc/how_tos/reading_data/fully_connected_preloaded.py
new file mode 100644
index 0000000000..b2436cd2ab
--- /dev/null
+++ b/tensorflow/g3doc/how_tos/reading_data/fully_connected_preloaded.py
@@ -0,0 +1,134 @@
+"""Trains the MNIST network using preloaded data in a constant.
+
+Command to run this py_binary target:
+
+bazel run -c opt \
+ <...>/tensorflow/g3doc/how_tos/reading_data:fully_connected_preloaded
+"""
+import os.path
+import time
+
+import tensorflow.python.platform
+import numpy
+import tensorflow as tf
+
+from tensorflow.g3doc.tutorials.mnist import input_data
+from tensorflow.g3doc.tutorials.mnist import mnist
+
+
+# Basic model parameters as external flags.
+flags = tf.app.flags
+FLAGS = flags.FLAGS
+flags.DEFINE_float('learning_rate', 0.01, 'Initial learning rate.')
+flags.DEFINE_integer('num_epochs', 2, 'Number of epochs to run trainer.')
+flags.DEFINE_integer('hidden1', 128, 'Number of units in hidden layer 1.')
+flags.DEFINE_integer('hidden2', 32, 'Number of units in hidden layer 2.')
+flags.DEFINE_integer('batch_size', 100, 'Batch size. '
+ 'Must divide evenly into the dataset sizes.')
+flags.DEFINE_string('train_dir', 'data', 'Directory to put the training data.')
+flags.DEFINE_boolean('fake_data', False, 'If true, uses fake data '
+ 'for unit testing.')
+
+
+def run_training():
+ """Train MNIST for a number of epochs."""
+ # Get the sets of images and labels for training, validation, and
+ # test on MNIST.
+ data_sets = input_data.read_data_sets(FLAGS.train_dir, FLAGS.fake_data)
+
+ # Tell TensorFlow that the model will be built into the default Graph.
+ with tf.Graph().as_default():
+ with tf.name_scope('input'):
+ # Input data
+ input_images = tf.constant(data_sets.train.images)
+ input_labels = tf.constant(data_sets.train.labels)
+
+ image, label = tf.train.slice_input_producer(
+ [input_images, input_labels], num_epochs=FLAGS.num_epochs)
+ label = tf.cast(label, tf.int32)
+ images, labels = tf.train.batch(
+ [image, label], batch_size=FLAGS.batch_size)
+
+ # Build a Graph that computes predictions from the inference model.
+ logits = mnist.inference(images, FLAGS.hidden1, FLAGS.hidden2)
+
+ # Add to the Graph the Ops for loss calculation.
+ loss = mnist.loss(logits, labels)
+
+ # Add to the Graph the Ops that calculate and apply gradients.
+ train_op = mnist.training(loss, FLAGS.learning_rate)
+
+ # Add the Op to compare the logits to the labels during evaluation.
+ eval_correct = mnist.evaluation(logits, labels)
+
+ # Build the summary operation based on the TF collection of Summaries.
+ summary_op = tf.merge_all_summaries()
+
+ # Create a saver for writing training checkpoints.
+ saver = tf.train.Saver()
+
+ # Create the op for initializing variables.
+ init_op = tf.initialize_all_variables()
+
+ # Create a session for running Ops on the Graph.
+ sess = tf.Session()
+
+ # Run the Op to initialize the variables.
+ sess.run(init_op)
+
+ # Instantiate a SummaryWriter to output summaries and the Graph.
+ summary_writer = tf.train.SummaryWriter(FLAGS.train_dir,
+ graph_def=sess.graph_def)
+
+ # Start input enqueue threads.
+ coord = tf.train.Coordinator()
+ threads = tf.train.start_queue_runners(sess=sess, coord=coord)
+
+ # And then after everything is built, start the training loop.
+ try:
+ step = 0
+ while not coord.should_stop():
+ start_time = time.time()
+
+ # Run one step of the model.
+ _, loss_value = sess.run([train_op, loss])
+
+ duration = time.time() - start_time
+
+ # Write the summaries and print an overview fairly often.
+ if step % 100 == 0:
+ # Print status to stdout.
+ print 'Step %d: loss = %.2f (%.3f sec)' % (step,
+ loss_value,
+ duration)
+ # Update the events file.
+ summary_str = sess.run(summary_op)
+ summary_writer.add_summary(summary_str, step)
+ step += 1
+
+ # Save a checkpoint periodically.
+ if (step + 1) % 1000 == 0:
+ print 'Saving'
+ saver.save(sess, FLAGS.train_dir, global_step=step)
+
+ step += 1
+ except tf.errors.OutOfRangeError:
+ print 'Saving'
+ saver.save(sess, FLAGS.train_dir, global_step=step)
+ print 'Done training for %d epochs, %d steps.' % (
+ FLAGS.num_epochs, step)
+ finally:
+ # When done, ask the threads to stop.
+ coord.request_stop()
+
+ # Wait for threads to finish.
+ coord.join(threads)
+ sess.close()
+
+
+def main(_):
+ run_training()
+
+
+if __name__ == '__main__':
+ tf.app.run()