aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/examples/image_retraining
diff options
context:
space:
mode:
authorGravatar Mark Daoust <markdaoust@google.com>2017-04-17 16:27:37 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-04-17 17:54:59 -0700
commita8b175a7d79032308f147351a85039bc3a29694a (patch)
tree5c07d2825dc760f390f00fdfd25dae0086205cb1 /tensorflow/examples/image_retraining
parent6656c28260ae93a24d032b44cb96aca87c8d350a (diff)
Improve graph and session management for retrain.py
Also fixed existing lint errors. Change: 153410743
Diffstat (limited to 'tensorflow/examples/image_retraining')
-rw-r--r--tensorflow/examples/image_retraining/retrain.py284
1 files changed, 162 insertions, 122 deletions
diff --git a/tensorflow/examples/image_retraining/retrain.py b/tensorflow/examples/image_retraining/retrain.py
index a3a4ba310e..5f4b6bed48 100644
--- a/tensorflow/examples/image_retraining/retrain.py
+++ b/tensorflow/examples/image_retraining/retrain.py
@@ -12,8 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Simple transfer learning with an Inception v3 architecture model which
-displays summaries in TensorBoard.
+r"""Simple transfer learning with an Inception v3 architecture model.
+
+With support for TensorBoard.
This example shows how to take a Inception v3 architecture model trained on
ImageNet images, and train a new top layer that can recognize other classes of
@@ -39,9 +40,20 @@ The subfolder names are important, since they define what label is applied to
each image, but the filenames themselves don't matter. Once your images are
prepared, you can run the training with a command like this:
+
+```bash
bazel build tensorflow/examples/image_retraining:retrain && \
bazel-bin/tensorflow/examples/image_retraining/retrain \
---image_dir ~/flower_photos
+ --image_dir ~/flower_photos
+```
+
+Or, if you have a pip installation of tensorflow, `retrain.py` can be run
+without bazel:
+
+```bash
+python tensorflow/examples/image_retraining/retrain.py \
+ --image_dir ~/flower_photos
+```
You can replace the image_dir argument with any folder containing subfolders of
images. The label for each image is taken from the name of the subfolder it's
@@ -244,7 +256,7 @@ def create_inception_graph():
Graph holding the trained Inception network, and various tensors we'll be
manipulating.
"""
- with tf.Session() as sess:
+ with tf.Graph().as_default() as graph:
model_filename = os.path.join(
FLAGS.model_dir, 'classify_image_graph_def.pb')
with gfile.FastGFile(model_filename, 'rb') as f:
@@ -254,7 +266,7 @@ def create_inception_graph():
tf.import_graph_def(graph_def, name='', return_elements=[
BOTTLENECK_TENSOR_NAME, JPEG_DATA_TENSOR_NAME,
RESIZED_INPUT_TENSOR_NAME]))
- return sess.graph, bottleneck_tensor, jpeg_data_tensor, resized_input_tensor
+ return graph, bottleneck_tensor, jpeg_data_tensor, resized_input_tensor
def run_bottleneck_on_image(sess, image_data, image_data_tensor,
@@ -315,7 +327,7 @@ def ensure_dir_exists(dir_name):
os.makedirs(dir_name)
-def write_list_of_floats_to_file(list_of_floats , file_path):
+def write_list_of_floats_to_file(list_of_floats, file_path):
"""Writes a given list of floats to a binary file.
Args:
@@ -346,18 +358,25 @@ def read_list_of_floats_from_file(file_path):
bottleneck_path_2_bottleneck_values = {}
+
def create_bottleneck_file(bottleneck_path, image_lists, label_name, index,
- image_dir, category, sess, jpeg_data_tensor, bottleneck_tensor):
+ image_dir, category, sess, jpeg_data_tensor,
+ bottleneck_tensor):
+ """Create a single bottleneck file."""
print('Creating bottleneck at ' + bottleneck_path)
- image_path = get_image_path(image_lists, label_name, index, image_dir, category)
+ image_path = get_image_path(image_lists, label_name, index,
+ image_dir, category)
if not gfile.Exists(image_path):
tf.logging.fatal('File does not exist %s', image_path)
image_data = gfile.FastGFile(image_path, 'rb').read()
- bottleneck_values = run_bottleneck_on_image(sess, image_data, jpeg_data_tensor, bottleneck_tensor)
+ bottleneck_values = run_bottleneck_on_image(sess, image_data,
+ jpeg_data_tensor,
+ bottleneck_tensor)
bottleneck_string = ','.join(str(x) for x in bottleneck_values)
with open(bottleneck_path, 'w') as bottleneck_file:
bottleneck_file.write(bottleneck_string)
+
def get_or_create_bottleneck(sess, image_lists, label_name, index, image_dir,
category, bottleneck_dir, jpeg_data_tensor,
bottleneck_tensor):
@@ -387,25 +406,32 @@ def get_or_create_bottleneck(sess, image_lists, label_name, index, image_dir,
sub_dir = label_lists['dir']
sub_dir_path = os.path.join(bottleneck_dir, sub_dir)
ensure_dir_exists(sub_dir_path)
- bottleneck_path = get_bottleneck_path(image_lists, label_name, index, bottleneck_dir, category)
+ bottleneck_path = get_bottleneck_path(image_lists, label_name, index,
+ bottleneck_dir, category)
if not os.path.exists(bottleneck_path):
- create_bottleneck_file(bottleneck_path, image_lists, label_name, index, image_dir, category, sess, jpeg_data_tensor, bottleneck_tensor)
+ create_bottleneck_file(bottleneck_path, image_lists, label_name, index,
+ image_dir, category, sess, jpeg_data_tensor,
+ bottleneck_tensor)
with open(bottleneck_path, 'r') as bottleneck_file:
bottleneck_string = bottleneck_file.read()
did_hit_error = False
try:
bottleneck_values = [float(x) for x in bottleneck_string.split(',')]
- except:
- print("Invalid float found, recreating bottleneck")
+ except ValueError:
+ print('Invalid float found, recreating bottleneck')
did_hit_error = True
if did_hit_error:
- create_bottleneck_file(bottleneck_path, image_lists, label_name, index, image_dir, category, sess, jpeg_data_tensor, bottleneck_tensor)
+ create_bottleneck_file(bottleneck_path, image_lists, label_name, index,
+ image_dir, category, sess, jpeg_data_tensor,
+ bottleneck_tensor)
with open(bottleneck_path, 'r') as bottleneck_file:
bottleneck_string = bottleneck_file.read()
- # Allow exceptions to propagate here, since they shouldn't happen after a fresh creation
+ # Allow exceptions to propagate here, since they shouldn't happen after a
+ # fresh creation
bottleneck_values = [float(x) for x in bottleneck_string.split(',')]
return bottleneck_values
+
def cache_bottlenecks(sess, image_lists, image_dir, bottleneck_dir,
jpeg_data_tensor, bottleneck_tensor):
"""Ensures all the training, testing, and validation bottlenecks are cached.
@@ -718,7 +744,11 @@ def add_final_training_ops(class_count, final_tensor_name, bottleneck_tensor):
layer_name = 'final_training_ops'
with tf.name_scope(layer_name):
with tf.name_scope('weights'):
- layer_weights = tf.Variable(tf.truncated_normal([BOTTLENECK_TENSOR_SIZE, class_count], stddev=0.001), name='final_weights')
+ initial_value = tf.truncated_normal([BOTTLENECK_TENSOR_SIZE, class_count],
+ stddev=0.001)
+
+ layer_weights = tf.Variable(initial_value, name='final_weights')
+
variable_summaries(layer_weights)
with tf.name_scope('biases'):
layer_biases = tf.Variable(tf.zeros([class_count]), name='final_biases')
@@ -738,8 +768,8 @@ def add_final_training_ops(class_count, final_tensor_name, bottleneck_tensor):
tf.summary.scalar('cross_entropy', cross_entropy_mean)
with tf.name_scope('train'):
- train_step = tf.train.GradientDescentOptimizer(FLAGS.learning_rate).minimize(
- cross_entropy_mean)
+ optimizer = tf.train.GradientDescentOptimizer(FLAGS.learning_rate)
+ train_step = optimizer.minimize(cross_entropy_mean)
return (train_step, cross_entropy_mean, bottleneck_input, ground_truth_input,
final_tensor)
@@ -794,115 +824,125 @@ def main(_):
do_distort_images = should_distort_images(
FLAGS.flip_left_right, FLAGS.random_crop, FLAGS.random_scale,
FLAGS.random_brightness)
- sess = tf.Session()
- if do_distort_images:
- # We will be applying distortions, so setup the operations we'll need.
- distorted_jpeg_data_tensor, distorted_image_tensor = add_input_distortions(
- FLAGS.flip_left_right, FLAGS.random_crop, FLAGS.random_scale,
- FLAGS.random_brightness)
- else:
- # We'll make sure we've calculated the 'bottleneck' image summaries and
- # cached them on disk.
- cache_bottlenecks(sess, image_lists, FLAGS.image_dir, FLAGS.bottleneck_dir,
- jpeg_data_tensor, bottleneck_tensor)
-
- # Add the new layer that we'll be training.
- (train_step, cross_entropy, bottleneck_input, ground_truth_input,
- final_tensor) = add_final_training_ops(len(image_lists.keys()),
- FLAGS.final_tensor_name,
- bottleneck_tensor)
-
- # Create the operations we need to evaluate the accuracy of our new layer.
- evaluation_step, prediction = add_evaluation_step(
- final_tensor, ground_truth_input)
-
- # Merge all the summaries and write them out to /tmp/retrain_logs (by default)
- merged = tf.summary.merge_all()
- train_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/train',
- sess.graph)
- validation_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/validation')
-
- # Set up all our weights to their initial default values.
- init = tf.global_variables_initializer()
- sess.run(init)
-
- # Run the training for as many cycles as requested on the command line.
- for i in range(FLAGS.how_many_training_steps):
- # Get a batch of input bottleneck values, either calculated fresh every time
- # with distortions applied, or from the cache stored on disk.
+ with tf.Session(graph=graph) as sess:
+
if do_distort_images:
- train_bottlenecks, train_ground_truth = get_random_distorted_bottlenecks(
- sess, image_lists, FLAGS.train_batch_size, 'training',
- FLAGS.image_dir, distorted_jpeg_data_tensor,
- distorted_image_tensor, resized_image_tensor, bottleneck_tensor)
+ # We will be applying distortions, so setup the operations we'll need.
+ (distorted_jpeg_data_tensor,
+ distorted_image_tensor) = add_input_distortions(
+ FLAGS.flip_left_right, FLAGS.random_crop,
+ FLAGS.random_scale, FLAGS.random_brightness)
else:
- train_bottlenecks, train_ground_truth, _ = get_random_cached_bottlenecks(
- sess, image_lists, FLAGS.train_batch_size, 'training',
- FLAGS.bottleneck_dir, FLAGS.image_dir, jpeg_data_tensor,
- bottleneck_tensor)
- # Feed the bottlenecks and ground truth into the graph, and run a training
- # step. Capture training summaries for TensorBoard with the `merged` op.
- train_summary, _ = sess.run([merged, train_step],
- feed_dict={bottleneck_input: train_bottlenecks,
- ground_truth_input: train_ground_truth})
- train_writer.add_summary(train_summary, i)
-
- # Every so often, print out how well the graph is training.
- is_last_step = (i + 1 == FLAGS.how_many_training_steps)
- if (i % FLAGS.eval_step_interval) == 0 or is_last_step:
- train_accuracy, cross_entropy_value = sess.run(
- [evaluation_step, cross_entropy],
+ # We'll make sure we've calculated the 'bottleneck' image summaries and
+ # cached them on disk.
+ cache_bottlenecks(sess, image_lists, FLAGS.image_dir,
+ FLAGS.bottleneck_dir, jpeg_data_tensor,
+ bottleneck_tensor)
+
+ # Add the new layer that we'll be training.
+ (train_step, cross_entropy, bottleneck_input, ground_truth_input,
+ final_tensor) = add_final_training_ops(len(image_lists.keys()),
+ FLAGS.final_tensor_name,
+ bottleneck_tensor)
+
+ # Create the operations we need to evaluate the accuracy of our new layer.
+ evaluation_step, prediction = add_evaluation_step(
+ final_tensor, ground_truth_input)
+
+ # Merge all the summaries and write them out to the summaries_dir
+ merged = tf.summary.merge_all()
+ train_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/train',
+ sess.graph)
+
+ validation_writer = tf.summary.FileWriter(
+ FLAGS.summaries_dir + '/validation')
+
+ # Set up all our weights to their initial default values.
+ init = tf.global_variables_initializer()
+ sess.run(init)
+
+ # Run the training for as many cycles as requested on the command line.
+ for i in range(FLAGS.how_many_training_steps):
+ # Get a batch of input bottleneck values, either calculated fresh every
+ # time with distortions applied, or from the cache stored on disk.
+ if do_distort_images:
+ (train_bottlenecks,
+ train_ground_truth) = get_random_distorted_bottlenecks(
+ sess, image_lists, FLAGS.train_batch_size, 'training',
+ FLAGS.image_dir, distorted_jpeg_data_tensor,
+ distorted_image_tensor, resized_image_tensor, bottleneck_tensor)
+ else:
+ (train_bottlenecks,
+ train_ground_truth, _) = get_random_cached_bottlenecks(
+ sess, image_lists, FLAGS.train_batch_size, 'training',
+ FLAGS.bottleneck_dir, FLAGS.image_dir, jpeg_data_tensor,
+ bottleneck_tensor)
+ # Feed the bottlenecks and ground truth into the graph, and run a training
+ # step. Capture training summaries for TensorBoard with the `merged` op.
+
+ train_summary, _ = sess.run(
+ [merged, train_step],
feed_dict={bottleneck_input: train_bottlenecks,
ground_truth_input: train_ground_truth})
- print('%s: Step %d: Train accuracy = %.1f%%' % (datetime.now(), i,
- train_accuracy * 100))
- print('%s: Step %d: Cross entropy = %f' % (datetime.now(), i,
- cross_entropy_value))
- validation_bottlenecks, validation_ground_truth, _ = (
- get_random_cached_bottlenecks(
- sess, image_lists, FLAGS.validation_batch_size, 'validation',
- FLAGS.bottleneck_dir, FLAGS.image_dir, jpeg_data_tensor,
- bottleneck_tensor))
- # Run a validation step and capture training summaries for TensorBoard
- # with the `merged` op.
- validation_summary, validation_accuracy = sess.run(
- [merged, evaluation_step],
- feed_dict={bottleneck_input: validation_bottlenecks,
- ground_truth_input: validation_ground_truth})
- validation_writer.add_summary(validation_summary, i)
- print('%s: Step %d: Validation accuracy = %.1f%% (N=%d)' %
- (datetime.now(), i, validation_accuracy * 100,
- len(validation_bottlenecks)))
-
- # We've completed all our training, so run a final test evaluation on
- # some new images we haven't used before.
- test_bottlenecks, test_ground_truth, test_filenames = (
- get_random_cached_bottlenecks(sess, image_lists, FLAGS.test_batch_size,
- 'testing', FLAGS.bottleneck_dir,
- FLAGS.image_dir, jpeg_data_tensor,
- bottleneck_tensor))
- test_accuracy, predictions = sess.run(
- [evaluation_step, prediction],
- feed_dict={bottleneck_input: test_bottlenecks,
- ground_truth_input: test_ground_truth})
- print('Final test accuracy = %.1f%% (N=%d)' % (
- test_accuracy * 100, len(test_bottlenecks)))
-
- if FLAGS.print_misclassified_test_images:
- print('=== MISCLASSIFIED TEST IMAGES ===')
- for i, test_filename in enumerate(test_filenames):
- if predictions[i] != test_ground_truth[i].argmax():
- print('%70s %s' % (test_filename,
- list(image_lists.keys())[predictions[i]]))
-
- # Write out the trained graph and labels with the weights stored as constants.
- output_graph_def = graph_util.convert_variables_to_constants(
- sess, graph.as_graph_def(), [FLAGS.final_tensor_name])
- with gfile.FastGFile(FLAGS.output_graph, 'wb') as f:
- f.write(output_graph_def.SerializeToString())
- with gfile.FastGFile(FLAGS.output_labels, 'w') as f:
- f.write('\n'.join(image_lists.keys()) + '\n')
+ train_writer.add_summary(train_summary, i)
+
+ # Every so often, print out how well the graph is training.
+ is_last_step = (i + 1 == FLAGS.how_many_training_steps)
+ if (i % FLAGS.eval_step_interval) == 0 or is_last_step:
+ train_accuracy, cross_entropy_value = sess.run(
+ [evaluation_step, cross_entropy],
+ feed_dict={bottleneck_input: train_bottlenecks,
+ ground_truth_input: train_ground_truth})
+ print('%s: Step %d: Train accuracy = %.1f%%' % (datetime.now(), i,
+ train_accuracy * 100))
+ print('%s: Step %d: Cross entropy = %f' % (datetime.now(), i,
+ cross_entropy_value))
+ validation_bottlenecks, validation_ground_truth, _ = (
+ get_random_cached_bottlenecks(
+ sess, image_lists, FLAGS.validation_batch_size, 'validation',
+ FLAGS.bottleneck_dir, FLAGS.image_dir, jpeg_data_tensor,
+ bottleneck_tensor))
+ # Run a validation step and capture training summaries for TensorBoard
+ # with the `merged` op.
+ validation_summary, validation_accuracy = sess.run(
+ [merged, evaluation_step],
+ feed_dict={bottleneck_input: validation_bottlenecks,
+ ground_truth_input: validation_ground_truth})
+ validation_writer.add_summary(validation_summary, i)
+ print('%s: Step %d: Validation accuracy = %.1f%% (N=%d)' %
+ (datetime.now(), i, validation_accuracy * 100,
+ len(validation_bottlenecks)))
+
+ # We've completed all our training, so run a final test evaluation on
+ # some new images we haven't used before.
+ test_bottlenecks, test_ground_truth, test_filenames = (
+ get_random_cached_bottlenecks(sess, image_lists, FLAGS.test_batch_size,
+ 'testing', FLAGS.bottleneck_dir,
+ FLAGS.image_dir, jpeg_data_tensor,
+ bottleneck_tensor))
+ test_accuracy, predictions = sess.run(
+ [evaluation_step, prediction],
+ feed_dict={bottleneck_input: test_bottlenecks,
+ ground_truth_input: test_ground_truth})
+ print('Final test accuracy = %.1f%% (N=%d)' % (
+ test_accuracy * 100, len(test_bottlenecks)))
+
+ if FLAGS.print_misclassified_test_images:
+ print('=== MISCLASSIFIED TEST IMAGES ===')
+ for i, test_filename in enumerate(test_filenames):
+ if predictions[i] != test_ground_truth[i].argmax():
+ print('%70s %s' % (test_filename,
+ list(image_lists.keys())[predictions[i]]))
+
+ # Write out the trained graph and labels with the weights stored as
+ # constants.
+ output_graph_def = graph_util.convert_variables_to_constants(
+ sess, graph.as_graph_def(), [FLAGS.final_tensor_name])
+ with gfile.FastGFile(FLAGS.output_graph, 'wb') as f:
+ f.write(output_graph_def.SerializeToString())
+ with gfile.FastGFile(FLAGS.output_labels, 'w') as f:
+ f.write('\n'.join(image_lists.keys()) + '\n')
if __name__ == '__main__':