aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/models/image/cifar10/cifar10_train.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/models/image/cifar10/cifar10_train.py')
-rw-r--r--tensorflow/models/image/cifar10/cifar10_train.py119
1 files changed, 119 insertions, 0 deletions
diff --git a/tensorflow/models/image/cifar10/cifar10_train.py b/tensorflow/models/image/cifar10/cifar10_train.py
new file mode 100644
index 0000000000..bcb6eeae58
--- /dev/null
+++ b/tensorflow/models/image/cifar10/cifar10_train.py
@@ -0,0 +1,119 @@
+"""A binary to train CIFAR-10 using a single GPU.
+
+Accuracy:
+cifar10_train.py achieves ~86% accuracy after 100K steps (256 epochs of
+data) as judged by cifar10_eval.py.
+
+Speed: With batch_size 128.
+
+System | Step Time (sec/batch) | Accuracy
+------------------------------------------------------------------
+1 Tesla K20m | 0.35-0.60 | ~86% at 60K steps (5 hours)
+1 Tesla K40m | 0.25-0.35 | ~86% at 100K steps (4 hours)
+
+Usage:
+Please see the tutorial and website for how to download the CIFAR-10
+data set, compile the program and train the model.
+
+http://tensorflow.org/tutorials/deep_cnn/
+"""
+from datetime import datetime
+import os.path
+import time
+
+import tensorflow.python.platform
+from tensorflow.python.platform import gfile
+
+import numpy as np
+
+import tensorflow as tf
+
+from tensorflow.models.image.cifar10 import cifar10
+
+FLAGS = tf.app.flags.FLAGS
+
+tf.app.flags.DEFINE_string('train_dir', '/tmp/cifar10_train',
+ """Directory where to write event logs """
+ """and checkpoint.""")
+tf.app.flags.DEFINE_integer('max_steps', 1000000,
+ """Number of batches to run.""")
+tf.app.flags.DEFINE_boolean('log_device_placement', False,
+ """Whether to log device placement.""")
+
+
+def train():
+ """Train CIFAR-10 for a number of steps."""
+ with tf.Graph().as_default():
+ global_step = tf.Variable(0, trainable=False)
+
+ # Get images and labels for CIFAR-10.
+ images, labels = cifar10.distorted_inputs()
+
+ # Build a Graph that computes the logits predictions from the
+ # inference model.
+ logits = cifar10.inference(images)
+
+ # Calculate loss.
+ loss = cifar10.loss(logits, labels)
+
+ # Build a Graph that trains the model with one batch of examples and
+ # updates the model parameters.
+ train_op = cifar10.train(loss, global_step)
+
+ # Create a saver.
+ saver = tf.train.Saver(tf.all_variables())
+
+ # Build the summary operation based on the TF collection of Summaries.
+ summary_op = tf.merge_all_summaries()
+
+ # Build an initialization operation to run below.
+ init = tf.initialize_all_variables()
+
+ # Start running operations on the Graph.
+ sess = tf.Session(config=tf.ConfigProto(
+ log_device_placement=FLAGS.log_device_placement))
+ sess.run(init)
+
+ # Start the queue runners.
+ tf.train.start_queue_runners(sess=sess)
+
+ summary_writer = tf.train.SummaryWriter(FLAGS.train_dir,
+ graph_def=sess.graph_def)
+
+ for step in xrange(FLAGS.max_steps):
+ start_time = time.time()
+ _, loss_value = sess.run([train_op, loss])
+ duration = time.time() - start_time
+
+ assert not np.isnan(loss_value), 'Model diverged with loss = NaN'
+
+ if step % 10 == 0:
+ num_examples_per_step = FLAGS.batch_size
+ examples_per_sec = num_examples_per_step / float(duration)
+ sec_per_batch = float(duration)
+
+ format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
+ 'sec/batch)')
+ print (format_str % (datetime.now(), step, loss_value,
+ examples_per_sec, sec_per_batch))
+
+ if step % 100 == 0:
+ summary_str = sess.run(summary_op)
+ summary_writer.add_summary(summary_str, step)
+
+ # Save the model checkpoint periodically.
+ if step % 1000 == 0 or (step + 1) == FLAGS.max_steps:
+ checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
+ saver.save(sess, checkpoint_path, global_step=step)
+
+
+def main(argv=None): # pylint: disable=unused-argument
+ cifar10.maybe_download_and_extract()
+ if gfile.Exists(FLAGS.train_dir):
+ gfile.DeleteRecursively(FLAGS.train_dir)
+ gfile.MakeDirs(FLAGS.train_dir)
+ train()
+
+
+if __name__ == '__main__':
+ tf.app.run()