aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/examples/image_retraining
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-06-24 16:50:40 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-06-24 18:05:02 -0700
commit67324b1e3af826c4c491802f4022a5f5be9f6670 (patch)
tree018e21feade905a14d8beb2c8e7ebdfd905bbf51 /tensorflow/examples/image_retraining
parente30936c026655f1b2f4f45997da32c257d18b076 (diff)
Merge changes from github.
Change: 125835079
Diffstat (limited to 'tensorflow/examples/image_retraining')
-rw-r--r--tensorflow/examples/image_retraining/retrain.py118
1 files changed, 91 insertions, 27 deletions
diff --git a/tensorflow/examples/image_retraining/retrain.py b/tensorflow/examples/image_retraining/retrain.py
index 9b01376577..6a3024d5bc 100644
--- a/tensorflow/examples/image_retraining/retrain.py
+++ b/tensorflow/examples/image_retraining/retrain.py
@@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Simple transfer learning with an Inception v3 architecture model.
+"""Simple transfer learning with an Inception v3 architecture model which
+displays summaries in TensorBoard.
This example shows how to take a Inception v3 architecture model trained on
ImageNet images, and train a new top layer that can recognize other classes of
@@ -49,6 +50,15 @@ in.
This produces a new model file that can be loaded and run by any TensorFlow
program, for example the label_image sample code.
+
+To use with TensorBoard:
+
+By default, this script will log summaries to /tmp/retrain_logs directory
+
+Visualize the summaries with this command:
+
+tensorboard --logdir /tmp/retrain_logs
+
"""
from __future__ import absolute_import
from __future__ import division
@@ -81,6 +91,8 @@ tf.app.flags.DEFINE_string('output_graph', '/tmp/output_graph.pb',
"""Where to save the trained graph.""")
tf.app.flags.DEFINE_string('output_labels', '/tmp/output_labels.txt',
"""Where to save the trained graph's labels.""")
+tf.app.flags.DEFINE_string('summaries_dir', '/tmp/retrain_logs',
+ """Where to save summary logs for TensorBoard.""")
# Details of the training configuration.
tf.app.flags.DEFINE_integer('how_many_training_steps', 4000,
@@ -650,6 +662,19 @@ def add_input_distortions(flip_left_right, random_crop, random_scale,
return jpeg_data, distort_result
+def variable_summaries(var, name):
+ """Attach a lot of summaries to a Tensor (for TensorBoard visualization)."""
+ with tf.name_scope('summaries'):
+ mean = tf.reduce_mean(var)
+ tf.scalar_summary('mean/' + name, mean)
+ with tf.name_scope('stddev'):
+ stddev = tf.sqrt(tf.reduce_sum(tf.square(var - mean)))
+ tf.scalar_summary('sttdev/' + name, stddev)
+ tf.scalar_summary('max/' + name, tf.reduce_max(var))
+ tf.scalar_summary('min/' + name, tf.reduce_min(var))
+ tf.histogram_summary(name, var)
+
+
def add_final_training_ops(class_count, final_tensor_name, bottleneck_tensor):
"""Adds a new softmax and fully-connected layer for training.
@@ -670,24 +695,43 @@ def add_final_training_ops(class_count, final_tensor_name, bottleneck_tensor):
The tensors for the training and cross entropy results, and tensors for the
bottleneck input and ground truth input.
"""
- bottleneck_input = tf.placeholder_with_default(
- bottleneck_tensor, shape=[None, BOTTLENECK_TENSOR_SIZE],
- name='BottleneckInputPlaceholder')
- layer_weights = tf.Variable(
- tf.truncated_normal([BOTTLENECK_TENSOR_SIZE, class_count], stddev=0.001),
- name='final_weights')
- layer_biases = tf.Variable(tf.zeros([class_count]), name='final_biases')
- logits = tf.matmul(bottleneck_input, layer_weights,
- name='final_matmul') + layer_biases
+ with tf.name_scope('input'):
+ bottleneck_input = tf.placeholder_with_default(
+ bottleneck_tensor, shape=[None, BOTTLENECK_TENSOR_SIZE],
+ name='BottleneckInputPlaceholder')
+
+ ground_truth_input = tf.placeholder(tf.float32,
+ [None, class_count],
+ name='GroundTruthInput')
+
+ # Organizing the following ops as `final_training_ops` so they're easier
+ # to see in TensorBoard
+ layer_name = 'final_training_ops'
+ with tf.name_scope(layer_name):
+ with tf.name_scope('weights'):
+ layer_weights = tf.Variable(tf.truncated_normal([BOTTLENECK_TENSOR_SIZE, class_count], stddev=0.001), name='final_weights')
+ variable_summaries(layer_weights, layer_name + '/weights')
+ with tf.name_scope('biases'):
+ layer_biases = tf.Variable(tf.zeros([class_count]), name='final_biases')
+ variable_summaries(layer_biases, layer_name + '/biases')
+ with tf.name_scope('Wx_plus_b'):
+ logits = tf.matmul(bottleneck_input, layer_weights) + layer_biases
+ tf.histogram_summary(layer_name + '/pre_activations', logits)
+
final_tensor = tf.nn.softmax(logits, name=final_tensor_name)
- ground_truth_input = tf.placeholder(tf.float32,
- [None, class_count],
- name='GroundTruthInput')
- cross_entropy = tf.nn.softmax_cross_entropy_with_logits(
+ tf.histogram_summary(final_tensor_name + '/activations', final_tensor)
+
+ with tf.name_scope('cross_entropy'):
+ cross_entropy = tf.nn.softmax_cross_entropy_with_logits(
logits, ground_truth_input)
- cross_entropy_mean = tf.reduce_mean(cross_entropy)
- train_step = tf.train.GradientDescentOptimizer(FLAGS.learning_rate).minimize(
- cross_entropy_mean)
+ with tf.name_scope('total'):
+ cross_entropy_mean = tf.reduce_mean(cross_entropy)
+ tf.scalar_summary('cross entropy', cross_entropy_mean)
+
+ with tf.name_scope('train'):
+ train_step = tf.train.GradientDescentOptimizer(FLAGS.learning_rate).minimize(
+ cross_entropy_mean)
+
return (train_step, cross_entropy_mean, bottleneck_input, ground_truth_input,
final_tensor)
@@ -703,13 +747,22 @@ def add_evaluation_step(result_tensor, ground_truth_tensor):
Returns:
Nothing.
"""
- correct_prediction = tf.equal(
- tf.argmax(result_tensor, 1), tf.argmax(ground_truth_tensor, 1))
- evaluation_step = tf.reduce_mean(tf.cast(correct_prediction, 'float'))
+ with tf.name_scope('accuracy'):
+ with tf.name_scope('correct_prediction'):
+ correct_prediction = tf.equal(tf.argmax(result_tensor, 1), \
+ tf.argmax(ground_truth_tensor, 1))
+ with tf.name_scope('accuracy'):
+ evaluation_step = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
+ tf.scalar_summary('accuracy', evaluation_step)
return evaluation_step
def main(_):
+ # Setup the directory we'll write summaries to for TensorBoard
+ if tf.gfile.Exists(FLAGS.summaries_dir):
+ tf.gfile.DeleteRecursively(FLAGS.summaries_dir)
+ tf.gfile.MakeDirs(FLAGS.summaries_dir)
+
# Set up the pre-trained graph.
maybe_download_and_extract()
graph, bottleneck_tensor, jpeg_data_tensor, resized_image_tensor = (
@@ -750,13 +803,19 @@ def main(_):
FLAGS.final_tensor_name,
bottleneck_tensor)
+ # Create the operations we need to evaluate the accuracy of our new layer.
+ evaluation_step = add_evaluation_step(final_tensor, ground_truth_input)
+
+ # Merge all the summaries and write them out to /tmp/retrain_logs (by default)
+ merged = tf.merge_all_summaries()
+ train_writer = tf.train.SummaryWriter(FLAGS.summaries_dir + '/train',
+ sess.graph)
+ validation_writer = tf.train.SummaryWriter(FLAGS.summaries_dir + '/validation')
+
# Set up all our weights to their initial default values.
init = tf.initialize_all_variables()
sess.run(init)
- # Create the operations we need to evaluate the accuracy of our new layer.
- evaluation_step = add_evaluation_step(final_tensor, ground_truth_input)
-
# Run the training for as many cycles as requested on the command line.
for i in range(FLAGS.how_many_training_steps):
# Get a catch of input bottleneck values, either calculated fresh every time
@@ -772,10 +831,12 @@ def main(_):
FLAGS.bottleneck_dir, FLAGS.image_dir, jpeg_data_tensor,
bottleneck_tensor)
# Feed the bottlenecks and ground truth into the graph, and run a training
- # step.
- sess.run(train_step,
+ # step. Capture training summaries for TensorBoard with the `merged` op.
+ train_summary, _ = sess.run([merged, train_step],
feed_dict={bottleneck_input: train_bottlenecks,
ground_truth_input: train_ground_truth})
+ train_writer.add_summary(train_summary, i)
+
# Every so often, print out how well the graph is training.
is_last_step = (i + 1 == FLAGS.how_many_training_steps)
if (i % FLAGS.eval_step_interval) == 0 or is_last_step:
@@ -792,10 +853,13 @@ def main(_):
sess, image_lists, FLAGS.validation_batch_size, 'validation',
FLAGS.bottleneck_dir, FLAGS.image_dir, jpeg_data_tensor,
bottleneck_tensor))
- validation_accuracy = sess.run(
- evaluation_step,
+ # Run a validation step and capture training summaries for TensorBoard
+ # with the `merged` op.
+ validation_summary, validation_accuracy = sess.run(
+ [merged, evaluation_step],
feed_dict={bottleneck_input: validation_bottlenecks,
ground_truth_input: validation_ground_truth})
+ validation_writer.add_summary(validation_summary, i)
print('%s: Step %d: Validation accuracy = %.1f%%' %
(datetime.now(), i, validation_accuracy * 100))