diff options
author | Koan-Sin Tan <koansin.tan@gmail.com> | 2018-06-04 12:02:37 +0800 |
---|---|---|
committer | Koan-Sin Tan <koansin.tan@gmail.com> | 2018-06-04 13:20:16 +0800 |
commit | fa0e2e361aa9b2ca4496b93cef5917f8c359d27d (patch) | |
tree | 442a04686502ccee44ffd7538014cc2c193a157e /tensorflow/contrib/lite/examples | |
parent | a8ae26ae1aa7a33b48cca8bf12c42ab7503a45cf (diff) |
[tflite] label_image for tflite in Python
With model (mobilenet_v1_1.0_224_quant.tflite), input image
(grace_hooper.bmp), and labels file (labels.txt) in /tmp.
Run
```
bazel run --config opt //tensorflow/contrib/lite/examples/python:label_image
```
We can get results like
```
0.470588: military uniform
0.337255: Windsor tie
0.047059: bow tie
0.031373: mortarboard
0.019608: suit
```
Diffstat (limited to 'tensorflow/contrib/lite/examples')
-rw-r--r-- | tensorflow/contrib/lite/examples/python/BUILD | 13 | ||||
-rw-r--r-- | tensorflow/contrib/lite/examples/python/label_image.py | 97 |
2 files changed, 110 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/examples/python/BUILD b/tensorflow/contrib/lite/examples/python/BUILD new file mode 100644 index 0000000000..d337c3ddc4 --- /dev/null +++ b/tensorflow/contrib/lite/examples/python/BUILD @@ -0,0 +1,13 @@ +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//tensorflow:internal"]) + +py_binary( + name = "label_image", + srcs = ["label_image.py"], + main = "label_image.py", + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/lite/python:lite", + ], +) diff --git a/tensorflow/contrib/lite/examples/python/label_image.py b/tensorflow/contrib/lite/examples/python/label_image.py new file mode 100644 index 0000000000..77e0d7cb4d --- /dev/null +++ b/tensorflow/contrib/lite/examples/python/label_image.py @@ -0,0 +1,97 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""label_image for tflite""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import numpy as np + +from PIL import Image + +from tensorflow.contrib.lite.python import interpreter as interpreter_wrapper + +def load_labels(filename): + my_labels = [] + input_file = open(filename, 'r') + for l in input_file: + my_labels.append(l.strip()) + return my_labels + +if __name__ == "__main__": + file_name = "/tmp/grace_hopper.bmp" + model_file = "/tmp/mobilenet_v1_1.0_224_quant.tflite" + label_file = "/tmp/labels.txt" + input_mean = 127.5 + input_std = 127.5 + floating_model = False + + parser = argparse.ArgumentParser() + parser.add_argument("--image", help="image to be classified") + parser.add_argument("--graph", help=".tflite model to be executed") + parser.add_argument("--labels", help="name of file containing labels") + parser.add_argument("--input_mean", help="input_mean") + parser.add_argument("--input_std", help="input standard deviation") + args = parser.parse_args() + + if args.graph: + model_file = args.graph + if args.image: + file_name = args.image + if args.labels: + label_file = args.labels + if args.input_mean: + input_mean = args.input_mean + if args.input_std: + input_std = args.input_std + + interpreter = interpreter_wrapper.Interpreter(model_path=model_file) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + output_details = interpreter.get_output_details() + + # check the type of the input tensor + if input_details[0]['dtype'] == type(np.float32(1.0)): + floating_model = True + + # NxHxWxC, H:1, W:2 + height = input_details[0]['shape'][1] + width = input_details[0]['shape'][2] + img = Image.open(file_name) + img = img.resize((width, height)) + + # add N dim + input_data = np.expand_dims(img, axis=0) + + if floating_model: + input_data = (np.float32(input_data) - input_mean) / input_std + + interpreter.set_tensor(input_details[0]['index'], input_data) + + interpreter.invoke() + + output_data = interpreter.get_tensor(output_details[0]['index']) + results = np.squeeze(output_data) + + top_k = results.argsort()[-5:][::-1] + labels = load_labels(label_file) + for i in top_k: + if floating_model: + print('{0:08.6f}'.format(float(results[i]))+":", labels[i]) + else: + print('{0:08.6f}'.format(float(results[i]/255.0))+":", labels[i]) |