diff options
Diffstat (limited to 'tensorflow/contrib/lite/delegates/eager/test_util.h')
-rw-r--r-- | tensorflow/contrib/lite/delegates/eager/test_util.h | 28 |
1 files changed, 25 insertions, 3 deletions
diff --git a/tensorflow/contrib/lite/delegates/eager/test_util.h b/tensorflow/contrib/lite/delegates/eager/test_util.h index 0eab9e1135..816db41931 100644 --- a/tensorflow/contrib/lite/delegates/eager/test_util.h +++ b/tensorflow/contrib/lite/delegates/eager/test_util.h @@ -44,11 +44,30 @@ class EagerModelTest : public ::testing::Test { bool Invoke(); + // Sets the (typed) tensor's values at the given index. + template <typename T> + void SetTypedValues(int tensor_index, const std::vector<T>& values) { + memcpy(interpreter_->typed_tensor<T>(tensor_index), values.data(), + values.size() * sizeof(T)); + } + + // Returns the (typed) tensor's values at the given index. + template <typename T> + std::vector<T> GetTypedValues(int tensor_index) { + const TfLiteTensor* t = interpreter_->tensor(tensor_index); + const T* tdata = interpreter_->typed_tensor<T>(tensor_index); + return std::vector<T>(tdata, tdata + t->bytes / sizeof(T)); + } + // Sets the tensor's values at the given index. - void SetValues(int tensor_index, const std::vector<float>& values); + void SetValues(int tensor_index, const std::vector<float>& values) { + SetTypedValues<float>(tensor_index, values); + } // Returns the tensor's values at the given index. - std::vector<float> GetValues(int tensor_index); + std::vector<float> GetValues(int tensor_index) { + return GetTypedValues<float>(tensor_index); + } // Sets the tensor's shape at the given index. void SetShape(int tensor_index, const std::vector<int>& values); @@ -56,13 +75,16 @@ class EagerModelTest : public ::testing::Test { // Returns the tensor's shape at the given index. std::vector<int> GetShape(int tensor_index); + // Returns the tensor's type at the given index. + TfLiteType GetType(int tensor_index); + const TestErrorReporter& error_reporter() const { return error_reporter_; } // Adds `num_tensor` tensors to the model. `inputs` contains the indices of // the input tensors and `outputs` contains the indices of the output // tensors. All tensors are set to have `type` and `dims`. void AddTensors(int num_tensors, const std::vector<int>& inputs, - const std::vector<int>& outputs, const TfLiteType& type, + const std::vector<int>& outputs, TfLiteType type, const std::vector<int>& dims); // Adds a TFLite Mul op. `inputs` contains the indices of the input tensors |