diff options
Diffstat (limited to 'tensorflow/contrib/lite/testing/test_runner.h')
-rw-r--r-- | tensorflow/contrib/lite/testing/test_runner.h | 124 |
1 files changed, 124 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/testing/test_runner.h b/tensorflow/contrib/lite/testing/test_runner.h new file mode 100644 index 0000000000..04ee4d9f7d --- /dev/null +++ b/tensorflow/contrib/lite/testing/test_runner.h @@ -0,0 +1,124 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_TEST_RUNNER_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_TEST_RUNNER_H_ + +#include <memory> +#include <string> +#include <vector> +#include "tensorflow/contrib/lite/string.h" + +namespace tflite { +namespace testing { + +// This is the base class for processing test data. Each one of the virtual +// methods must be implemented to forward the data to the appropriate executor +// (e.g. TF Lite's interpreter, or the NNAPI). +class TestRunner { + public: + TestRunner() {} + virtual ~TestRunner() {} + + // Load the given model, as a path relative to SetModelBaseDir(). + virtual void LoadModel(const string& bin_file_path) = 0; + + // Return the list of input tensors in the loaded model. + virtual const std::vector<int>& GetInputs() = 0; + + // Return the list of output tensors in the loaded model. + virtual const std::vector<int>& GetOutputs() = 0; + + // Prepare for a run by resize the given tensor. The given 'id' is + // guaranteed to be one of the ids returned by GetInputs(). + virtual void ReshapeTensor(int id, const string& csv_values) = 0; + + // Reserve memory for all tensors. + virtual void AllocateTensors() = 0; + + // Set the given tensor to some initial state, usually zero. This is + // used to reset persistent buffers in a model. + virtual void ResetTensor(int id) = 0; + + // Define the contents of the given input tensor. The given 'id' is + // guaranteed to be one of the ids returned by GetInputs(). + virtual void SetInput(int id, const string& csv_values) = 0; + + // Define what should be expected for an output tensor after Invoke() runs. + // The given 'id' is guaranteed to be one of the ids returned by + // GetOutputs(). + virtual void SetExpectation(int id, const string& csv_values) = 0; + + // Run the model. + virtual void Invoke() = 0; + + // Verify that the contents of all ouputs conform to the existing + // expectations. Return true if there are no expectations or they are all + // satisfied. + virtual bool CheckResults() = 0; + + // Set the base path for loading models. + void SetModelBaseDir(const string& path) { + model_base_dir_ = path; + if (path[path.length() - 1] != '/') { + model_base_dir_ += "/"; + } + } + + // Return the full path of a model. + string GetFullPath(const string& path) { return model_base_dir_ + path; } + + // Give an id to the next invocation to make error reporting more meaningful. + void SetInvocationId(const string& id) { invocation_id_ = id; } + const string& GetInvocationId() const { return invocation_id_; } + + // Invalidate the test runner, preventing it from executing any further. + void Invalidate(const string& error_message) { + error_message_ = error_message; + } + bool IsValid() const { return error_message_.empty(); } + const string& GetErrorMessage() const { return error_message_; } + + // Handle the overall success of this test runner. This will be true if all + // invocations were successful. + void SetOverallSuccess(bool value) { overall_success_ = value; } + bool GetOverallSuccess() const { return overall_success_; } + + protected: + // A helper to check of the given number of values is consistent with the + // number of bytes in a tensor of type T. When incompatibles sizes are found, + // the test runner is invalidated and false is returned. + template <typename T> + bool CheckSizes(size_t tensor_bytes, size_t num_values) { + size_t num_tensor_elements = tensor_bytes / sizeof(T); + if (num_tensor_elements != num_values) { + Invalidate("Expected '" + std::to_string(num_tensor_elements) + + "' elements for a tensor, but only got '" + + std::to_string(num_values) + "'"); + return false; + } + return true; + } + + private: + string model_base_dir_; + string invocation_id_; + bool overall_success_ = true; + + string error_message_; +}; + +} // namespace testing +} // namespace tflite +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_TEST_RUNNER_H_ |