diff options
author | Mark Daoust <markdaoust@google.com> | 2017-05-18 13:05:48 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-05-18 13:10:01 -0700 |
commit | 315f195f728d59647f23f10ecc664b8ce9466a06 (patch) | |
tree | 350fb890c4100400e56df29f04088fd456d85e0f /tensorflow/examples/image_retraining | |
parent | f41a9113bcac0e144d214cb101a991ca8fbff963 (diff) |
Add label_image.py that works with retrain.py.
PiperOrigin-RevId: 156468713
Diffstat (limited to 'tensorflow/examples/image_retraining')
-rw-r--r-- | tensorflow/examples/image_retraining/BUILD | 17 | ||||
-rw-r--r-- | tensorflow/examples/image_retraining/data/labels.txt | 3 | ||||
-rw-r--r-- | tensorflow/examples/image_retraining/label_image.py | 147 | ||||
-rw-r--r-- | tensorflow/examples/image_retraining/retrain_test.py | 32 |
4 files changed, 199 insertions, 0 deletions
diff --git a/tensorflow/examples/image_retraining/BUILD b/tensorflow/examples/image_retraining/BUILD index 5a885e33c2..394c413b33 100644 --- a/tensorflow/examples/image_retraining/BUILD +++ b/tensorflow/examples/image_retraining/BUILD @@ -24,13 +24,30 @@ py_binary( ], ) +py_binary( + name = "label_image", + srcs = [ + "label_image.py", + ], + srcs_version = "PY2AND3", + visibility = ["//tensorflow:__subpackages__"], + deps = [ + "//tensorflow:tensorflow_py", + ], +) + py_test( name = "retrain_test", size = "small", srcs = [ + "label_image.py", "retrain.py", "retrain_test.py", ], + data = [ + ":data/labels.txt", + "//tensorflow/examples/label_image:data/grace_hopper.jpg", + ], srcs_version = "PY2AND3", deps = [ ":retrain", diff --git a/tensorflow/examples/image_retraining/data/labels.txt b/tensorflow/examples/image_retraining/data/labels.txt new file mode 100644 index 0000000000..bc1131ac45 --- /dev/null +++ b/tensorflow/examples/image_retraining/data/labels.txt @@ -0,0 +1,3 @@ +Runner-up +Winner +Loser diff --git a/tensorflow/examples/image_retraining/label_image.py b/tensorflow/examples/image_retraining/label_image.py new file mode 100644 index 0000000000..ecfa672462 --- /dev/null +++ b/tensorflow/examples/image_retraining/label_image.py @@ -0,0 +1,147 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Simple image classification with Inception. + +Run image classification with your model. + +This script is usually used with retrain.py found in this same +directory. + +This program creates a graph from a saved GraphDef protocol buffer, +and runs inference on an input JPEG image. You are required +to pass in the graph file and the txt file. + +It outputs human readable strings of the top 5 predictions along with +their probabilities. + +Change the --image_file argument to any jpg image to compute a +classification of that image. + +Example usage: +python label_image.py --graph=retrained_graph.pb + --labels=retrained_labels.txt + --image=flower_photos/daisy/54377391_15648e8d18.jpg + +NOTE: To learn to use this file and retrain.py, please see: + +https://codelabs.developers.google.com/codelabs/tensorflow-for-poets +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import sys + +import tensorflow as tf + +parser = argparse.ArgumentParser() +parser.add_argument( + '--image', required=True, type=str, help='Absolute path to image file.') +parser.add_argument( + '--num_top_predictions', + type=int, + default=5, + help='Display this many predictions.') +parser.add_argument( + '--graph', + required=True, + type=str, + help='Absolute path to graph file (.pb)') +parser.add_argument( + '--labels', + required=True, + type=str, + help='Absolute path to labels file (.txt)') +parser.add_argument( + '--output_layer', + type=str, + default='final_result:0', + help='Name of the result operation') +parser.add_argument( + '--input_layer', + type=str, + default='DecodeJpeg/contents:0', + help='Name of the input operation') + + +def load_image(filename): + """Read in the image_data to be classified.""" + return tf.gfile.FastGFile(filename, 'rb').read() + + +def load_labels(filename): + """Read in labels, one label per line.""" + return [line.rstrip() for line in tf.gfile.GFile(filename)] + + +def load_graph(filename): + """Unpersists graph from file as default graph.""" + with tf.gfile.FastGFile(filename, 'rb') as f: + graph_def = tf.GraphDef() + graph_def.ParseFromString(f.read()) + tf.import_graph_def(graph_def, name='') + + +def run_graph(image_data, labels, input_layer_name, output_layer_name, + num_top_predictions): + with tf.Session() as sess: + # Feed the image_data as input to the graph. + # predictions will contain a two-dimensional array, where one + # dimension represents the input image count, and the other has + # predictions per class + softmax_tensor = sess.graph.get_tensor_by_name(output_layer_name) + predictions, = sess.run(softmax_tensor, {input_layer_name: image_data}) + + # Sort to show labels in order of confidence + top_k = predictions.argsort()[-num_top_predictions:][::-1] + for node_id in top_k: + human_string = labels[node_id] + score = predictions[node_id] + print('%s (score = %.5f)' % (human_string, score)) + + return 0 + + +def main(argv): + """Runs inference on an image.""" + if argv[1:]: + raise ValueError('Unused Command Line Args: %s' % argv[1:]) + + if not tf.gfile.Exists(FLAGS.image): + tf.logging.fatal('image file does not exist %s', FLAGS.image) + + if not tf.gfile.Exists(FLAGS.labels): + tf.logging.fatal('labels file does not exist %s', FLAGS.labels) + + if not tf.gfile.Exists(FLAGS.graph): + tf.logging.fatal('graph file does not exist %s', FLAGS.graph) + + # load image + image_data = load_image(FLAGS.image) + + # load labels + labels = load_labels(FLAGS.labels) + + # load graph, which is stored in the default session + load_graph(FLAGS.graph) + + run_graph(image_data, labels, FLAGS.input_layer, FLAGS.output_layer, + FLAGS.num_top_predictions) + + +if __name__ == '__main__': + FLAGS, unparsed = parser.parse_known_args() + tf.app.run(main=main, argv=sys.argv[:1]+unparsed) diff --git a/tensorflow/examples/image_retraining/retrain_test.py b/tensorflow/examples/image_retraining/retrain_test.py index 00ccea174f..8af5cc7114 100644 --- a/tensorflow/examples/image_retraining/retrain_test.py +++ b/tensorflow/examples/image_retraining/retrain_test.py @@ -19,7 +19,9 @@ from __future__ import division from __future__ import print_function import tensorflow as tf +import os +from tensorflow.examples.image_retraining import label_image from tensorflow.examples.image_retraining import retrain from tensorflow.python.framework import test_util @@ -81,5 +83,35 @@ class ImageRetrainingTest(test_util.TensorFlowTestCase): gt = tf.placeholder(tf.float32, [1], name='gt') self.assertIsNotNone(retrain.add_evaluation_step(final, gt)) + def testLabelImage(self): + + image_filename = ('../label_image/data/grace_hopper.jpg') + + # Load some default data + label_path = os.path.join(tf.resource_loader.get_data_files_path(), + 'data/labels.txt') + labels = label_image.load_labels(label_path) + self.assertEqual(len(labels), 3) + + image_path = os.path.join(tf.resource_loader.get_data_files_path(), + image_filename) + + image = label_image.load_image(image_path) + self.assertEqual(len(image), 61306) + + # Create trivial graph; note that the two nodes don't meet + with tf.Graph().as_default(): + jpeg = tf.constant(image) + # Input node that doesn't lead anywhere. + tf.image.decode_jpeg(jpeg, name='DecodeJpeg') + + # Output node, that always outputs a constant. + tf.constant([[10, 30, 5]], name='final') + + # As label_image outputs via print, we assume that + # if it returns, everything is OK. + result = label_image.run_graph(image, labels, jpeg, 'final:0', 3) + self.assertEqual(result, 0) + if __name__ == '__main__': tf.test.main() |