aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/interpreter.h
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-05-04 16:37:27 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-05 08:31:44 -0700
commit037e52e20157985d3f385f8e0426cdde3f5aae2b (patch)
treeae99f730824213a83b4048ddce65a171d5c71f39 /tensorflow/contrib/lite/interpreter.h
parent008a3b69a601dc68fd940eb8a03b0c445714a339 (diff)
Expose read-only versions of tensors in tflite.
PiperOrigin-RevId: 195491701
Diffstat (limited to 'tensorflow/contrib/lite/interpreter.h')
-rw-r--r--tensorflow/contrib/lite/interpreter.h37
1 files changed, 31 insertions, 6 deletions
diff --git a/tensorflow/contrib/lite/interpreter.h b/tensorflow/contrib/lite/interpreter.h
index 1074f64263..0450e86ae7 100644
--- a/tensorflow/contrib/lite/interpreter.h
+++ b/tensorflow/contrib/lite/interpreter.h
@@ -201,7 +201,7 @@ class Interpreter {
// Overrides execution plan. This bounds checks indices sent in.
TfLiteStatus SetExecutionPlan(const std::vector<int>& new_plan);
- // Get a tensor data structure.
+ // Get a mutable tensor data structure.
// TODO(aselle): Create a safe ArrayHandle interface to avoid exposing this
// read/write access to structure
TfLiteTensor* tensor(int tensor_index) {
@@ -210,9 +210,14 @@ class Interpreter {
return &context_.tensors[tensor_index];
}
+ // Get an immutable tensor data structure.
+ const TfLiteTensor* tensor(int tensor_index) const {
+ if (tensor_index >= context_.tensors_size || tensor_index < 0)
+ return nullptr;
+ return &context_.tensors[tensor_index];
+ }
+
// Get a pointer to an operation and registration data structure if in bounds.
- // TODO(aselle): Create a safe ArrayHandle interface to avoid exposing this
- // read/write access to structure
const std::pair<TfLiteNode, TfLiteRegistration>* node_and_registration(
int node_index) const {
if (node_index >= nodes_and_registration_.size() || node_index < 0)
@@ -220,7 +225,8 @@ class Interpreter {
return &nodes_and_registration_[node_index];
}
- // Perform a checked cast to the appropriate tensor type.
+ // Perform a checked cast to the appropriate tensor type (mutable pointer
+ // version).
template <class T>
T* typed_tensor(int tensor_index) {
if (TfLiteTensor* tensor_ptr = tensor(tensor_index)) {
@@ -231,6 +237,18 @@ class Interpreter {
return nullptr;
}
+ // Perform a checked cast to the appropriate tensor type (immutable pointer
+ // version).
+ template <class T>
+ const T* typed_tensor(int tensor_index) const {
+ if (const TfLiteTensor* tensor_ptr = tensor(tensor_index)) {
+ if (tensor_ptr->type == typeToTfLiteType<T>()) {
+ return reinterpret_cast<const T*>(tensor_ptr->data.raw);
+ }
+ }
+ return nullptr;
+ }
+
// Return a pointer into the data of a given input tensor. The given index
// must be between 0 and inputs().size().
template <class T>
@@ -238,13 +256,20 @@ class Interpreter {
return typed_tensor<T>(inputs_[index]);
}
- // Return a pointer into the data of a given output tensor. The given index
- // must be between 0 and outputs().size().
+ // Return a mutable pointer into the data of a given output tensor. The given
+ // index must be between 0 and outputs().size().
template <class T>
T* typed_output_tensor(int index) {
return typed_tensor<T>(outputs_[index]);
}
+ // Return an immutable pointer into the data of a given output tensor. The
+ // given index must be between 0 and outputs().size().
+ template <class T>
+ const T* typed_output_tensor(int index) const {
+ return typed_tensor<T>(outputs_[index]);
+ }
+
// Change the dimensionality of a given tensor. Note, this is only acceptable
// for tensor indices that are inputs.
// Returns status of failure or success.