aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/models/image/cifar10/cifar10_eval.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/models/image/cifar10/cifar10_eval.py')
-rw-r--r--tensorflow/models/image/cifar10/cifar10_eval.py148
1 files changed, 148 insertions, 0 deletions
diff --git a/tensorflow/models/image/cifar10/cifar10_eval.py b/tensorflow/models/image/cifar10/cifar10_eval.py
new file mode 100644
index 0000000000..73c224191d
--- /dev/null
+++ b/tensorflow/models/image/cifar10/cifar10_eval.py
@@ -0,0 +1,148 @@
+"""Evaluation for CIFAR-10.
+
+Accuracy:
+cifar10_train.py achieves 83.0% accuracy after 100K steps (256 epochs
+of data) as judged by cifar10_eval.py.
+
+Speed:
+On a single Tesla K40, cifar10_train.py processes a single batch of 128 images
+in 0.25-0.35 sec (i.e. 350 - 600 images /sec). The model reaches ~86%
+accuracy after 100K steps in 8 hours of training time.
+
+Usage:
+Please see the tutorial and website for how to download the CIFAR-10
+data set, compile the program and train the model.
+
+http://tensorflow.org/tutorials/deep_cnn/
+"""
+from datetime import datetime
+import math
+import time
+
+import tensorflow.python.platform
+from tensorflow.python.platform import gfile
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.models.image.cifar10 import cifar10
+
+FLAGS = tf.app.flags.FLAGS
+
+tf.app.flags.DEFINE_string('eval_dir', '/tmp/cifar10_eval',
+ """Directory where to write event logs.""")
+tf.app.flags.DEFINE_string('eval_data', 'test',
+ """Either 'test' or 'train_eval'.""")
+tf.app.flags.DEFINE_string('checkpoint_dir', '/tmp/cifar10_train',
+ """Directory where to read model checkpoints.""")
+tf.app.flags.DEFINE_integer('eval_interval_secs', 60 * 5,
+ """How often to run the eval.""")
+tf.app.flags.DEFINE_integer('num_examples', 10000,
+ """Number of examples to run.""")
+tf.app.flags.DEFINE_boolean('run_once', False,
+ """Whether to run eval only once.""")
+
+
+def eval_once(saver, summary_writer, top_k_op, summary_op):
+ """Run Eval once.
+
+ Args:
+ saver: Saver.
+ summary_writer: Summary writer.
+ top_k_op: Top K op.
+ summary_op: Summary op.
+ """
+ with tf.Session() as sess:
+ ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
+ if ckpt and ckpt.model_checkpoint_path:
+ # Restores from checkpoint
+ saver.restore(sess, ckpt.model_checkpoint_path)
+ # Assuming model_checkpoint_path looks something like:
+ # /my-favorite-path/cifar10_train/model.ckpt-0,
+ # extract global_step from it.
+ global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
+ else:
+ print 'No checkpoint file found'
+ return
+
+ # Start the queue runners.
+ coord = tf.train.Coordinator()
+ try:
+ threads = []
+ for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
+ threads.extend(qr.create_threads(sess, coord=coord, daemon=True,
+ start=True))
+
+ num_iter = int(math.ceil(FLAGS.num_examples / FLAGS.batch_size))
+ true_count = 0 # Counts the number of correct predictions.
+ total_sample_count = num_iter * FLAGS.batch_size
+ step = 0
+ while step < num_iter and not coord.should_stop():
+ predictions = sess.run([top_k_op])
+ true_count += np.sum(predictions)
+ step += 1
+
+ # Compute precision @ 1.
+ precision = float(true_count) / float(total_sample_count)
+ print '%s: precision @ 1 = %.3f' % (datetime.now(), precision)
+
+ summary = tf.Summary()
+ summary.ParseFromString(sess.run(summary_op))
+ summary.value.add(tag='Precision @ 1', simple_value=precision)
+ summary_writer.add_summary(summary, global_step)
+ except Exception, e: # pylint: disable=broad-except
+ coord.request_stop(e)
+
+ coord.request_stop()
+ coord.join(threads, stop_grace_period_secs=10)
+
+
+def evaluate():
+ """Eval CIFAR-10 for a number of steps."""
+ with tf.Graph().as_default():
+ # Get images and labels for CIFAR-10.
+ eval_data = FLAGS.eval_data == 'test'
+ images, labels = cifar10.inputs(eval_data=eval_data)
+
+ # Build a Graph that computes the logits predictions from the
+ # inference model.
+ logits = cifar10.inference(images)
+
+ # Calculate predictions.
+ top_k_op = tf.nn.in_top_k(logits, labels, 1)
+
+ # Restore the moving average version of the learned variables for eval.
+ variable_averages = tf.train.ExponentialMovingAverage(
+ cifar10.MOVING_AVERAGE_DECAY)
+ variables_to_restore = {}
+ for v in tf.all_variables():
+ if v in tf.trainable_variables():
+ restore_name = variable_averages.average_name(v)
+ else:
+ restore_name = v.op.name
+ variables_to_restore[restore_name] = v
+ saver = tf.train.Saver(variables_to_restore)
+
+ # Build the summary operation based on the TF collection of Summaries.
+ summary_op = tf.merge_all_summaries()
+
+ graph_def = tf.get_default_graph().as_graph_def()
+ summary_writer = tf.train.SummaryWriter(FLAGS.eval_dir,
+ graph_def=graph_def)
+
+ while True:
+ eval_once(saver, summary_writer, top_k_op, summary_op)
+ if FLAGS.run_once:
+ break
+ time.sleep(FLAGS.eval_interval_secs)
+
+
+def main(argv=None): # pylint: disable=unused-argument
+ cifar10.maybe_download_and_extract()
+ if gfile.Exists(FLAGS.eval_dir):
+ gfile.DeleteRecursively(FLAGS.eval_dir)
+ gfile.MakeDirs(FLAGS.eval_dir)
+ evaluate()
+
+
+if __name__ == '__main__':
+ tf.app.run()