diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2016-06-24 16:50:40 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-06-24 18:05:02 -0700 |
commit | 67324b1e3af826c4c491802f4022a5f5be9f6670 (patch) | |
tree | 018e21feade905a14d8beb2c8e7ebdfd905bbf51 /tensorflow/examples/image_retraining | |
parent | e30936c026655f1b2f4f45997da32c257d18b076 (diff) |
Merge changes from github.
Change: 125835079
Diffstat (limited to 'tensorflow/examples/image_retraining')
-rw-r--r-- | tensorflow/examples/image_retraining/retrain.py | 118 |
1 files changed, 91 insertions, 27 deletions
diff --git a/tensorflow/examples/image_retraining/retrain.py b/tensorflow/examples/image_retraining/retrain.py index 9b01376577..6a3024d5bc 100644 --- a/tensorflow/examples/image_retraining/retrain.py +++ b/tensorflow/examples/image_retraining/retrain.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Simple transfer learning with an Inception v3 architecture model. +"""Simple transfer learning with an Inception v3 architecture model which +displays summaries in TensorBoard. This example shows how to take a Inception v3 architecture model trained on ImageNet images, and train a new top layer that can recognize other classes of @@ -49,6 +50,15 @@ in. This produces a new model file that can be loaded and run by any TensorFlow program, for example the label_image sample code. + +To use with TensorBoard: + +By default, this script will log summaries to /tmp/retrain_logs directory + +Visualize the summaries with this command: + +tensorboard --logdir /tmp/retrain_logs + """ from __future__ import absolute_import from __future__ import division @@ -81,6 +91,8 @@ tf.app.flags.DEFINE_string('output_graph', '/tmp/output_graph.pb', """Where to save the trained graph.""") tf.app.flags.DEFINE_string('output_labels', '/tmp/output_labels.txt', """Where to save the trained graph's labels.""") +tf.app.flags.DEFINE_string('summaries_dir', '/tmp/retrain_logs', + """Where to save summary logs for TensorBoard.""") # Details of the training configuration. tf.app.flags.DEFINE_integer('how_many_training_steps', 4000, @@ -650,6 +662,19 @@ def add_input_distortions(flip_left_right, random_crop, random_scale, return jpeg_data, distort_result +def variable_summaries(var, name): + """Attach a lot of summaries to a Tensor (for TensorBoard visualization).""" + with tf.name_scope('summaries'): + mean = tf.reduce_mean(var) + tf.scalar_summary('mean/' + name, mean) + with tf.name_scope('stddev'): + stddev = tf.sqrt(tf.reduce_sum(tf.square(var - mean))) + tf.scalar_summary('sttdev/' + name, stddev) + tf.scalar_summary('max/' + name, tf.reduce_max(var)) + tf.scalar_summary('min/' + name, tf.reduce_min(var)) + tf.histogram_summary(name, var) + + def add_final_training_ops(class_count, final_tensor_name, bottleneck_tensor): """Adds a new softmax and fully-connected layer for training. @@ -670,24 +695,43 @@ def add_final_training_ops(class_count, final_tensor_name, bottleneck_tensor): The tensors for the training and cross entropy results, and tensors for the bottleneck input and ground truth input. """ - bottleneck_input = tf.placeholder_with_default( - bottleneck_tensor, shape=[None, BOTTLENECK_TENSOR_SIZE], - name='BottleneckInputPlaceholder') - layer_weights = tf.Variable( - tf.truncated_normal([BOTTLENECK_TENSOR_SIZE, class_count], stddev=0.001), - name='final_weights') - layer_biases = tf.Variable(tf.zeros([class_count]), name='final_biases') - logits = tf.matmul(bottleneck_input, layer_weights, - name='final_matmul') + layer_biases + with tf.name_scope('input'): + bottleneck_input = tf.placeholder_with_default( + bottleneck_tensor, shape=[None, BOTTLENECK_TENSOR_SIZE], + name='BottleneckInputPlaceholder') + + ground_truth_input = tf.placeholder(tf.float32, + [None, class_count], + name='GroundTruthInput') + + # Organizing the following ops as `final_training_ops` so they're easier + # to see in TensorBoard + layer_name = 'final_training_ops' + with tf.name_scope(layer_name): + with tf.name_scope('weights'): + layer_weights = tf.Variable(tf.truncated_normal([BOTTLENECK_TENSOR_SIZE, class_count], stddev=0.001), name='final_weights') + variable_summaries(layer_weights, layer_name + '/weights') + with tf.name_scope('biases'): + layer_biases = tf.Variable(tf.zeros([class_count]), name='final_biases') + variable_summaries(layer_biases, layer_name + '/biases') + with tf.name_scope('Wx_plus_b'): + logits = tf.matmul(bottleneck_input, layer_weights) + layer_biases + tf.histogram_summary(layer_name + '/pre_activations', logits) + final_tensor = tf.nn.softmax(logits, name=final_tensor_name) - ground_truth_input = tf.placeholder(tf.float32, - [None, class_count], - name='GroundTruthInput') - cross_entropy = tf.nn.softmax_cross_entropy_with_logits( + tf.histogram_summary(final_tensor_name + '/activations', final_tensor) + + with tf.name_scope('cross_entropy'): + cross_entropy = tf.nn.softmax_cross_entropy_with_logits( logits, ground_truth_input) - cross_entropy_mean = tf.reduce_mean(cross_entropy) - train_step = tf.train.GradientDescentOptimizer(FLAGS.learning_rate).minimize( - cross_entropy_mean) + with tf.name_scope('total'): + cross_entropy_mean = tf.reduce_mean(cross_entropy) + tf.scalar_summary('cross entropy', cross_entropy_mean) + + with tf.name_scope('train'): + train_step = tf.train.GradientDescentOptimizer(FLAGS.learning_rate).minimize( + cross_entropy_mean) + return (train_step, cross_entropy_mean, bottleneck_input, ground_truth_input, final_tensor) @@ -703,13 +747,22 @@ def add_evaluation_step(result_tensor, ground_truth_tensor): Returns: Nothing. """ - correct_prediction = tf.equal( - tf.argmax(result_tensor, 1), tf.argmax(ground_truth_tensor, 1)) - evaluation_step = tf.reduce_mean(tf.cast(correct_prediction, 'float')) + with tf.name_scope('accuracy'): + with tf.name_scope('correct_prediction'): + correct_prediction = tf.equal(tf.argmax(result_tensor, 1), \ + tf.argmax(ground_truth_tensor, 1)) + with tf.name_scope('accuracy'): + evaluation_step = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) + tf.scalar_summary('accuracy', evaluation_step) return evaluation_step def main(_): + # Setup the directory we'll write summaries to for TensorBoard + if tf.gfile.Exists(FLAGS.summaries_dir): + tf.gfile.DeleteRecursively(FLAGS.summaries_dir) + tf.gfile.MakeDirs(FLAGS.summaries_dir) + # Set up the pre-trained graph. maybe_download_and_extract() graph, bottleneck_tensor, jpeg_data_tensor, resized_image_tensor = ( @@ -750,13 +803,19 @@ def main(_): FLAGS.final_tensor_name, bottleneck_tensor) + # Create the operations we need to evaluate the accuracy of our new layer. + evaluation_step = add_evaluation_step(final_tensor, ground_truth_input) + + # Merge all the summaries and write them out to /tmp/retrain_logs (by default) + merged = tf.merge_all_summaries() + train_writer = tf.train.SummaryWriter(FLAGS.summaries_dir + '/train', + sess.graph) + validation_writer = tf.train.SummaryWriter(FLAGS.summaries_dir + '/validation') + # Set up all our weights to their initial default values. init = tf.initialize_all_variables() sess.run(init) - # Create the operations we need to evaluate the accuracy of our new layer. - evaluation_step = add_evaluation_step(final_tensor, ground_truth_input) - # Run the training for as many cycles as requested on the command line. for i in range(FLAGS.how_many_training_steps): # Get a catch of input bottleneck values, either calculated fresh every time @@ -772,10 +831,12 @@ def main(_): FLAGS.bottleneck_dir, FLAGS.image_dir, jpeg_data_tensor, bottleneck_tensor) # Feed the bottlenecks and ground truth into the graph, and run a training - # step. - sess.run(train_step, + # step. Capture training summaries for TensorBoard with the `merged` op. + train_summary, _ = sess.run([merged, train_step], feed_dict={bottleneck_input: train_bottlenecks, ground_truth_input: train_ground_truth}) + train_writer.add_summary(train_summary, i) + # Every so often, print out how well the graph is training. is_last_step = (i + 1 == FLAGS.how_many_training_steps) if (i % FLAGS.eval_step_interval) == 0 or is_last_step: @@ -792,10 +853,13 @@ def main(_): sess, image_lists, FLAGS.validation_batch_size, 'validation', FLAGS.bottleneck_dir, FLAGS.image_dir, jpeg_data_tensor, bottleneck_tensor)) - validation_accuracy = sess.run( - evaluation_step, + # Run a validation step and capture training summaries for TensorBoard + # with the `merged` op. + validation_summary, validation_accuracy = sess.run( + [merged, evaluation_step], feed_dict={bottleneck_input: validation_bottlenecks, ground_truth_input: validation_ground_truth}) + validation_writer.add_summary(validation_summary, i) print('%s: Step %d: Validation accuracy = %.1f%%' % (datetime.now(), i, validation_accuracy * 100)) |