diff options
Diffstat (limited to 'tensorflow/contrib/lite/tools/benchmark/benchmark_model.cc')
-rw-r--r-- | tensorflow/contrib/lite/tools/benchmark/benchmark_model.cc | 15 |
1 files changed, 12 insertions, 3 deletions
diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_model.cc b/tensorflow/contrib/lite/tools/benchmark/benchmark_model.cc index 08648bcfe2..f86c0445b0 100644 --- a/tensorflow/contrib/lite/tools/benchmark/benchmark_model.cc +++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_model.cc @@ -84,7 +84,7 @@ std::vector<Flag> BenchmarkModel::GetFlags() { }; } -void BenchmarkModel::LogFlags() { +void BenchmarkModel::LogParams() { TFLITE_LOG(INFO) << "Num runs: [" << params_.Get<int32_t>("num_runs") << "]"; TFLITE_LOG(INFO) << "Inter-run delay (seconds): [" << params_.Get<float>("run_delay") << "]"; @@ -98,10 +98,13 @@ void BenchmarkModel::LogFlags() { << "]"; } +void BenchmarkModel::PrepareInputsAndOutputs() {} + Stat<int64_t> BenchmarkModel::Run(int num_times, RunType run_type) { Stat<int64_t> run_stats; TFLITE_LOG(INFO) << "Running benchmark for " << num_times << " iterations "; for (int run = 0; run < num_times; run++) { + PrepareInputsAndOutputs(); listeners_.OnSingleRunStart(run_type); int64_t start_us = profiling::time::NowMicros(); RunImpl(); @@ -119,12 +122,18 @@ Stat<int64_t> BenchmarkModel::Run(int num_times, RunType run_type) { return run_stats; } +bool BenchmarkModel::ValidateParams() { return true; } + void BenchmarkModel::Run(int argc, char **argv) { if (!ParseFlags(argc, argv)) { return; } + Run(); +} - LogFlags(); +void BenchmarkModel::Run() { + ValidateParams(); + LogParams(); listeners_.OnBenchmarkStart(params_); int64_t initialization_start_us = profiling::time::NowMicros(); @@ -152,7 +161,7 @@ bool BenchmarkModel::ParseFlags(int argc, char **argv) { TFLITE_LOG(ERROR) << usage; return false; } - return ValidateFlags(); + return true; } } // namespace benchmark |