aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/tools
diff options
context:
space:
mode:
authorGravatar Yunlu Li <yunluli@google.com>2018-09-11 15:37:41 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-11 15:48:56 -0700
commit46bfc4766874bd41adaeb317a11075ce094ca1bb (patch)
tree6cce388c7a28db9f6c89011b64f3b55a1ce4c959 /tensorflow/contrib/lite/tools
parent6305a6d83552ba6a472cd72398b60d9241467f1f (diff)
Regenerate input for every inference.
PiperOrigin-RevId: 212535619
Diffstat (limited to 'tensorflow/contrib/lite/tools')
-rw-r--r--tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc63
-rw-r--r--tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h3
2 files changed, 36 insertions, 30 deletions
diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc
index 02039922b4..0f3b3b40f8 100644
--- a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc
+++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc
@@ -232,6 +232,39 @@ uint64_t BenchmarkTfLiteModel::ComputeInputBytes() {
return total_input_bytes;
}
+void BenchmarkTfLiteModel::PrepareInputsAndOutputs() {
+ auto interpreter_inputs = interpreter->inputs();
+ // Set the values of the input tensors.
+ for (int j = 0; j < inputs.size(); ++j) {
+ const InputLayerInfo& input = inputs[j];
+ int i = interpreter_inputs[j];
+ TfLiteTensor* t = interpreter->tensor(i);
+ std::vector<int> sizes = input.shape;
+
+ // TODO(ahentz): below we ignore the O-th dimension (number of batches).
+ if (t->type == kTfLiteFloat32) {
+ FillRandomValue<float>(
+ interpreter->typed_tensor<float>(i),
+ std::vector<int>(sizes.begin() + 1, sizes.end()),
+ []() { return static_cast<float>(rand()) / RAND_MAX - 0.5f; });
+ } else if (t->type == kTfLiteUInt8) {
+ FillRandomValue<uint8_t>(
+ interpreter->typed_tensor<uint8_t>(i),
+ std::vector<int>(sizes.begin() + 1, sizes.end()),
+ []() { return static_cast<uint8_t>(rand()) % 255; });
+ } else if (t->type == kTfLiteString) {
+ tflite::DynamicBuffer buffer;
+ FillRandomString(&buffer, sizes, []() {
+ return "we're have some friends over saturday to hang out in the yard";
+ });
+ buffer.WriteToTensor(interpreter->tensor(i));
+ } else {
+ TFLITE_LOG(FATAL) << "Don't know how to populate tensor " << t->name
+ << " of type " << t->type;
+ }
+ }
+}
+
void BenchmarkTfLiteModel::Init() {
std::string graph = params_.Get<std::string>("graph");
model = tflite::FlatBufferModel::BuildFromFile(graph.c_str());
@@ -305,36 +338,6 @@ void BenchmarkTfLiteModel::Init() {
if (interpreter->AllocateTensors() != kTfLiteOk) {
TFLITE_LOG(FATAL) << "Failed to allocate tensors!";
}
-
- // Set the values of the input tensors.
- for (int j = 0; j < inputs.size(); ++j) {
- const InputLayerInfo& input = inputs[j];
- int i = interpreter_inputs[j];
- TfLiteTensor* t = interpreter->tensor(i);
- std::vector<int> sizes = input.shape;
-
- // TODO(ahentz): below we ignore the O-th dimension (number of batches).
- if (t->type == kTfLiteFloat32) {
- FillRandomValue<float>(
- interpreter->typed_tensor<float>(i),
- std::vector<int>(sizes.begin() + 1, sizes.end()),
- []() { return static_cast<float>(rand()) / RAND_MAX - 0.5f; });
- } else if (t->type == kTfLiteUInt8) {
- FillRandomValue<uint8_t>(
- interpreter->typed_tensor<uint8_t>(i),
- std::vector<int>(sizes.begin() + 1, sizes.end()),
- []() { return static_cast<uint8_t>(rand()) % 255; });
- } else if (t->type == kTfLiteString) {
- tflite::DynamicBuffer buffer;
- FillRandomString(&buffer, sizes, []() {
- return "we're have some friends over saturday to hang out in the yard";
- });
- buffer.WriteToTensor(interpreter->tensor(i));
- } else {
- TFLITE_LOG(FATAL) << "Don't know how to populate tensor " << t->name
- << " of type " << t->type;
- }
- }
}
void BenchmarkTfLiteModel::RunImpl() {
diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h
index 4c4320a998..8541512bc8 100644
--- a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h
+++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h
@@ -69,6 +69,9 @@ class BenchmarkTfLiteModel : public BenchmarkModel {
std::vector<int> shape;
};
+ protected:
+ void PrepareInputsAndOutputs() override;
+
private:
#ifdef TFLITE_EXTENDED
std::unique_ptr<EagerDelegate> delegate_;