aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/examples/image_retraining
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-11-28 23:34:36 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-11-28 23:44:32 -0800
commit7061dcef3b9717476b2ec74c043f6c30a707376c (patch)
tree7fa4309c5b50f6ef904c1a5148ebe5b77f0c7e30 /tensorflow/examples/image_retraining
parentd0aee9e749e7266ac4f0d18f4d9f7fd9a7487eda (diff)
Improvements to image retraining example code:
1. Add an option to use the entire test/validation set for evaluating the model quality, rather than choosing a random subset. This makes the metrics much more stable across training iterations and across runs. 2. Add an option to output the list of failing examples in the test set. Change: 140444449
Diffstat (limited to 'tensorflow/examples/image_retraining')
-rw-r--r--tensorflow/examples/image_retraining/retrain.py126
1 files changed, 83 insertions, 43 deletions
diff --git a/tensorflow/examples/image_retraining/retrain.py b/tensorflow/examples/image_retraining/retrain.py
index 6b34e57b8f..9e2d01615d 100644
--- a/tensorflow/examples/image_retraining/retrain.py
+++ b/tensorflow/examples/image_retraining/retrain.py
@@ -85,14 +85,6 @@ from tensorflow.python.util import compat
FLAGS = None
-# Input and output file flags.
-
-# Details of the training configuration.
-
-# File-system cache locations.
-
-# Controls the distortions used during training.
-
# These are all parameters that are tied to the particular model architecture
# we're using for Inception v3. These include things like tensor names and their
# sizes. If you want to adapt this script to work with another model, you will
@@ -455,7 +447,8 @@ def get_random_cached_bottlenecks(sess, image_lists, how_many, category,
Args:
sess: Current TensorFlow Session.
image_lists: Dictionary of training images for each label.
- how_many: The number of bottleneck values to return.
+ how_many: If positive, a random sample of this size will be chosen.
+ If negative, all bottlenecks will be retrieved.
category: Name string of which set to pull from - training, testing, or
validation.
bottleneck_dir: Folder string holding cached files of bottleneck values.
@@ -465,24 +458,47 @@ def get_random_cached_bottlenecks(sess, image_lists, how_many, category,
bottleneck_tensor: The bottleneck output layer of the CNN graph.
Returns:
- List of bottleneck arrays and their corresponding ground truths.
+ List of bottleneck arrays, their corresponding ground truths, and the
+ relevant filenames.
"""
class_count = len(image_lists.keys())
bottlenecks = []
ground_truths = []
- for unused_i in range(how_many):
- label_index = random.randrange(class_count)
- label_name = list(image_lists.keys())[label_index]
- image_index = random.randrange(MAX_NUM_IMAGES_PER_CLASS + 1)
- bottleneck = get_or_create_bottleneck(sess, image_lists, label_name,
- image_index, image_dir, category,
- bottleneck_dir, jpeg_data_tensor,
- bottleneck_tensor)
- ground_truth = np.zeros(class_count, dtype=np.float32)
- ground_truth[label_index] = 1.0
- bottlenecks.append(bottleneck)
- ground_truths.append(ground_truth)
- return bottlenecks, ground_truths
+ filenames = []
+ if how_many >= 0:
+ # Retrieve a random sample of bottlenecks.
+ for unused_i in range(how_many):
+ label_index = random.randrange(class_count)
+ label_name = list(image_lists.keys())[label_index]
+ image_index = random.randrange(MAX_NUM_IMAGES_PER_CLASS + 1)
+ image_name = get_image_path(image_lists, label_name, image_index,
+ image_dir, category)
+ bottleneck = get_or_create_bottleneck(sess, image_lists, label_name,
+ image_index, image_dir, category,
+ bottleneck_dir, jpeg_data_tensor,
+ bottleneck_tensor)
+ ground_truth = np.zeros(class_count, dtype=np.float32)
+ ground_truth[label_index] = 1.0
+ bottlenecks.append(bottleneck)
+ ground_truths.append(ground_truth)
+ filenames.append(image_name)
+ else:
+ # Retrieve all bottlenecks.
+ for label_index, label_name in enumerate(image_lists.keys()):
+ for image_index, image_name in enumerate(
+ image_lists[label_name][category]):
+ image_name = get_image_path(image_lists, label_name, image_index,
+ image_dir, category)
+ bottleneck = get_or_create_bottleneck(sess, image_lists, label_name,
+ image_index, image_dir, category,
+ bottleneck_dir, jpeg_data_tensor,
+ bottleneck_tensor)
+ ground_truth = np.zeros(class_count, dtype=np.float32)
+ ground_truth[label_index] = 1.0
+ bottlenecks.append(bottleneck)
+ ground_truths.append(ground_truth)
+ filenames.append(image_name)
+ return bottlenecks, ground_truths, filenames
def get_random_distorted_bottlenecks(
@@ -729,16 +745,17 @@ def add_evaluation_step(result_tensor, ground_truth_tensor):
into.
Returns:
- Nothing.
+ Tuple of (evaluation step, prediction).
"""
with tf.name_scope('accuracy'):
with tf.name_scope('correct_prediction'):
- correct_prediction = tf.equal(tf.argmax(result_tensor, 1), \
- tf.argmax(ground_truth_tensor, 1))
+ prediction = tf.argmax(result_tensor, 1)
+ correct_prediction = tf.equal(
+ prediction, tf.argmax(ground_truth_tensor, 1))
with tf.name_scope('accuracy'):
evaluation_step = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
tf.summary.scalar('accuracy', evaluation_step)
- return evaluation_step
+ return evaluation_step, prediction
def main(_):
@@ -788,7 +805,8 @@ def main(_):
bottleneck_tensor)
# Create the operations we need to evaluate the accuracy of our new layer.
- evaluation_step = add_evaluation_step(final_tensor, ground_truth_input)
+ evaluation_step, prediction = add_evaluation_step(
+ final_tensor, ground_truth_input)
# Merge all the summaries and write them out to /tmp/retrain_logs (by default)
merged = tf.summary.merge_all()
@@ -810,7 +828,7 @@ def main(_):
FLAGS.image_dir, distorted_jpeg_data_tensor,
distorted_image_tensor, resized_image_tensor, bottleneck_tensor)
else:
- train_bottlenecks, train_ground_truth = get_random_cached_bottlenecks(
+ train_bottlenecks, train_ground_truth, _ = get_random_cached_bottlenecks(
sess, image_lists, FLAGS.train_batch_size, 'training',
FLAGS.bottleneck_dir, FLAGS.image_dir, jpeg_data_tensor,
bottleneck_tensor)
@@ -832,7 +850,7 @@ def main(_):
train_accuracy * 100))
print('%s: Step %d: Cross entropy = %f' % (datetime.now(), i,
cross_entropy_value))
- validation_bottlenecks, validation_ground_truth = (
+ validation_bottlenecks, validation_ground_truth, _ = (
get_random_cached_bottlenecks(
sess, image_lists, FLAGS.validation_batch_size, 'validation',
FLAGS.bottleneck_dir, FLAGS.image_dir, jpeg_data_tensor,
@@ -844,20 +862,29 @@ def main(_):
feed_dict={bottleneck_input: validation_bottlenecks,
ground_truth_input: validation_ground_truth})
validation_writer.add_summary(validation_summary, i)
- print('%s: Step %d: Validation accuracy = %.1f%%' %
- (datetime.now(), i, validation_accuracy * 100))
+ print('%s: Step %d: Validation accuracy = %.1f%% (N=%d)' %
+ (datetime.now(), i, validation_accuracy * 100,
+ len(validation_bottlenecks)))
# We've completed all our training, so run a final test evaluation on
# some new images we haven't used before.
- test_bottlenecks, test_ground_truth = get_random_cached_bottlenecks(
- sess, image_lists, FLAGS.test_batch_size, 'testing',
- FLAGS.bottleneck_dir, FLAGS.image_dir, jpeg_data_tensor,
- bottleneck_tensor)
- test_accuracy = sess.run(
- evaluation_step,
+ test_bottlenecks, test_ground_truth, test_filenames = (
+ get_random_cached_bottlenecks(sess, image_lists, FLAGS.test_batch_size,
+ 'testing', FLAGS.bottleneck_dir,
+ FLAGS.image_dir, jpeg_data_tensor,
+ bottleneck_tensor))
+ test_accuracy, predictions = sess.run(
+ [evaluation_step, prediction],
feed_dict={bottleneck_input: test_bottlenecks,
ground_truth_input: test_ground_truth})
- print('Final test accuracy = %.1f%%' % (test_accuracy * 100))
+ print('Final test accuracy = %.1f%% (N=%d)' % (
+ test_accuracy * 100, len(test_bottlenecks)))
+
+ if FLAGS.print_misclassified_test_images:
+ print('=== MISCLASSIFIED TEST IMAGES ===')
+ for i, test_filename in enumerate(test_filenames):
+ if predictions[i] != test_ground_truth[i].argmax():
+ print('%70s %s' % (test_filename, image_lists.keys()[predictions[i]]))
# Write out the trained graph and labels with the weights stored as constants.
output_graph_def = graph_util.convert_variables_to_constants(
@@ -933,10 +960,12 @@ if __name__ == '__main__':
parser.add_argument(
'--test_batch_size',
type=int,
- default=500,
+ default=-1,
help="""\
- How many images to test on at a time. This test set is only used
- infrequently to verify the overall accuracy of the model.\
+ How many images to test on. This test set is only used once, to evaluate
+ the final accuracy of the model after training completes.
+ A value of -1 causes the entire test set to be used, which leads to more
+ stable results across runs.\
"""
)
parser.add_argument(
@@ -946,10 +975,21 @@ if __name__ == '__main__':
help="""\
How many images to use in an evaluation batch. This validation set is
used much more often than the test set, and is an early indicator of how
- accurate the model is during training.\
+ accurate the model is during training.
+ A value of -1 causes the entire validation set to be used, which leads to
+ more stable results across training iterations, but may be slower on large
+ training sets.\
"""
)
parser.add_argument(
+ '--print_misclassified_test_images',
+ default=False,
+ help="""\
+ Whether to print out a list of all misclassified test images.\
+ """,
+ action='store_true'
+ )
+ parser.add_argument(
'--model_dir',
type=str,
default='/tmp/imagenet',