aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar freedom" Koan-Sin Tan <koansin.tan@gmail.com>2018-01-26 06:24:23 +0800
committerGravatar Rasmus Munk Larsen <rmlarsen@google.com>2018-01-25 14:24:23 -0800
commit119b993e00a2a138d1ef2f4886e39ca528bad1c3 (patch)
tree487e8d3fdb5c13ff4860b21ac59800d3484b38cd
parentdce14f5d38a19b3a19c834a877acc0042f1e9fbd (diff)
make label_image for tflite build again (#16206)
* make label_image for tflite build again 1. add namespace to label_image.h to make label_image for tflite build again 2. add --config monolithic and mention NDK settings in label_image.md 3. fix a typo in display_usage() * relies on tensor type info to switch types use tensor types of input and output tensors instead of command line flag from user * reformatted according to review
-rw-r--r--tensorflow/contrib/lite/examples/label_image/label_image.cc55
-rw-r--r--tensorflow/contrib/lite/examples/label_image/label_image.h7
-rw-r--r--tensorflow/contrib/lite/examples/label_image/label_image.md12
3 files changed, 46 insertions, 28 deletions
diff --git a/tensorflow/contrib/lite/examples/label_image/label_image.cc b/tensorflow/contrib/lite/examples/label_image/label_image.cc
index 4d2e1ce0bc..d7f49ad875 100644
--- a/tensorflow/contrib/lite/examples/label_image/label_image.cc
+++ b/tensorflow/contrib/lite/examples/label_image/label_image.cc
@@ -148,14 +148,22 @@ void RunInference(Settings* s) {
int wanted_width = dims->data[2];
int wanted_channels = dims->data[3];
- if (s->input_floating) {
- downsize<float>(interpreter->typed_tensor<float>(input), in, image_height,
- image_width, image_channels, wanted_height, wanted_width,
- wanted_channels, s);
- } else {
- downsize<uint8_t>(interpreter->typed_tensor<uint8_t>(input), in,
- image_height, image_width, image_channels, wanted_height,
- wanted_width, wanted_channels, s);
+ switch (interpreter->tensor(input)->type) {
+ case kTfLiteFloat32:
+ s->input_floating = true;
+ downsize<float>(interpreter->typed_tensor<float>(input), in,
+ image_height, image_width, image_channels,
+ wanted_height, wanted_width, wanted_channels, s);
+ break;
+ case kTfLiteUInt8:
+ downsize<uint8_t>(interpreter->typed_tensor<uint8_t>(input), in,
+ image_height, image_width, image_channels,
+ wanted_height, wanted_width, wanted_channels, s);
+ break;
+ default:
+ LOG(FATAL) << "cannot handle input type "
+ << interpreter->tensor(input)->type << " yet";
+ exit(-1);
}
struct timeval start_time, stop_time;
@@ -177,13 +185,22 @@ void RunInference(Settings* s) {
std::vector<std::pair<float, int>> top_results;
- if (s->input_floating) {
- get_top_n<float>(interpreter->typed_output_tensor<float>(0), output_size,
- num_results, threshold, &top_results, s->input_floating);
- } else {
- get_top_n<uint8_t>(interpreter->typed_output_tensor<uint8_t>(0),
+ int output = interpreter->outputs()[0];
+ switch (interpreter->tensor(output)->type) {
+ case kTfLiteFloat32:
+ get_top_n<float>(interpreter->typed_output_tensor<float>(0),
output_size, num_results, threshold, &top_results,
- s->input_floating);
+ true);
+ break;
+ case kTfLiteUInt8:
+ get_top_n<uint8_t>(interpreter->typed_output_tensor<uint8_t>(0),
+ output_size, num_results, threshold, &top_results,
+ false);
+ break;
+ default:
+ LOG(FATAL) << "cannot handle output type "
+ << interpreter->tensor(input)->type << " yet";
+ exit(-1);
}
std::vector<string> labels;
@@ -203,13 +220,11 @@ void display_usage() {
LOG(INFO) << "label_image\n"
<< "--accelerated, -a: [0|1], use Android NNAPI or note\n"
<< "--count, -c: loop interpreter->Invoke() for certain times\n"
- << "--input_floating, -f: [0|1] type of input layer is floating "
- "point numbers\n"
<< "--input_mean, -b: input mean\n"
<< "--input_std, -s: input standard deviation\n"
<< "--image, -i: image_name.bmp\n"
<< "--labels, -l: labels for the model\n"
- << "--tflite_mode, -m: model_name.tflite\n"
+ << "--tflite_model, -m: model_name.tflite\n"
<< "--threads, -t: number of threads\n"
<< "--verbose, -v: [0|1] print more information\n"
<< "\n";
@@ -223,7 +238,6 @@ int Main(int argc, char** argv) {
static struct option long_options[] = {
{"accelerated", required_argument, 0, 'a'},
{"count", required_argument, 0, 'c'},
- {"input_floating", required_argument, 0, 'f'},
{"verbose", required_argument, 0, 'v'},
{"image", required_argument, 0, 'i'},
{"labels", required_argument, 0, 'l'},
@@ -254,11 +268,6 @@ int Main(int argc, char** argv) {
s.loop_count = strtol( // NOLINT(runtime/deprecated_fn)
optarg, (char**)NULL, 10);
break;
- case 'f':
- s.input_floating = strtol( // NOLINT(runtime/deprecated_fn)
- optarg, (char**)NULL, 10);
- s.input_layer_type = "float";
- break;
case 'i':
s.input_bmp_name = optarg;
break;
diff --git a/tensorflow/contrib/lite/examples/label_image/label_image.h b/tensorflow/contrib/lite/examples/label_image/label_image.h
index ce98e06fc1..4de32e33fb 100644
--- a/tensorflow/contrib/lite/examples/label_image/label_image.h
+++ b/tensorflow/contrib/lite/examples/label_image/label_image.h
@@ -16,9 +16,11 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_LABEL_IMAGE_H
#define TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_LABEL_IMAGE_H
-#include <string>
#include "tensorflow/contrib/lite/string.h"
+namespace tflite {
+namespace label_image {
+
struct Settings {
bool verbose = false;
bool accel = false;
@@ -33,4 +35,7 @@ struct Settings {
int number_of_threads = 4;
};
+} // namespace label_image
+} // namespace tflite
+
#endif // TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_LABEL_IMAGE_H
diff --git a/tensorflow/contrib/lite/examples/label_image/label_image.md b/tensorflow/contrib/lite/examples/label_image/label_image.md
index d6019d673f..9ce32cf101 100644
--- a/tensorflow/contrib/lite/examples/label_image/label_image.md
+++ b/tensorflow/contrib/lite/examples/label_image/label_image.md
@@ -1,8 +1,12 @@
label_image for TensorFlow Lite inspired by TensorFlow's label_image.
+
+To build label_image for Android, run $TENSORFLOW_ROOT/configure
+and set Android NDK or configure NDK setting in
+$TENSORFLOW_ROOT/WORKSPACE first.
To build it for android ARMv8:
```
-> bazel build --cxxopt=-std=c++11 \
+> bazel build --config monolithic --cxxopt=-std=c++11 \
--crosstool_top=//external:android/crosstool \
--host_crosstool_top=@bazel_tools//tools/cpp:toolchain \
--cpu=arm64-v8a \
@@ -10,13 +14,13 @@ To build it for android ARMv8:
```
or
```
-> bazel build --config android_arm64 --cxxopt=-std=c++11 \
+> bazel build --config android_arm64 --config monolithic --cxxopt=-std=c++11 \
//tensorflow/contrib/lite/examples/label_image:label_image
```
To build it for android arm-v7a:
```
-> bazel build --cxxopt=-std=c++11 \
+> bazel build --config monolithic --cxxopt=-std=c++11 \
--crosstool_top=//external:android/crosstool \
--host_crosstool_top=@bazel_tools//tools/cpp:toolchain \
--cpu=armeabi-v7a \
@@ -24,7 +28,7 @@ To build it for android arm-v7a:
```
or
```
-> bazel build --config android_arm --cxxopt=-std=c++11 \
+> bazel build --config android_arm --config monolithic --cxxopt=-std=c++11 \
//tensorflow/contrib/lite/examples/label_image:label_image
```