aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/examples/image_retraining
diff options
context:
space:
mode:
authorGravatar Mark Daoust <markdaoust@google.com>2017-05-18 13:05:48 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-05-18 13:10:01 -0700
commit315f195f728d59647f23f10ecc664b8ce9466a06 (patch)
tree350fb890c4100400e56df29f04088fd456d85e0f /tensorflow/examples/image_retraining
parentf41a9113bcac0e144d214cb101a991ca8fbff963 (diff)
Add label_image.py that works with retrain.py.
PiperOrigin-RevId: 156468713
Diffstat (limited to 'tensorflow/examples/image_retraining')
-rw-r--r--tensorflow/examples/image_retraining/BUILD17
-rw-r--r--tensorflow/examples/image_retraining/data/labels.txt3
-rw-r--r--tensorflow/examples/image_retraining/label_image.py147
-rw-r--r--tensorflow/examples/image_retraining/retrain_test.py32
4 files changed, 199 insertions, 0 deletions
diff --git a/tensorflow/examples/image_retraining/BUILD b/tensorflow/examples/image_retraining/BUILD
index 5a885e33c2..394c413b33 100644
--- a/tensorflow/examples/image_retraining/BUILD
+++ b/tensorflow/examples/image_retraining/BUILD
@@ -24,13 +24,30 @@ py_binary(
],
)
+py_binary(
+ name = "label_image",
+ srcs = [
+ "label_image.py",
+ ],
+ srcs_version = "PY2AND3",
+ visibility = ["//tensorflow:__subpackages__"],
+ deps = [
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
py_test(
name = "retrain_test",
size = "small",
srcs = [
+ "label_image.py",
"retrain.py",
"retrain_test.py",
],
+ data = [
+ ":data/labels.txt",
+ "//tensorflow/examples/label_image:data/grace_hopper.jpg",
+ ],
srcs_version = "PY2AND3",
deps = [
":retrain",
diff --git a/tensorflow/examples/image_retraining/data/labels.txt b/tensorflow/examples/image_retraining/data/labels.txt
new file mode 100644
index 0000000000..bc1131ac45
--- /dev/null
+++ b/tensorflow/examples/image_retraining/data/labels.txt
@@ -0,0 +1,3 @@
+Runner-up
+Winner
+Loser
diff --git a/tensorflow/examples/image_retraining/label_image.py b/tensorflow/examples/image_retraining/label_image.py
new file mode 100644
index 0000000000..ecfa672462
--- /dev/null
+++ b/tensorflow/examples/image_retraining/label_image.py
@@ -0,0 +1,147 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Simple image classification with Inception.
+
+Run image classification with your model.
+
+This script is usually used with retrain.py found in this same
+directory.
+
+This program creates a graph from a saved GraphDef protocol buffer,
+and runs inference on an input JPEG image. You are required
+to pass in the graph file and the txt file.
+
+It outputs human readable strings of the top 5 predictions along with
+their probabilities.
+
+Change the --image_file argument to any jpg image to compute a
+classification of that image.
+
+Example usage:
+python label_image.py --graph=retrained_graph.pb
+ --labels=retrained_labels.txt
+ --image=flower_photos/daisy/54377391_15648e8d18.jpg
+
+NOTE: To learn to use this file and retrain.py, please see:
+
+https://codelabs.developers.google.com/codelabs/tensorflow-for-poets
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import sys
+
+import tensorflow as tf
+
+parser = argparse.ArgumentParser()
+parser.add_argument(
+ '--image', required=True, type=str, help='Absolute path to image file.')
+parser.add_argument(
+ '--num_top_predictions',
+ type=int,
+ default=5,
+ help='Display this many predictions.')
+parser.add_argument(
+ '--graph',
+ required=True,
+ type=str,
+ help='Absolute path to graph file (.pb)')
+parser.add_argument(
+ '--labels',
+ required=True,
+ type=str,
+ help='Absolute path to labels file (.txt)')
+parser.add_argument(
+ '--output_layer',
+ type=str,
+ default='final_result:0',
+ help='Name of the result operation')
+parser.add_argument(
+ '--input_layer',
+ type=str,
+ default='DecodeJpeg/contents:0',
+ help='Name of the input operation')
+
+
+def load_image(filename):
+ """Read in the image_data to be classified."""
+ return tf.gfile.FastGFile(filename, 'rb').read()
+
+
+def load_labels(filename):
+ """Read in labels, one label per line."""
+ return [line.rstrip() for line in tf.gfile.GFile(filename)]
+
+
+def load_graph(filename):
+ """Unpersists graph from file as default graph."""
+ with tf.gfile.FastGFile(filename, 'rb') as f:
+ graph_def = tf.GraphDef()
+ graph_def.ParseFromString(f.read())
+ tf.import_graph_def(graph_def, name='')
+
+
+def run_graph(image_data, labels, input_layer_name, output_layer_name,
+ num_top_predictions):
+ with tf.Session() as sess:
+ # Feed the image_data as input to the graph.
+ # predictions will contain a two-dimensional array, where one
+ # dimension represents the input image count, and the other has
+ # predictions per class
+ softmax_tensor = sess.graph.get_tensor_by_name(output_layer_name)
+ predictions, = sess.run(softmax_tensor, {input_layer_name: image_data})
+
+ # Sort to show labels in order of confidence
+ top_k = predictions.argsort()[-num_top_predictions:][::-1]
+ for node_id in top_k:
+ human_string = labels[node_id]
+ score = predictions[node_id]
+ print('%s (score = %.5f)' % (human_string, score))
+
+ return 0
+
+
+def main(argv):
+ """Runs inference on an image."""
+ if argv[1:]:
+ raise ValueError('Unused Command Line Args: %s' % argv[1:])
+
+ if not tf.gfile.Exists(FLAGS.image):
+ tf.logging.fatal('image file does not exist %s', FLAGS.image)
+
+ if not tf.gfile.Exists(FLAGS.labels):
+ tf.logging.fatal('labels file does not exist %s', FLAGS.labels)
+
+ if not tf.gfile.Exists(FLAGS.graph):
+ tf.logging.fatal('graph file does not exist %s', FLAGS.graph)
+
+ # load image
+ image_data = load_image(FLAGS.image)
+
+ # load labels
+ labels = load_labels(FLAGS.labels)
+
+ # load graph, which is stored in the default session
+ load_graph(FLAGS.graph)
+
+ run_graph(image_data, labels, FLAGS.input_layer, FLAGS.output_layer,
+ FLAGS.num_top_predictions)
+
+
+if __name__ == '__main__':
+ FLAGS, unparsed = parser.parse_known_args()
+ tf.app.run(main=main, argv=sys.argv[:1]+unparsed)
diff --git a/tensorflow/examples/image_retraining/retrain_test.py b/tensorflow/examples/image_retraining/retrain_test.py
index 00ccea174f..8af5cc7114 100644
--- a/tensorflow/examples/image_retraining/retrain_test.py
+++ b/tensorflow/examples/image_retraining/retrain_test.py
@@ -19,7 +19,9 @@ from __future__ import division
from __future__ import print_function
import tensorflow as tf
+import os
+from tensorflow.examples.image_retraining import label_image
from tensorflow.examples.image_retraining import retrain
from tensorflow.python.framework import test_util
@@ -81,5 +83,35 @@ class ImageRetrainingTest(test_util.TensorFlowTestCase):
gt = tf.placeholder(tf.float32, [1], name='gt')
self.assertIsNotNone(retrain.add_evaluation_step(final, gt))
+ def testLabelImage(self):
+
+ image_filename = ('../label_image/data/grace_hopper.jpg')
+
+ # Load some default data
+ label_path = os.path.join(tf.resource_loader.get_data_files_path(),
+ 'data/labels.txt')
+ labels = label_image.load_labels(label_path)
+ self.assertEqual(len(labels), 3)
+
+ image_path = os.path.join(tf.resource_loader.get_data_files_path(),
+ image_filename)
+
+ image = label_image.load_image(image_path)
+ self.assertEqual(len(image), 61306)
+
+ # Create trivial graph; note that the two nodes don't meet
+ with tf.Graph().as_default():
+ jpeg = tf.constant(image)
+ # Input node that doesn't lead anywhere.
+ tf.image.decode_jpeg(jpeg, name='DecodeJpeg')
+
+ # Output node, that always outputs a constant.
+ tf.constant([[10, 30, 5]], name='final')
+
+ # As label_image outputs via print, we assume that
+ # if it returns, everything is OK.
+ result = label_image.run_graph(image, labels, jpeg, 'final:0', 3)
+ self.assertEqual(result, 0)
+
if __name__ == '__main__':
tf.test.main()