aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/delegates/eager/test_util.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/delegates/eager/test_util.h')
-rw-r--r--tensorflow/contrib/lite/delegates/eager/test_util.h28
1 files changed, 25 insertions, 3 deletions
diff --git a/tensorflow/contrib/lite/delegates/eager/test_util.h b/tensorflow/contrib/lite/delegates/eager/test_util.h
index 0eab9e1135..816db41931 100644
--- a/tensorflow/contrib/lite/delegates/eager/test_util.h
+++ b/tensorflow/contrib/lite/delegates/eager/test_util.h
@@ -44,11 +44,30 @@ class EagerModelTest : public ::testing::Test {
bool Invoke();
+ // Sets the (typed) tensor's values at the given index.
+ template <typename T>
+ void SetTypedValues(int tensor_index, const std::vector<T>& values) {
+ memcpy(interpreter_->typed_tensor<T>(tensor_index), values.data(),
+ values.size() * sizeof(T));
+ }
+
+ // Returns the (typed) tensor's values at the given index.
+ template <typename T>
+ std::vector<T> GetTypedValues(int tensor_index) {
+ const TfLiteTensor* t = interpreter_->tensor(tensor_index);
+ const T* tdata = interpreter_->typed_tensor<T>(tensor_index);
+ return std::vector<T>(tdata, tdata + t->bytes / sizeof(T));
+ }
+
// Sets the tensor's values at the given index.
- void SetValues(int tensor_index, const std::vector<float>& values);
+ void SetValues(int tensor_index, const std::vector<float>& values) {
+ SetTypedValues<float>(tensor_index, values);
+ }
// Returns the tensor's values at the given index.
- std::vector<float> GetValues(int tensor_index);
+ std::vector<float> GetValues(int tensor_index) {
+ return GetTypedValues<float>(tensor_index);
+ }
// Sets the tensor's shape at the given index.
void SetShape(int tensor_index, const std::vector<int>& values);
@@ -56,13 +75,16 @@ class EagerModelTest : public ::testing::Test {
// Returns the tensor's shape at the given index.
std::vector<int> GetShape(int tensor_index);
+ // Returns the tensor's type at the given index.
+ TfLiteType GetType(int tensor_index);
+
const TestErrorReporter& error_reporter() const { return error_reporter_; }
// Adds `num_tensor` tensors to the model. `inputs` contains the indices of
// the input tensors and `outputs` contains the indices of the output
// tensors. All tensors are set to have `type` and `dims`.
void AddTensors(int num_tensors, const std::vector<int>& inputs,
- const std::vector<int>& outputs, const TfLiteType& type,
+ const std::vector<int>& outputs, TfLiteType type,
const std::vector<int>& dims);
// Adds a TFLite Mul op. `inputs` contains the indices of the input tensors