diff options
author | 2018-01-26 06:24:23 +0800 | |
---|---|---|
committer | 2018-01-25 14:24:23 -0800 | |
commit | 119b993e00a2a138d1ef2f4886e39ca528bad1c3 (patch) | |
tree | 487e8d3fdb5c13ff4860b21ac59800d3484b38cd | |
parent | dce14f5d38a19b3a19c834a877acc0042f1e9fbd (diff) |
make label_image for tflite build again (#16206)
* make label_image for tflite build again
1. add namespace to label_image.h to make label_image for tflite build again
2. add --config monolithic and mention NDK settings in label_image.md
3. fix a typo in display_usage()
* relies on tensor type info to switch types
use tensor types of input and output tensors instead of command
line flag from user
* reformatted according to review
3 files changed, 46 insertions, 28 deletions
diff --git a/tensorflow/contrib/lite/examples/label_image/label_image.cc b/tensorflow/contrib/lite/examples/label_image/label_image.cc index 4d2e1ce0bc..d7f49ad875 100644 --- a/tensorflow/contrib/lite/examples/label_image/label_image.cc +++ b/tensorflow/contrib/lite/examples/label_image/label_image.cc @@ -148,14 +148,22 @@ void RunInference(Settings* s) { int wanted_width = dims->data[2]; int wanted_channels = dims->data[3]; - if (s->input_floating) { - downsize<float>(interpreter->typed_tensor<float>(input), in, image_height, - image_width, image_channels, wanted_height, wanted_width, - wanted_channels, s); - } else { - downsize<uint8_t>(interpreter->typed_tensor<uint8_t>(input), in, - image_height, image_width, image_channels, wanted_height, - wanted_width, wanted_channels, s); + switch (interpreter->tensor(input)->type) { + case kTfLiteFloat32: + s->input_floating = true; + downsize<float>(interpreter->typed_tensor<float>(input), in, + image_height, image_width, image_channels, + wanted_height, wanted_width, wanted_channels, s); + break; + case kTfLiteUInt8: + downsize<uint8_t>(interpreter->typed_tensor<uint8_t>(input), in, + image_height, image_width, image_channels, + wanted_height, wanted_width, wanted_channels, s); + break; + default: + LOG(FATAL) << "cannot handle input type " + << interpreter->tensor(input)->type << " yet"; + exit(-1); } struct timeval start_time, stop_time; @@ -177,13 +185,22 @@ void RunInference(Settings* s) { std::vector<std::pair<float, int>> top_results; - if (s->input_floating) { - get_top_n<float>(interpreter->typed_output_tensor<float>(0), output_size, - num_results, threshold, &top_results, s->input_floating); - } else { - get_top_n<uint8_t>(interpreter->typed_output_tensor<uint8_t>(0), + int output = interpreter->outputs()[0]; + switch (interpreter->tensor(output)->type) { + case kTfLiteFloat32: + get_top_n<float>(interpreter->typed_output_tensor<float>(0), output_size, num_results, threshold, &top_results, - s->input_floating); + true); + break; + case kTfLiteUInt8: + get_top_n<uint8_t>(interpreter->typed_output_tensor<uint8_t>(0), + output_size, num_results, threshold, &top_results, + false); + break; + default: + LOG(FATAL) << "cannot handle output type " + << interpreter->tensor(input)->type << " yet"; + exit(-1); } std::vector<string> labels; @@ -203,13 +220,11 @@ void display_usage() { LOG(INFO) << "label_image\n" << "--accelerated, -a: [0|1], use Android NNAPI or note\n" << "--count, -c: loop interpreter->Invoke() for certain times\n" - << "--input_floating, -f: [0|1] type of input layer is floating " - "point numbers\n" << "--input_mean, -b: input mean\n" << "--input_std, -s: input standard deviation\n" << "--image, -i: image_name.bmp\n" << "--labels, -l: labels for the model\n" - << "--tflite_mode, -m: model_name.tflite\n" + << "--tflite_model, -m: model_name.tflite\n" << "--threads, -t: number of threads\n" << "--verbose, -v: [0|1] print more information\n" << "\n"; @@ -223,7 +238,6 @@ int Main(int argc, char** argv) { static struct option long_options[] = { {"accelerated", required_argument, 0, 'a'}, {"count", required_argument, 0, 'c'}, - {"input_floating", required_argument, 0, 'f'}, {"verbose", required_argument, 0, 'v'}, {"image", required_argument, 0, 'i'}, {"labels", required_argument, 0, 'l'}, @@ -254,11 +268,6 @@ int Main(int argc, char** argv) { s.loop_count = strtol( // NOLINT(runtime/deprecated_fn) optarg, (char**)NULL, 10); break; - case 'f': - s.input_floating = strtol( // NOLINT(runtime/deprecated_fn) - optarg, (char**)NULL, 10); - s.input_layer_type = "float"; - break; case 'i': s.input_bmp_name = optarg; break; diff --git a/tensorflow/contrib/lite/examples/label_image/label_image.h b/tensorflow/contrib/lite/examples/label_image/label_image.h index ce98e06fc1..4de32e33fb 100644 --- a/tensorflow/contrib/lite/examples/label_image/label_image.h +++ b/tensorflow/contrib/lite/examples/label_image/label_image.h @@ -16,9 +16,11 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_LABEL_IMAGE_H #define TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_LABEL_IMAGE_H -#include <string> #include "tensorflow/contrib/lite/string.h" +namespace tflite { +namespace label_image { + struct Settings { bool verbose = false; bool accel = false; @@ -33,4 +35,7 @@ struct Settings { int number_of_threads = 4; }; +} // namespace label_image +} // namespace tflite + #endif // TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_LABEL_IMAGE_H diff --git a/tensorflow/contrib/lite/examples/label_image/label_image.md b/tensorflow/contrib/lite/examples/label_image/label_image.md index d6019d673f..9ce32cf101 100644 --- a/tensorflow/contrib/lite/examples/label_image/label_image.md +++ b/tensorflow/contrib/lite/examples/label_image/label_image.md @@ -1,8 +1,12 @@ label_image for TensorFlow Lite inspired by TensorFlow's label_image. + +To build label_image for Android, run $TENSORFLOW_ROOT/configure +and set Android NDK or configure NDK setting in +$TENSORFLOW_ROOT/WORKSPACE first. To build it for android ARMv8: ``` -> bazel build --cxxopt=-std=c++11 \ +> bazel build --config monolithic --cxxopt=-std=c++11 \ --crosstool_top=//external:android/crosstool \ --host_crosstool_top=@bazel_tools//tools/cpp:toolchain \ --cpu=arm64-v8a \ @@ -10,13 +14,13 @@ To build it for android ARMv8: ``` or ``` -> bazel build --config android_arm64 --cxxopt=-std=c++11 \ +> bazel build --config android_arm64 --config monolithic --cxxopt=-std=c++11 \ //tensorflow/contrib/lite/examples/label_image:label_image ``` To build it for android arm-v7a: ``` -> bazel build --cxxopt=-std=c++11 \ +> bazel build --config monolithic --cxxopt=-std=c++11 \ --crosstool_top=//external:android/crosstool \ --host_crosstool_top=@bazel_tools//tools/cpp:toolchain \ --cpu=armeabi-v7a \ @@ -24,7 +28,7 @@ To build it for android arm-v7a: ``` or ``` -> bazel build --config android_arm --cxxopt=-std=c++11 \ +> bazel build --config android_arm --config monolithic --cxxopt=-std=c++11 \ //tensorflow/contrib/lite/examples/label_image:label_image ``` |