aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/examples/label_image
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-06-27 16:33:00 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-27 16:37:09 -0700
commit50b999a8336d19400ab75aea66fe46eca2f5fe0b (patch)
tree7cba4f4af6b131c253b65ff9f2923e851184668c /tensorflow/examples/label_image
parentd6d58a3a1785785679af56c0f8f131e7312b8226 (diff)
Merge changes from github.
PiperOrigin-RevId: 160344052
Diffstat (limited to 'tensorflow/examples/label_image')
-rw-r--r--tensorflow/examples/label_image/README.md23
-rw-r--r--tensorflow/examples/label_image/label_image.py132
2 files changed, 154 insertions, 1 deletions
diff --git a/tensorflow/examples/label_image/README.md b/tensorflow/examples/label_image/README.md
index 1103caf586..c5857f394a 100644
--- a/tensorflow/examples/label_image/README.md
+++ b/tensorflow/examples/label_image/README.md
@@ -1,4 +1,4 @@
-# TensorFlow C++ Image Recognition Demo
+# TensorFlow C++ and Python Image Recognition Demo
This example shows how you can load a pre-trained TensorFlow network and use it
to recognize objects in images in C++. For Java see the [Java
@@ -64,3 +64,24 @@ $ bazel-bin/tensorflow/examples/label_image/label_image --image=my_image.png
For a more detailed look at this code, you can check out the C++ section of the
[Inception tutorial](https://tensorflow.org/tutorials/image_recognition/).
+
+## Python implementation
+
+label_image.py is a python implementation that provides code corresponding
+to the C++ code here. This gives more intuitive mapping between C++ and
+Python than the Python code mentioned in the
+[Inception tutorial](https://tensorflow.org/tutorials/image_recognition/).
+and could be easier to add visualization or debug code.
+
+With tensorflow python package installed, you can run it like:
+```bash
+$ python3 tensorflow/examples/label_image/label_image.py
+```
+And get result similar to this:
+```
+military uniform 0.834305
+mortarboard 0.0218694
+academic gown 0.0103581
+pickelhaube 0.00800818
+bulletproof vest 0.0053509
+```
diff --git a/tensorflow/examples/label_image/label_image.py b/tensorflow/examples/label_image/label_image.py
new file mode 100644
index 0000000000..39d0981337
--- /dev/null
+++ b/tensorflow/examples/label_image/label_image.py
@@ -0,0 +1,132 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import sys
+
+import numpy as np
+import tensorflow as tf
+
+def load_graph(model_file):
+ graph = tf.Graph()
+ graph_def = tf.GraphDef()
+
+ with open(model_file, "rb") as f:
+ graph_def.ParseFromString(f.read())
+ with graph.as_default():
+ tf.import_graph_def(graph_def)
+
+ return graph
+
+def read_tensor_from_image_file(file_name, input_height=299, input_width=299,
+ input_mean=0, input_std=255):
+ input_name = "file_reader"
+ output_name = "normalized"
+ file_reader = tf.read_file(file_name, input_name)
+ if file_name.endswith(".png"):
+ image_reader = tf.image.decode_png(file_reader, channels = 3,
+ name='png_reader')
+ elif file_name.endswith(".gif"):
+ image_reader = tf.squeeze(tf.image.decode_gif(file_reader,
+ name='gif_reader'))
+ elif file_name.endswith(".bmp"):
+ image_reader = tf.image.decode_bmp(file_reader, name='bmp_reader')
+ else:
+ image_reader = tf.image.decode_jpeg(file_reader, channels = 3,
+ name='jpeg_reader')
+ float_caster = tf.cast(image_reader, tf.float32)
+ dims_expander = tf.expand_dims(float_caster, 0);
+ resized = tf.image.resize_bilinear(dims_expander, [input_height, input_width])
+ normalized = tf.divide(tf.subtract(resized, [input_mean]), [input_std])
+ sess = tf.Session()
+ result = sess.run(normalized)
+
+ return result
+
+def load_labels(label_file):
+ label = []
+ proto_as_ascii_lines = tf.gfile.GFile(label_file).readlines()
+ for l in proto_as_ascii_lines:
+ label.append(l.rstrip())
+ return label
+
+if __name__ == "__main__":
+ file_name = "tensorflow/examples/label_image/data/grace_hopper.jpg"
+ model_file = \
+ "tensorflow/examples/label_image/data/inception_v3_2016_08_28_frozen.pb"
+ label_file = "tensorflow/examples/label_image/data/imagenet_slim_labels.txt"
+ input_height = 299
+ input_width = 299
+ input_mean = 0
+ input_std = 255
+ input_layer = "input"
+ output_layer = "InceptionV3/Predictions/Reshape_1"
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--image", help="image to be processed")
+ parser.add_argument("--graph", help="graph/model to be executed")
+ parser.add_argument("--labels", help="name of file containing labels")
+ parser.add_argument("--input_height", type=int, help="input height")
+ parser.add_argument("--input_width", type=int, help="input width")
+ parser.add_argument("--input_mean", type=int, help="input mean")
+ parser.add_argument("--input_std", type=int, help="input std")
+ parser.add_argument("--input_layer", help="name of input layer")
+ parser.add_argument("--output_layer", help="name of output layer")
+ args = parser.parse_args()
+
+ if args.graph:
+ model_file = args.graph
+ if args.image:
+ file_name = args.image
+ if args.labels:
+ label_file = args.labels
+ if args.input_height:
+ input_height = args.input_height
+ if args.input_width:
+ input_width = args.input_width
+ if args.input_mean:
+ input_mean = args.input_mean
+ if args.input_std:
+ input_std = args.input_std
+ if args.input_layer:
+ input_layer = args.input_layer
+ if args.output_layer:
+ output_layer = args.output_layer
+
+ graph = load_graph(model_file)
+ t = read_tensor_from_image_file(file_name,
+ input_height=input_height,
+ input_width=input_width,
+ input_mean=input_mean,
+ input_std=input_std)
+
+ input_name = "import/" + input_layer
+ output_name = "import/" + output_layer
+ input_operation = graph.get_operation_by_name(input_name);
+ output_operation = graph.get_operation_by_name(output_name);
+
+ with tf.Session(graph=graph) as sess:
+ results = sess.run(output_operation.outputs[0],
+ {input_operation.outputs[0]: t})
+ results = np.squeeze(results)
+
+ top_k = results.argsort()[-5:][::-1]
+ labels = load_labels(label_file)
+ for i in top_k:
+ print(labels[i], results[i])