diff options
Diffstat (limited to 'tensorflow/models/image/cifar10/cifar10_train.py')
-rw-r--r-- | tensorflow/models/image/cifar10/cifar10_train.py | 119 |
1 files changed, 119 insertions, 0 deletions
diff --git a/tensorflow/models/image/cifar10/cifar10_train.py b/tensorflow/models/image/cifar10/cifar10_train.py new file mode 100644 index 0000000000..bcb6eeae58 --- /dev/null +++ b/tensorflow/models/image/cifar10/cifar10_train.py @@ -0,0 +1,119 @@ +"""A binary to train CIFAR-10 using a single GPU. + +Accuracy: +cifar10_train.py achieves ~86% accuracy after 100K steps (256 epochs of +data) as judged by cifar10_eval.py. + +Speed: With batch_size 128. + +System | Step Time (sec/batch) | Accuracy +------------------------------------------------------------------ +1 Tesla K20m | 0.35-0.60 | ~86% at 60K steps (5 hours) +1 Tesla K40m | 0.25-0.35 | ~86% at 100K steps (4 hours) + +Usage: +Please see the tutorial and website for how to download the CIFAR-10 +data set, compile the program and train the model. + +http://tensorflow.org/tutorials/deep_cnn/ +""" +from datetime import datetime +import os.path +import time + +import tensorflow.python.platform +from tensorflow.python.platform import gfile + +import numpy as np + +import tensorflow as tf + +from tensorflow.models.image.cifar10 import cifar10 + +FLAGS = tf.app.flags.FLAGS + +tf.app.flags.DEFINE_string('train_dir', '/tmp/cifar10_train', + """Directory where to write event logs """ + """and checkpoint.""") +tf.app.flags.DEFINE_integer('max_steps', 1000000, + """Number of batches to run.""") +tf.app.flags.DEFINE_boolean('log_device_placement', False, + """Whether to log device placement.""") + + +def train(): + """Train CIFAR-10 for a number of steps.""" + with tf.Graph().as_default(): + global_step = tf.Variable(0, trainable=False) + + # Get images and labels for CIFAR-10. + images, labels = cifar10.distorted_inputs() + + # Build a Graph that computes the logits predictions from the + # inference model. + logits = cifar10.inference(images) + + # Calculate loss. + loss = cifar10.loss(logits, labels) + + # Build a Graph that trains the model with one batch of examples and + # updates the model parameters. + train_op = cifar10.train(loss, global_step) + + # Create a saver. + saver = tf.train.Saver(tf.all_variables()) + + # Build the summary operation based on the TF collection of Summaries. + summary_op = tf.merge_all_summaries() + + # Build an initialization operation to run below. + init = tf.initialize_all_variables() + + # Start running operations on the Graph. + sess = tf.Session(config=tf.ConfigProto( + log_device_placement=FLAGS.log_device_placement)) + sess.run(init) + + # Start the queue runners. + tf.train.start_queue_runners(sess=sess) + + summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, + graph_def=sess.graph_def) + + for step in xrange(FLAGS.max_steps): + start_time = time.time() + _, loss_value = sess.run([train_op, loss]) + duration = time.time() - start_time + + assert not np.isnan(loss_value), 'Model diverged with loss = NaN' + + if step % 10 == 0: + num_examples_per_step = FLAGS.batch_size + examples_per_sec = num_examples_per_step / float(duration) + sec_per_batch = float(duration) + + format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f ' + 'sec/batch)') + print (format_str % (datetime.now(), step, loss_value, + examples_per_sec, sec_per_batch)) + + if step % 100 == 0: + summary_str = sess.run(summary_op) + summary_writer.add_summary(summary_str, step) + + # Save the model checkpoint periodically. + if step % 1000 == 0 or (step + 1) == FLAGS.max_steps: + checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt') + saver.save(sess, checkpoint_path, global_step=step) + + +def main(argv=None): # pylint: disable=unused-argument + cifar10.maybe_download_and_extract() + if gfile.Exists(FLAGS.train_dir): + gfile.DeleteRecursively(FLAGS.train_dir) + gfile.MakeDirs(FLAGS.train_dir) + train() + + +if __name__ == '__main__': + tf.app.run() |