aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/delegates/flex/test_util.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/delegates/flex/test_util.h')
-rw-r--r--tensorflow/contrib/lite/delegates/flex/test_util.h119
1 files changed, 119 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/delegates/flex/test_util.h b/tensorflow/contrib/lite/delegates/flex/test_util.h
new file mode 100644
index 0000000000..a8c81b90a3
--- /dev/null
+++ b/tensorflow/contrib/lite/delegates/flex/test_util.h
@@ -0,0 +1,119 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_FLEX_TEST_UTIL_H_
+#define TENSORFLOW_CONTRIB_LITE_DELEGATES_FLEX_TEST_UTIL_H_
+
+#include "tensorflow/c/c_api_internal.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+
+namespace tflite {
+namespace flex {
+namespace testing {
+
+enum TfOpType {
+ kUnpack,
+ kIdentity,
+ kAdd,
+ kMul,
+ // Represents an op that does not exist in TensorFlow.
+ kNonExistent,
+ // Represents an valid TensorFlow op where the NodeDef is incompatible.
+ kIncompatibleNodeDef,
+};
+
+// This class creates models with TF and TFLite ops. In order to use this class
+// to test the Flex delegate, implement a function that calls
+// interpreter->ModifyGraphWithDelegate.
+class FlexModelTest : public ::testing::Test {
+ public:
+ FlexModelTest() {}
+ ~FlexModelTest() {}
+
+ bool Invoke();
+
+ // Sets the (typed) tensor's values at the given index.
+ template <typename T>
+ void SetTypedValues(int tensor_index, const std::vector<T>& values) {
+ memcpy(interpreter_->typed_tensor<T>(tensor_index), values.data(),
+ values.size() * sizeof(T));
+ }
+
+ // Returns the (typed) tensor's values at the given index.
+ template <typename T>
+ std::vector<T> GetTypedValues(int tensor_index) {
+ const TfLiteTensor* t = interpreter_->tensor(tensor_index);
+ const T* tdata = interpreter_->typed_tensor<T>(tensor_index);
+ return std::vector<T>(tdata, tdata + t->bytes / sizeof(T));
+ }
+
+ // Sets the tensor's values at the given index.
+ void SetValues(int tensor_index, const std::vector<float>& values) {
+ SetTypedValues<float>(tensor_index, values);
+ }
+
+ // Returns the tensor's values at the given index.
+ std::vector<float> GetValues(int tensor_index) {
+ return GetTypedValues<float>(tensor_index);
+ }
+
+ // Sets the tensor's shape at the given index.
+ void SetShape(int tensor_index, const std::vector<int>& values);
+
+ // Returns the tensor's shape at the given index.
+ std::vector<int> GetShape(int tensor_index);
+
+ // Returns the tensor's type at the given index.
+ TfLiteType GetType(int tensor_index);
+
+ const TestErrorReporter& error_reporter() const { return error_reporter_; }
+
+ // Adds `num_tensor` tensors to the model. `inputs` contains the indices of
+ // the input tensors and `outputs` contains the indices of the output
+ // tensors. All tensors are set to have `type` and `dims`.
+ void AddTensors(int num_tensors, const std::vector<int>& inputs,
+ const std::vector<int>& outputs, TfLiteType type,
+ const std::vector<int>& dims);
+
+ // Adds a TFLite Mul op. `inputs` contains the indices of the input tensors
+ // and `outputs` contains the indices of the output tensors.
+ void AddTfLiteMulOp(const std::vector<int>& inputs,
+ const std::vector<int>& outputs);
+
+ // Adds a TensorFlow op. `inputs` contains the indices of the
+ // input tensors and `outputs` contains the indices of the output tensors.
+ // This function is limited to the set of ops defined in TfOpType.
+ void AddTfOp(TfOpType op, const std::vector<int>& inputs,
+ const std::vector<int>& outputs);
+
+ protected:
+ std::unique_ptr<Interpreter> interpreter_;
+ TestErrorReporter error_reporter_;
+
+ private:
+ // Helper method to add a TensorFlow op. tflite_names needs to start with
+ // "Flex" in order to work with the Flex delegate.
+ void AddTfOp(const char* tflite_name, const string& tf_name,
+ const string& nodedef_str, const std::vector<int>& inputs,
+ const std::vector<int>& outputs);
+
+ std::vector<std::vector<uint8_t>> flexbuffers_;
+};
+
+} // namespace testing
+} // namespace flex
+} // namespace tflite
+
+#endif // TENSORFLOW_CONTRIB_LITE_DELEGATES_FLEX_TEST_UTIL_H_