diff options
Diffstat (limited to 'tensorflow/contrib/lite/examples/label_image/label_image.cc')
-rw-r--r-- | tensorflow/contrib/lite/examples/label_image/label_image.cc | 300 |
1 files changed, 300 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/examples/label_image/label_image.cc b/tensorflow/contrib/lite/examples/label_image/label_image.cc new file mode 100644 index 0000000000..4d2e1ce0bc --- /dev/null +++ b/tensorflow/contrib/lite/examples/label_image/label_image.cc @@ -0,0 +1,300 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <cstdarg> +#include <cstdio> +#include <cstdlib> +#include <fstream> +#include <iostream> +#include <memory> +#include <sstream> +#include <string> +#include <unordered_set> +#include <vector> + +#include <fcntl.h> // NOLINT(build/include_order) +#include <getopt.h> // NOLINT(build/include_order) +#include <sys/time.h> // NOLINT(build/include_order) +#include <sys/types.h> // NOLINT(build/include_order) +#include <sys/uio.h> // NOLINT(build/include_order) +#include <unistd.h> // NOLINT(build/include_order) + +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/optional_debug_tools.h" +#include "tensorflow/contrib/lite/string_util.h" + +#include "tensorflow/contrib/lite/examples/label_image/bitmap_helpers.h" +#include "tensorflow/contrib/lite/examples/label_image/get_top_n.h" + +#define LOG(x) std::cerr + +namespace tflite { +namespace label_image { + +double get_us(struct timeval t) { return (t.tv_sec * 1000000 + t.tv_usec); } + +// Takes a file name, and loads a list of labels from it, one per line, and +// returns a vector of the strings. It pads with empty strings so the length +// of the result is a multiple of 16, because our model expects that. +TfLiteStatus ReadLabelsFile(const string& file_name, + std::vector<string>* result, + size_t* found_label_count) { + std::ifstream file(file_name); + if (!file) { + LOG(FATAL) << "Labels file " << file_name << " not found\n"; + return kTfLiteError; + } + result->clear(); + string line; + while (std::getline(file, line)) { + result->push_back(line); + } + *found_label_count = result->size(); + const int padding = 16; + while (result->size() % padding) { + result->emplace_back(); + } + return kTfLiteOk; +} + +void RunInference(Settings* s) { + if (!s->model_name.c_str()) { + LOG(ERROR) << "no model file name\n"; + exit(-1); + } + + std::unique_ptr<tflite::FlatBufferModel> model; + std::unique_ptr<tflite::Interpreter> interpreter; + model = tflite::FlatBufferModel::BuildFromFile(s->model_name.c_str()); + if (!model) { + LOG(FATAL) << "\nFailed to mmap model " << s->model_name << "\n"; + exit(-1); + } + LOG(INFO) << "Loaded model " << s->model_name << "\n"; + model->error_reporter(); + LOG(INFO) << "resolved reporter\n"; + + tflite::ops::builtin::BuiltinOpResolver resolver; + + tflite::InterpreterBuilder(*model, resolver)(&interpreter); + if (!interpreter) { + LOG(FATAL) << "Failed to construct interpreter\n"; + exit(-1); + } + + interpreter->UseNNAPI(s->accel); + + if (s->verbose) { + LOG(INFO) << "tensors size: " << interpreter->tensors_size() << "\n"; + LOG(INFO) << "nodes size: " << interpreter->nodes_size() << "\n"; + LOG(INFO) << "inputs: " << interpreter->inputs().size() << "\n"; + LOG(INFO) << "input(0) name: " << interpreter->GetInputName(0) << "\n"; + + int t_size = interpreter->tensors_size(); + for (int i = 0; i < t_size; i++) { + if (interpreter->tensor(i)->name) + LOG(INFO) << i << ": " << interpreter->tensor(i)->name << ", " + << interpreter->tensor(i)->bytes << ", " + << interpreter->tensor(i)->type << ", " + << interpreter->tensor(i)->params.scale << ", " + << interpreter->tensor(i)->params.zero_point << "\n"; + } + } + + if (s->number_of_threads != -1) { + interpreter->SetNumThreads(s->number_of_threads); + } + + int image_width = 224; + int image_height = 224; + int image_channels = 3; + uint8_t* in = read_bmp(s->input_bmp_name, &image_width, &image_height, + &image_channels, s); + + int input = interpreter->inputs()[0]; + if (s->verbose) LOG(INFO) << "input: " << input << "\n"; + + const std::vector<int> inputs = interpreter->inputs(); + const std::vector<int> outputs = interpreter->outputs(); + + if (s->verbose) { + LOG(INFO) << "number of inputs: " << inputs.size() << "\n"; + LOG(INFO) << "number of outputs: " << outputs.size() << "\n"; + } + + if (interpreter->AllocateTensors() != kTfLiteOk) { + LOG(FATAL) << "Failed to allocate tensors!"; + } + + if (s->verbose) PrintInterpreterState(interpreter.get()); + + // get input dimension from the input tensor metadata + // assuming one input only + TfLiteIntArray* dims = interpreter->tensor(input)->dims; + int wanted_height = dims->data[1]; + int wanted_width = dims->data[2]; + int wanted_channels = dims->data[3]; + + if (s->input_floating) { + downsize<float>(interpreter->typed_tensor<float>(input), in, image_height, + image_width, image_channels, wanted_height, wanted_width, + wanted_channels, s); + } else { + downsize<uint8_t>(interpreter->typed_tensor<uint8_t>(input), in, + image_height, image_width, image_channels, wanted_height, + wanted_width, wanted_channels, s); + } + + struct timeval start_time, stop_time; + gettimeofday(&start_time, NULL); + for (int i = 0; i < s->loop_count; i++) { + if (interpreter->Invoke() != kTfLiteOk) { + LOG(FATAL) << "Failed to invoke tflite!\n"; + } + } + gettimeofday(&stop_time, NULL); + LOG(INFO) << "invoked \n"; + LOG(INFO) << "average time: " + << (get_us(stop_time) - get_us(start_time)) / (s->loop_count * 1000) + << " ms \n"; + + const int output_size = 1000; + const size_t num_results = 5; + const float threshold = 0.001f; + + std::vector<std::pair<float, int>> top_results; + + if (s->input_floating) { + get_top_n<float>(interpreter->typed_output_tensor<float>(0), output_size, + num_results, threshold, &top_results, s->input_floating); + } else { + get_top_n<uint8_t>(interpreter->typed_output_tensor<uint8_t>(0), + output_size, num_results, threshold, &top_results, + s->input_floating); + } + + std::vector<string> labels; + size_t label_count; + + if (ReadLabelsFile(s->labels_file_name, &labels, &label_count) != kTfLiteOk) + exit(-1); + + for (const auto& result : top_results) { + const float confidence = result.first; + const int index = result.second; + LOG(INFO) << confidence << ": " << index << " " << labels[index] << "\n"; + } +} + +void display_usage() { + LOG(INFO) << "label_image\n" + << "--accelerated, -a: [0|1], use Android NNAPI or note\n" + << "--count, -c: loop interpreter->Invoke() for certain times\n" + << "--input_floating, -f: [0|1] type of input layer is floating " + "point numbers\n" + << "--input_mean, -b: input mean\n" + << "--input_std, -s: input standard deviation\n" + << "--image, -i: image_name.bmp\n" + << "--labels, -l: labels for the model\n" + << "--tflite_mode, -m: model_name.tflite\n" + << "--threads, -t: number of threads\n" + << "--verbose, -v: [0|1] print more information\n" + << "\n"; +} + +int Main(int argc, char** argv) { + Settings s; + + int c; + while (1) { + static struct option long_options[] = { + {"accelerated", required_argument, 0, 'a'}, + {"count", required_argument, 0, 'c'}, + {"input_floating", required_argument, 0, 'f'}, + {"verbose", required_argument, 0, 'v'}, + {"image", required_argument, 0, 'i'}, + {"labels", required_argument, 0, 'l'}, + {"tflite_model", required_argument, 0, 'm'}, + {"threads", required_argument, 0, 't'}, + {"input_mean", required_argument, 0, 'b'}, + {"input_std", required_argument, 0, 's'}, + {0, 0, 0, 0}}; + + /* getopt_long stores the option index here. */ + int option_index = 0; + + c = getopt_long(argc, argv, "a:b:c:f:i:l:m:s:t:v:", long_options, + &option_index); + + /* Detect the end of the options. */ + if (c == -1) break; + + switch (c) { + case 'a': + s.accel = strtol( // NOLINT(runtime/deprecated_fn) + optarg, (char**)NULL, 10); + break; + case 'b': + s.input_mean = strtod(optarg, NULL); + break; + case 'c': + s.loop_count = strtol( // NOLINT(runtime/deprecated_fn) + optarg, (char**)NULL, 10); + break; + case 'f': + s.input_floating = strtol( // NOLINT(runtime/deprecated_fn) + optarg, (char**)NULL, 10); + s.input_layer_type = "float"; + break; + case 'i': + s.input_bmp_name = optarg; + break; + case 'l': + s.labels_file_name = optarg; + break; + case 'm': + s.model_name = optarg; + break; + case 's': + s.input_std = strtod(optarg, NULL); + break; + case 't': + s.number_of_threads = strtol( // NOLINT(runtime/deprecated_fn) + optarg, (char**)NULL, 10); + break; + case 'v': + s.verbose = strtol( // NOLINT(runtime/deprecated_fn) + optarg, (char**)NULL, 10); + break; + case 'h': + case '?': + /* getopt_long already printed an error message. */ + display_usage(); + exit(-1); + default: + exit(-1); + } + } + RunInference(&s); + return 0; +} + +} // namespace label_image +} // namespace tflite + +int main(int argc, char** argv) { + return tflite::label_image::Main(argc, argv); +} |