aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/examples/label_image
diff options
context:
space:
mode:
authorGravatar Patrick Nguyen <drpng@google.com>2018-05-01 14:28:36 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-01 14:33:20 -0700
commit325d0ef21a48bea1cc618a2bd24a9776de417ce5 (patch)
treed41cf6304071e95bebd5747ca87dfca571e98634 /tensorflow/contrib/lite/examples/label_image
parent46bf1e8934b3bc8edeff3f218a50b0ee5806e96b (diff)
Merge changes from github.
PiperOrigin-RevId: 194997009
Diffstat (limited to 'tensorflow/contrib/lite/examples/label_image')
-rw-r--r--tensorflow/contrib/lite/examples/label_image/label_image.cc45
-rw-r--r--tensorflow/contrib/lite/examples/label_image/label_image.h1
2 files changed, 44 insertions, 2 deletions
diff --git a/tensorflow/contrib/lite/examples/label_image/label_image.cc b/tensorflow/contrib/lite/examples/label_image/label_image.cc
index a91467d345..456c5c6dc7 100644
--- a/tensorflow/contrib/lite/examples/label_image/label_image.cc
+++ b/tensorflow/contrib/lite/examples/label_image/label_image.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include <cstdio>
#include <cstdlib>
#include <fstream>
+#include <iomanip>
#include <iostream>
#include <memory>
#include <sstream>
@@ -70,6 +71,23 @@ TfLiteStatus ReadLabelsFile(const string& file_name,
return kTfLiteOk;
}
+void PrintProfilingInfo(const profiling::ProfileEvent* e, uint32_t op_index,
+ TfLiteRegistration registration) {
+ // output something like
+ // time (ms) , Node xxx, OpCode xxx, symblic name
+ // 5.352, Node 5, OpCode 4, DEPTHWISE_CONV_2D
+
+
+ LOG(INFO) << std::fixed << std::setw(10) << std::setprecision(3)
+ << (e->end_timestamp_us - e->begin_timestamp_us) / 1000.0
+ << ", Node " << std::setw(3) << std::setprecision(3) << op_index
+ << ", OpCode " << std::setw(3) << std::setprecision(3)
+ << registration.builtin_code << ", "
+ << EnumNameBuiltinOperator(
+ (BuiltinOperator)registration.builtin_code)
+ << "\n";
+}
+
void RunInference(Settings* s) {
if (!s->model_name.c_str()) {
LOG(ERROR) << "no model file name\n";
@@ -166,6 +184,11 @@ void RunInference(Settings* s) {
exit(-1);
}
+ profiling::Profiler* profiler = new profiling::Profiler();
+ interpreter->SetProfiler(profiler);
+
+ if (s->profiling) profiler->StartProfiling();
+
struct timeval start_time, stop_time;
gettimeofday(&start_time, NULL);
for (int i = 0; i < s->loop_count; i++) {
@@ -179,6 +202,18 @@ void RunInference(Settings* s) {
<< (get_us(stop_time) - get_us(start_time)) / (s->loop_count * 1000)
<< " ms \n";
+ if (s->profiling) {
+ profiler->StopProfiling();
+ auto profile_events = profiler->GetProfileEvents();
+ for (int i = 0; i < profile_events.size(); i++) {
+ auto op_index = profile_events[i]->event_metadata;
+ const auto node_and_registration =
+ interpreter->node_and_registration(op_index);
+ const TfLiteRegistration registration = node_and_registration->second;
+ PrintProfilingInfo(profile_events[i], op_index, registration);
+ }
+ }
+
const int output_size = 1000;
const size_t num_results = 5;
const float threshold = 0.001f;
@@ -217,13 +252,14 @@ void RunInference(Settings* s) {
void display_usage() {
LOG(INFO) << "label_image\n"
- << "--accelerated, -a: [0|1], use Android NNAPI or note\n"
+ << "--accelerated, -a: [0|1], use Android NNAPI or not\n"
<< "--count, -c: loop interpreter->Invoke() for certain times\n"
<< "--input_mean, -b: input mean\n"
<< "--input_std, -s: input standard deviation\n"
<< "--image, -i: image_name.bmp\n"
<< "--labels, -l: labels for the model\n"
<< "--tflite_model, -m: model_name.tflite\n"
+ << "--profiling, -p: [0|1], profiling or not\n"
<< "--threads, -t: number of threads\n"
<< "--verbose, -v: [0|1] print more information\n"
<< "\n";
@@ -241,6 +277,7 @@ int Main(int argc, char** argv) {
{"image", required_argument, 0, 'i'},
{"labels", required_argument, 0, 'l'},
{"tflite_model", required_argument, 0, 'm'},
+ {"profiling", required_argument, 0, 'p'},
{"threads", required_argument, 0, 't'},
{"input_mean", required_argument, 0, 'b'},
{"input_std", required_argument, 0, 's'},
@@ -249,7 +286,7 @@ int Main(int argc, char** argv) {
/* getopt_long stores the option index here. */
int option_index = 0;
- c = getopt_long(argc, argv, "a:b:c:f:i:l:m:s:t:v:", long_options,
+ c = getopt_long(argc, argv, "a:b:c:f:i:l:m:p:s:t:v:", long_options,
&option_index);
/* Detect the end of the options. */
@@ -276,6 +313,10 @@ int Main(int argc, char** argv) {
case 'm':
s.model_name = optarg;
break;
+ case 'p':
+ s.profiling = strtol( // NOLINT(runtime/deprecated_fn)
+ optarg, (char**)NULL, 10);
+ break;
case 's':
s.input_std = strtod(optarg, NULL);
break;
diff --git a/tensorflow/contrib/lite/examples/label_image/label_image.h b/tensorflow/contrib/lite/examples/label_image/label_image.h
index 4de32e33fb..4b48014e1c 100644
--- a/tensorflow/contrib/lite/examples/label_image/label_image.h
+++ b/tensorflow/contrib/lite/examples/label_image/label_image.h
@@ -25,6 +25,7 @@ struct Settings {
bool verbose = false;
bool accel = false;
bool input_floating = false;
+ bool profiling = false;
int loop_count = 1;
float input_mean = 127.5f;
float input_std = 127.5f;