aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/examples
diff options
context:
space:
mode:
authorGravatar Koan-Sin Tan <koansin.tan@gmail.com>2018-06-04 12:02:37 +0800
committerGravatar Koan-Sin Tan <koansin.tan@gmail.com>2018-06-04 13:20:16 +0800
commitfa0e2e361aa9b2ca4496b93cef5917f8c359d27d (patch)
tree442a04686502ccee44ffd7538014cc2c193a157e /tensorflow/contrib/lite/examples
parenta8ae26ae1aa7a33b48cca8bf12c42ab7503a45cf (diff)
[tflite] label_image for tflite in Python
With model (mobilenet_v1_1.0_224_quant.tflite), input image (grace_hooper.bmp), and labels file (labels.txt) in /tmp. Run ``` bazel run --config opt //tensorflow/contrib/lite/examples/python:label_image ``` We can get results like ``` 0.470588: military uniform 0.337255: Windsor tie 0.047059: bow tie 0.031373: mortarboard 0.019608: suit ```
Diffstat (limited to 'tensorflow/contrib/lite/examples')
-rw-r--r--tensorflow/contrib/lite/examples/python/BUILD13
-rw-r--r--tensorflow/contrib/lite/examples/python/label_image.py97
2 files changed, 110 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/examples/python/BUILD b/tensorflow/contrib/lite/examples/python/BUILD
new file mode 100644
index 0000000000..d337c3ddc4
--- /dev/null
+++ b/tensorflow/contrib/lite/examples/python/BUILD
@@ -0,0 +1,13 @@
+licenses(["notice"]) # Apache 2.0
+
+package(default_visibility = ["//tensorflow:internal"])
+
+py_binary(
+ name = "label_image",
+ srcs = ["label_image.py"],
+ main = "label_image.py",
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/contrib/lite/python:lite",
+ ],
+)
diff --git a/tensorflow/contrib/lite/examples/python/label_image.py b/tensorflow/contrib/lite/examples/python/label_image.py
new file mode 100644
index 0000000000..77e0d7cb4d
--- /dev/null
+++ b/tensorflow/contrib/lite/examples/python/label_image.py
@@ -0,0 +1,97 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""label_image for tflite"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import numpy as np
+
+from PIL import Image
+
+from tensorflow.contrib.lite.python import interpreter as interpreter_wrapper
+
+def load_labels(filename):
+ my_labels = []
+ input_file = open(filename, 'r')
+ for l in input_file:
+ my_labels.append(l.strip())
+ return my_labels
+
+if __name__ == "__main__":
+ file_name = "/tmp/grace_hopper.bmp"
+ model_file = "/tmp/mobilenet_v1_1.0_224_quant.tflite"
+ label_file = "/tmp/labels.txt"
+ input_mean = 127.5
+ input_std = 127.5
+ floating_model = False
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--image", help="image to be classified")
+ parser.add_argument("--graph", help=".tflite model to be executed")
+ parser.add_argument("--labels", help="name of file containing labels")
+ parser.add_argument("--input_mean", help="input_mean")
+ parser.add_argument("--input_std", help="input standard deviation")
+ args = parser.parse_args()
+
+ if args.graph:
+ model_file = args.graph
+ if args.image:
+ file_name = args.image
+ if args.labels:
+ label_file = args.labels
+ if args.input_mean:
+ input_mean = args.input_mean
+ if args.input_std:
+ input_std = args.input_std
+
+ interpreter = interpreter_wrapper.Interpreter(model_path=model_file)
+ interpreter.allocate_tensors()
+
+ input_details = interpreter.get_input_details()
+ output_details = interpreter.get_output_details()
+
+ # check the type of the input tensor
+ if input_details[0]['dtype'] == type(np.float32(1.0)):
+ floating_model = True
+
+ # NxHxWxC, H:1, W:2
+ height = input_details[0]['shape'][1]
+ width = input_details[0]['shape'][2]
+ img = Image.open(file_name)
+ img = img.resize((width, height))
+
+ # add N dim
+ input_data = np.expand_dims(img, axis=0)
+
+ if floating_model:
+ input_data = (np.float32(input_data) - input_mean) / input_std
+
+ interpreter.set_tensor(input_details[0]['index'], input_data)
+
+ interpreter.invoke()
+
+ output_data = interpreter.get_tensor(output_details[0]['index'])
+ results = np.squeeze(output_data)
+
+ top_k = results.argsort()[-5:][::-1]
+ labels = load_labels(label_file)
+ for i in top_k:
+ if floating_model:
+ print('{0:08.6f}'.format(float(results[i]))+":", labels[i])
+ else:
+ print('{0:08.6f}'.format(float(results[i]/255.0))+":", labels[i])