aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/testing/tflite_driver.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/testing/tflite_driver.cc')
-rw-r--r--tensorflow/contrib/lite/testing/tflite_driver.cc208
1 files changed, 208 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/testing/tflite_driver.cc b/tensorflow/contrib/lite/testing/tflite_driver.cc
new file mode 100644
index 0000000000..cf9df2ec26
--- /dev/null
+++ b/tensorflow/contrib/lite/testing/tflite_driver.cc
@@ -0,0 +1,208 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/testing/tflite_driver.h"
+
+#include <iostream>
+
+#include "tensorflow/contrib/lite/testing/split.h"
+
+namespace tflite {
+namespace testing {
+
+namespace {
+
+// Returns the value in the given position in a tensor.
+template <typename T>
+T Value(const TfLitePtrUnion& data, int index);
+template <>
+float Value(const TfLitePtrUnion& data, int index) {
+ return data.f[index];
+}
+template <>
+uint8_t Value(const TfLitePtrUnion& data, int index) {
+ return data.uint8[index];
+}
+
+template <typename T>
+void SetTensorData(const std::vector<T>& values, TfLitePtrUnion* data) {
+ T* input_ptr = reinterpret_cast<T*>(data->raw);
+ for (const T& v : values) {
+ *input_ptr = v;
+ ++input_ptr;
+ }
+}
+
+} // namespace
+
+class TfLiteDriver::Expectation {
+ public:
+ Expectation() { data_.raw = nullptr; }
+ ~Expectation() { delete[] data_.raw; }
+ template <typename T>
+ void SetData(const string& csv_values) {
+ const auto& values = testing::Split<T>(csv_values, ",");
+ data_.raw = new char[values.size() * sizeof(T)];
+ SetTensorData(values, &data_);
+ }
+
+ bool Check(bool verbose, const TfLiteTensor& tensor) {
+ switch (tensor.type) {
+ case kTfLiteFloat32:
+ return TypedCheck<float>(verbose, tensor);
+ case kTfLiteUInt8:
+ return TypedCheck<uint8_t>(verbose, tensor);
+ default:
+ return false;
+ }
+ }
+
+ private:
+ template <typename T>
+ bool TypedCheck(bool verbose, const TfLiteTensor& tensor) {
+ int tensor_size = tensor.bytes / sizeof(T);
+
+ bool good_output = true;
+ for (int i = 0; i < tensor_size; ++i) {
+ if (std::abs(Value<T>(data_, i) - Value<T>(tensor.data, i)) > 1e-5) {
+ good_output = false;
+ if (verbose) {
+ std::cerr << " index " << i << ": " << Value<T>(data_, i)
+ << " != " << Value<T>(tensor.data, i) << std::endl;
+ }
+ }
+ }
+ return good_output;
+ }
+
+ TfLitePtrUnion data_;
+};
+
+TfLiteDriver::TfLiteDriver(bool use_nnapi) : use_nnapi_(use_nnapi) {}
+TfLiteDriver::~TfLiteDriver() {}
+
+void TfLiteDriver::AllocateTensors() {
+ if (must_allocate_tensors_) {
+ if (interpreter_->AllocateTensors() != kTfLiteOk) {
+ std::cerr << "Failed to allocate tensors" << std::endl;
+ abort();
+ }
+ must_allocate_tensors_ = false;
+ }
+}
+
+void TfLiteDriver::LoadModel(const string& bin_file_path) {
+ if (!IsValid()) return;
+ std::cout << std::endl << "Loading model: " << bin_file_path << std::endl;
+
+ model_ = FlatBufferModel::BuildFromFile(GetFullPath(bin_file_path).c_str());
+ if (!model_) {
+ Invalidate("Failed to mmap model " + bin_file_path);
+ return;
+ }
+ ops::builtin::BuiltinOpResolver builtins;
+ InterpreterBuilder(*model_, builtins)(&interpreter_);
+ if (!interpreter_) {
+ Invalidate("Failed build interpreter");
+ return;
+ }
+
+ must_allocate_tensors_ = true;
+}
+
+void TfLiteDriver::ResetTensor(int id) {
+ if (!IsValid()) return;
+ auto* tensor = interpreter_->tensor(id);
+ memset(tensor->data.raw, 0, tensor->bytes);
+}
+
+void TfLiteDriver::ReshapeTensor(int id, const string& csv_values) {
+ if (!IsValid()) return;
+ if (interpreter_->ResizeInputTensor(
+ id, testing::Split<int>(csv_values, ",")) != kTfLiteOk) {
+ Invalidate("Failed to resize input tensor " + std::to_string(id));
+ return;
+ }
+ must_allocate_tensors_ = true;
+}
+
+void TfLiteDriver::SetInput(int id, const string& csv_values) {
+ if (!IsValid()) return;
+ auto* tensor = interpreter_->tensor(id);
+ switch (tensor->type) {
+ case kTfLiteFloat32: {
+ const auto& values = testing::Split<float>(csv_values, ",");
+ if (!CheckSizes<float>(tensor->bytes, values.size())) return;
+ SetTensorData(values, &tensor->data);
+ break;
+ }
+ case kTfLiteUInt8: {
+ const auto& values = testing::Split<uint8_t>(csv_values, ",");
+ if (!CheckSizes<uint8_t>(tensor->bytes, values.size())) return;
+ SetTensorData(values, &tensor->data);
+ break;
+ }
+ default:
+ Invalidate("Unsupported tensor data type");
+ return;
+ }
+}
+
+void TfLiteDriver::SetExpectation(int id, const string& csv_values) {
+ if (!IsValid()) return;
+ auto* tensor = interpreter_->tensor(id);
+ expected_output_[id].reset(new Expectation);
+ switch (tensor->type) {
+ case kTfLiteFloat32:
+ expected_output_[id]->SetData<float>(csv_values);
+ break;
+ case kTfLiteUInt8:
+ expected_output_[id]->SetData<uint8_t>(csv_values);
+ break;
+ default:
+ Invalidate("Unsupported tensor data type");
+ return;
+ }
+}
+
+void TfLiteDriver::Invoke() {
+ if (!IsValid()) return;
+ if (interpreter_->Invoke() != kTfLiteOk) {
+ Invalidate("Failed to invoke interpreter");
+ }
+}
+
+bool TfLiteDriver::CheckResults() {
+ if (!IsValid()) return false;
+ bool success = true;
+ for (const auto& p : expected_output_) {
+ int id = p.first;
+ auto* tensor = interpreter_->tensor(id);
+ if (!p.second->Check(/*verbose=*/false, *tensor)) {
+ // Do not invalidate anything here. Instead, simply output the
+ // differences and return false. Invalidating would prevent all
+ // subsequent invocations from running..
+ std::cerr << "There were errors in invocation '" << GetInvocationId()
+ << "', output tensor '" << id << "':" << std::endl;
+ p.second->Check(/*verbose=*/true, *tensor);
+ success = false;
+ SetOverallSuccess(false);
+ }
+ }
+ expected_output_.clear();
+ return success;
+}
+
+} // namespace testing
+} // namespace tflite