diff options
author | 2018-09-11 15:37:41 -0700 | |
---|---|---|
committer | 2018-09-11 15:48:56 -0700 | |
commit | 46bfc4766874bd41adaeb317a11075ce094ca1bb (patch) | |
tree | 6cce388c7a28db9f6c89011b64f3b55a1ce4c959 /tensorflow/contrib/lite/tools | |
parent | 6305a6d83552ba6a472cd72398b60d9241467f1f (diff) |
Regenerate input for every inference.
PiperOrigin-RevId: 212535619
Diffstat (limited to 'tensorflow/contrib/lite/tools')
-rw-r--r-- | tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc | 63 | ||||
-rw-r--r-- | tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h | 3 |
2 files changed, 36 insertions, 30 deletions
diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc index 02039922b4..0f3b3b40f8 100644 --- a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc +++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc @@ -232,6 +232,39 @@ uint64_t BenchmarkTfLiteModel::ComputeInputBytes() { return total_input_bytes; } +void BenchmarkTfLiteModel::PrepareInputsAndOutputs() { + auto interpreter_inputs = interpreter->inputs(); + // Set the values of the input tensors. + for (int j = 0; j < inputs.size(); ++j) { + const InputLayerInfo& input = inputs[j]; + int i = interpreter_inputs[j]; + TfLiteTensor* t = interpreter->tensor(i); + std::vector<int> sizes = input.shape; + + // TODO(ahentz): below we ignore the O-th dimension (number of batches). + if (t->type == kTfLiteFloat32) { + FillRandomValue<float>( + interpreter->typed_tensor<float>(i), + std::vector<int>(sizes.begin() + 1, sizes.end()), + []() { return static_cast<float>(rand()) / RAND_MAX - 0.5f; }); + } else if (t->type == kTfLiteUInt8) { + FillRandomValue<uint8_t>( + interpreter->typed_tensor<uint8_t>(i), + std::vector<int>(sizes.begin() + 1, sizes.end()), + []() { return static_cast<uint8_t>(rand()) % 255; }); + } else if (t->type == kTfLiteString) { + tflite::DynamicBuffer buffer; + FillRandomString(&buffer, sizes, []() { + return "we're have some friends over saturday to hang out in the yard"; + }); + buffer.WriteToTensor(interpreter->tensor(i)); + } else { + TFLITE_LOG(FATAL) << "Don't know how to populate tensor " << t->name + << " of type " << t->type; + } + } +} + void BenchmarkTfLiteModel::Init() { std::string graph = params_.Get<std::string>("graph"); model = tflite::FlatBufferModel::BuildFromFile(graph.c_str()); @@ -305,36 +338,6 @@ void BenchmarkTfLiteModel::Init() { if (interpreter->AllocateTensors() != kTfLiteOk) { TFLITE_LOG(FATAL) << "Failed to allocate tensors!"; } - - // Set the values of the input tensors. - for (int j = 0; j < inputs.size(); ++j) { - const InputLayerInfo& input = inputs[j]; - int i = interpreter_inputs[j]; - TfLiteTensor* t = interpreter->tensor(i); - std::vector<int> sizes = input.shape; - - // TODO(ahentz): below we ignore the O-th dimension (number of batches). - if (t->type == kTfLiteFloat32) { - FillRandomValue<float>( - interpreter->typed_tensor<float>(i), - std::vector<int>(sizes.begin() + 1, sizes.end()), - []() { return static_cast<float>(rand()) / RAND_MAX - 0.5f; }); - } else if (t->type == kTfLiteUInt8) { - FillRandomValue<uint8_t>( - interpreter->typed_tensor<uint8_t>(i), - std::vector<int>(sizes.begin() + 1, sizes.end()), - []() { return static_cast<uint8_t>(rand()) % 255; }); - } else if (t->type == kTfLiteString) { - tflite::DynamicBuffer buffer; - FillRandomString(&buffer, sizes, []() { - return "we're have some friends over saturday to hang out in the yard"; - }); - buffer.WriteToTensor(interpreter->tensor(i)); - } else { - TFLITE_LOG(FATAL) << "Don't know how to populate tensor " << t->name - << " of type " << t->type; - } - } } void BenchmarkTfLiteModel::RunImpl() { diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h index 4c4320a998..8541512bc8 100644 --- a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h +++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h @@ -69,6 +69,9 @@ class BenchmarkTfLiteModel : public BenchmarkModel { std::vector<int> shape; }; + protected: + void PrepareInputsAndOutputs() override; + private: #ifdef TFLITE_EXTENDED std::unique_ptr<EagerDelegate> delegate_; |