aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/examples
diff options
context:
space:
mode:
authorGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-09 09:47:59 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-09 09:47:59 -0700
commitb065a6cfd1fbdc77cff13c2b3b83fe018df8966f (patch)
treeb8116e57b96d3a6ec9a585f2be14a2290a6430fe /tensorflow/contrib/lite/examples
parente51791dd3bfe80a17b78780b620f9832b1b62474 (diff)
parent57dacd87afc9d6e30bb11480deccf5481f8d3bc3 (diff)
Merge pull request #19736 from freedomtan:label_image_tflite_py
PiperOrigin-RevId: 208062989
Diffstat (limited to 'tensorflow/contrib/lite/examples')
-rw-r--r--tensorflow/contrib/lite/examples/python/BUILD13
-rw-r--r--tensorflow/contrib/lite/examples/python/label_image.md50
-rw-r--r--tensorflow/contrib/lite/examples/python/label_image.py86
3 files changed, 149 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/examples/python/BUILD b/tensorflow/contrib/lite/examples/python/BUILD
new file mode 100644
index 0000000000..d337c3ddc4
--- /dev/null
+++ b/tensorflow/contrib/lite/examples/python/BUILD
@@ -0,0 +1,13 @@
+licenses(["notice"]) # Apache 2.0
+
+package(default_visibility = ["//tensorflow:internal"])
+
+py_binary(
+ name = "label_image",
+ srcs = ["label_image.py"],
+ main = "label_image.py",
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/contrib/lite/python:lite",
+ ],
+)
diff --git a/tensorflow/contrib/lite/examples/python/label_image.md b/tensorflow/contrib/lite/examples/python/label_image.md
new file mode 100644
index 0000000000..e81192a96c
--- /dev/null
+++ b/tensorflow/contrib/lite/examples/python/label_image.md
@@ -0,0 +1,50 @@
+
+With model, input image (grace_hopper.bmp), and labels file (labels.txt)
+in /tmp.
+
+The example input image and labels file are from TensorFlow repo and
+MobileNet V1 model files.
+
+```
+curl https://raw.githubusercontent.com/tensorflow/tensorflow/master/tensorflow/contrib/lite/examples/label_image/testdata/grace_hopper.bmp > /tmp/grace_hopper.bmp
+
+curl https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_1.0_224_frozen.tgz | tar xzv -C /tmp mobilenet_v1_1.0_224/labels.txt
+mv /tmp/mobilenet_v1_1.0_224/labels.txt /tmp/
+
+```
+
+Run
+
+```
+curl http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_224_quant.tgz | tar xzv -C /tmp
+bazel run --config opt //tensorflow/contrib/lite/examples/python:label_image
+```
+
+We can get results like
+
+```
+0.470588: military uniform
+0.337255: Windsor tie
+0.047059: bow tie
+0.031373: mortarboard
+0.019608: suit
+```
+
+Run
+
+```
+curl http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_224.tgz | tar xzv -C /tmp
+bazel run --config opt //tensorflow/contrib/lite/examples/python:label_image \
+-- --model_file /tmp/mobilenet_v1_1.0_224.tflite
+```
+
+We can get results like
+```
+0.728693: military uniform
+0.116163: Windsor tie
+0.035517: bow tie
+0.014874: mortarboard
+0.011758: bolo tie
+```
+
+Check [models](../../g3doc/models.md) for models hosted by Google.
diff --git a/tensorflow/contrib/lite/examples/python/label_image.py b/tensorflow/contrib/lite/examples/python/label_image.py
new file mode 100644
index 0000000000..282118a1d2
--- /dev/null
+++ b/tensorflow/contrib/lite/examples/python/label_image.py
@@ -0,0 +1,86 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""label_image for tflite"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import numpy as np
+
+from PIL import Image
+
+from tensorflow.contrib.lite.python import interpreter as interpreter_wrapper
+
+def load_labels(filename):
+ my_labels = []
+ input_file = open(filename, 'r')
+ for l in input_file:
+ my_labels.append(l.strip())
+ return my_labels
+
+if __name__ == "__main__":
+ floating_model = False
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument("-i", "--image", default="/tmp/grace_hopper.bmp", \
+ help="image to be classified")
+ parser.add_argument("-m", "--model_file", \
+ default="/tmp/mobilenet_v1_1.0_224_quant.tflite", \
+ help=".tflite model to be executed")
+ parser.add_argument("-l", "--label_file", default="/tmp/labels.txt", \
+ help="name of file containing labels")
+ parser.add_argument("--input_mean", default=127.5, help="input_mean")
+ parser.add_argument("--input_std", default=127.5, \
+ help="input standard deviation")
+ args = parser.parse_args()
+
+ interpreter = interpreter_wrapper.Interpreter(model_path=args.model_file)
+ interpreter.allocate_tensors()
+
+ input_details = interpreter.get_input_details()
+ output_details = interpreter.get_output_details()
+
+ # check the type of the input tensor
+ if input_details[0]['dtype'] == np.float32:
+ floating_model = True
+
+ # NxHxWxC, H:1, W:2
+ height = input_details[0]['shape'][1]
+ width = input_details[0]['shape'][2]
+ img = Image.open(args.image)
+ img = img.resize((width, height))
+
+ # add N dim
+ input_data = np.expand_dims(img, axis=0)
+
+ if floating_model:
+ input_data = (np.float32(input_data) - args.input_mean) / args.input_std
+
+ interpreter.set_tensor(input_details[0]['index'], input_data)
+
+ interpreter.invoke()
+
+ output_data = interpreter.get_tensor(output_details[0]['index'])
+ results = np.squeeze(output_data)
+
+ top_k = results.argsort()[-5:][::-1]
+ labels = load_labels(args.label_file)
+ for i in top_k:
+ if floating_model:
+ print('{0:08.6f}'.format(float(results[i]))+":", labels[i])
+ else:
+ print('{0:08.6f}'.format(float(results[i]/255.0))+":", labels[i])