diff options
author | 2018-05-04 16:37:27 -0700 | |
---|---|---|
committer | 2018-05-05 08:31:44 -0700 | |
commit | 037e52e20157985d3f385f8e0426cdde3f5aae2b (patch) | |
tree | ae99f730824213a83b4048ddce65a171d5c71f39 /tensorflow/contrib/lite/interpreter.h | |
parent | 008a3b69a601dc68fd940eb8a03b0c445714a339 (diff) |
Expose read-only versions of tensors in tflite.
PiperOrigin-RevId: 195491701
Diffstat (limited to 'tensorflow/contrib/lite/interpreter.h')
-rw-r--r-- | tensorflow/contrib/lite/interpreter.h | 37 |
1 files changed, 31 insertions, 6 deletions
diff --git a/tensorflow/contrib/lite/interpreter.h b/tensorflow/contrib/lite/interpreter.h index 1074f64263..0450e86ae7 100644 --- a/tensorflow/contrib/lite/interpreter.h +++ b/tensorflow/contrib/lite/interpreter.h @@ -201,7 +201,7 @@ class Interpreter { // Overrides execution plan. This bounds checks indices sent in. TfLiteStatus SetExecutionPlan(const std::vector<int>& new_plan); - // Get a tensor data structure. + // Get a mutable tensor data structure. // TODO(aselle): Create a safe ArrayHandle interface to avoid exposing this // read/write access to structure TfLiteTensor* tensor(int tensor_index) { @@ -210,9 +210,14 @@ class Interpreter { return &context_.tensors[tensor_index]; } + // Get an immutable tensor data structure. + const TfLiteTensor* tensor(int tensor_index) const { + if (tensor_index >= context_.tensors_size || tensor_index < 0) + return nullptr; + return &context_.tensors[tensor_index]; + } + // Get a pointer to an operation and registration data structure if in bounds. - // TODO(aselle): Create a safe ArrayHandle interface to avoid exposing this - // read/write access to structure const std::pair<TfLiteNode, TfLiteRegistration>* node_and_registration( int node_index) const { if (node_index >= nodes_and_registration_.size() || node_index < 0) @@ -220,7 +225,8 @@ class Interpreter { return &nodes_and_registration_[node_index]; } - // Perform a checked cast to the appropriate tensor type. + // Perform a checked cast to the appropriate tensor type (mutable pointer + // version). template <class T> T* typed_tensor(int tensor_index) { if (TfLiteTensor* tensor_ptr = tensor(tensor_index)) { @@ -231,6 +237,18 @@ class Interpreter { return nullptr; } + // Perform a checked cast to the appropriate tensor type (immutable pointer + // version). + template <class T> + const T* typed_tensor(int tensor_index) const { + if (const TfLiteTensor* tensor_ptr = tensor(tensor_index)) { + if (tensor_ptr->type == typeToTfLiteType<T>()) { + return reinterpret_cast<const T*>(tensor_ptr->data.raw); + } + } + return nullptr; + } + // Return a pointer into the data of a given input tensor. The given index // must be between 0 and inputs().size(). template <class T> @@ -238,13 +256,20 @@ class Interpreter { return typed_tensor<T>(inputs_[index]); } - // Return a pointer into the data of a given output tensor. The given index - // must be between 0 and outputs().size(). + // Return a mutable pointer into the data of a given output tensor. The given + // index must be between 0 and outputs().size(). template <class T> T* typed_output_tensor(int index) { return typed_tensor<T>(outputs_[index]); } + // Return an immutable pointer into the data of a given output tensor. The + // given index must be between 0 and outputs().size(). + template <class T> + const T* typed_output_tensor(int index) const { + return typed_tensor<T>(outputs_[index]); + } + // Change the dimensionality of a given tensor. Note, this is only acceptable // for tensor indices that are inputs. // Returns status of failure or success. |