aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/examples
diff options
context:
space:
mode:
authorGravatar Koan-Sin Tan <koansin.tan@gmail.com>2018-07-20 14:21:42 +0800
committerGravatar Koan-Sin Tan <koansin.tan@gmail.com>2018-07-20 14:21:42 +0800
commita7fc3811dfbb75f14a9ece2d8904b72b8a45a670 (patch)
tree184775c23b880f86a876cb2aab56f666d10d0823 /tensorflow/contrib/lite/examples
parentdb308efbf4e95a7362fde90d35447091349b548e (diff)
get output size from output tensor
1. get output size from the output tensor 2. add command line option to specify number of results
Diffstat (limited to 'tensorflow/contrib/lite/examples')
-rw-r--r--tensorflow/contrib/lite/examples/label_image/label_image.cc19
-rw-r--r--tensorflow/contrib/lite/examples/label_image/label_image.h1
2 files changed, 14 insertions, 6 deletions
diff --git a/tensorflow/contrib/lite/examples/label_image/label_image.cc b/tensorflow/contrib/lite/examples/label_image/label_image.cc
index 86d7d1cc4a..7c6f523041 100644
--- a/tensorflow/contrib/lite/examples/label_image/label_image.cc
+++ b/tensorflow/contrib/lite/examples/label_image/label_image.cc
@@ -213,22 +213,23 @@ void RunInference(Settings* s) {
}
}
- const int output_size = 1000;
- const size_t num_results = 5;
const float threshold = 0.001f;
std::vector<std::pair<float, int>> top_results;
int output = interpreter->outputs()[0];
+ TfLiteIntArray* output_dims = interpreter->tensor(output)->dims;
+ // assume output dims to be something like (1, 1, ... ,size)
+ auto output_size = output_dims->data[output_dims->size - 1];
switch (interpreter->tensor(output)->type) {
case kTfLiteFloat32:
get_top_n<float>(interpreter->typed_output_tensor<float>(0), output_size,
- num_results, threshold, &top_results, true);
+ s->number_of_results, threshold, &top_results, true);
break;
case kTfLiteUInt8:
get_top_n<uint8_t>(interpreter->typed_output_tensor<uint8_t>(0),
- output_size, num_results, threshold, &top_results,
- false);
+ output_size, s->number_of_results, threshold,
+ &top_results, false);
break;
default:
LOG(FATAL) << "cannot handle output type "
@@ -259,6 +260,7 @@ void display_usage() {
<< "--labels, -l: labels for the model\n"
<< "--tflite_model, -m: model_name.tflite\n"
<< "--profiling, -p: [0|1], profiling or not\n"
+ << "--num_results, -r: number of results to show\n"
<< "--threads, -t: number of threads\n"
<< "--verbose, -v: [0|1] print more information\n"
<< "\n";
@@ -280,12 +282,13 @@ int Main(int argc, char** argv) {
{"threads", required_argument, nullptr, 't'},
{"input_mean", required_argument, nullptr, 'b'},
{"input_std", required_argument, nullptr, 's'},
+ {"num_results", required_argument, nullptr, 'r'},
{nullptr, 0, nullptr, 0}};
/* getopt_long stores the option index here. */
int option_index = 0;
- c = getopt_long(argc, argv, "a:b:c:f:i:l:m:p:s:t:v:", long_options,
+ c = getopt_long(argc, argv, "a:b:c:f:i:l:m:p:r:s:t:v:", long_options,
&option_index);
/* Detect the end of the options. */
@@ -315,6 +318,10 @@ int Main(int argc, char** argv) {
s.profiling =
strtol(optarg, nullptr, 10); // NOLINT(runtime/deprecated_fn)
break;
+ case 'r':
+ s.number_of_results =
+ strtol(optarg, nullptr, 10); // NOLINT(runtime/deprecated_fn)
+ break;
case 's':
s.input_std = strtod(optarg, nullptr);
break;
diff --git a/tensorflow/contrib/lite/examples/label_image/label_image.h b/tensorflow/contrib/lite/examples/label_image/label_image.h
index 4b48014e1c..34c223f713 100644
--- a/tensorflow/contrib/lite/examples/label_image/label_image.h
+++ b/tensorflow/contrib/lite/examples/label_image/label_image.h
@@ -34,6 +34,7 @@ struct Settings {
string labels_file_name = "./labels.txt";
string input_layer_type = "uint8_t";
int number_of_threads = 4;
+ int number_of_results = 5;
};
} // namespace label_image