aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/python
diff options
context:
space:
mode:
authorGravatar Nupur Garg <nupurgarg@google.com>2018-10-03 10:51:17 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-03 10:55:31 -0700
commit560624bff65b7b502da2c52f9b250d9181c4a3f7 (patch)
tree29d3aab2396c231223952515333ce2f2c08f8e30 /tensorflow/contrib/lite/python
parentaf1458a9c1a3bc8d49a1e55386950b4941ab1815 (diff)
Internal change.
PiperOrigin-RevId: 215589009
Diffstat (limited to 'tensorflow/contrib/lite/python')
-rw-r--r--tensorflow/contrib/lite/python/interpreter.py17
-rw-r--r--tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc19
-rw-r--r--tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h1
3 files changed, 36 insertions, 1 deletions
diff --git a/tensorflow/contrib/lite/python/interpreter.py b/tensorflow/contrib/lite/python/interpreter.py
index 5700bf7892..6300552cbe 100644
--- a/tensorflow/contrib/lite/python/interpreter.py
+++ b/tensorflow/contrib/lite/python/interpreter.py
@@ -129,6 +129,23 @@ class Interpreter(object):
return details
+ def get_tensor_details(self):
+ """Gets tensor details for every tensor with valid tensor details.
+
+ Tensors where required information about the tensor is not found are not
+ added to the list. This includes temporary tensors without a name.
+
+ Returns:
+ A list of dictionaries containing tensor information.
+ """
+ tensor_details = []
+ for idx in range(self._interpreter.NumTensors()):
+ try:
+ tensor_details.append(self._get_tensor_details(idx))
+ except ValueError:
+ pass
+ return tensor_details
+
def get_input_details(self):
"""Gets model input details.
diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc
index 418f19a179..1e2384b6d2 100644
--- a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc
+++ b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc
@@ -277,13 +277,20 @@ PyObject* InterpreterWrapper::ResizeInputTensor(int i, PyObject* value) {
Py_RETURN_NONE;
}
+int InterpreterWrapper::NumTensors() const {
+ if (!interpreter_) {
+ return 0;
+ }
+ return interpreter_->tensors_size();
+}
+
std::string InterpreterWrapper::TensorName(int i) const {
if (!interpreter_ || i >= interpreter_->tensors_size() || i < 0) {
return "";
}
const TfLiteTensor* tensor = interpreter_->tensor(i);
- return tensor->name;
+ return tensor->name ? tensor->name : "";
}
PyObject* InterpreterWrapper::TensorType(int i) const {
@@ -291,6 +298,11 @@ PyObject* InterpreterWrapper::TensorType(int i) const {
TFLITE_PY_TENSOR_BOUNDS_CHECK(i);
const TfLiteTensor* tensor = interpreter_->tensor(i);
+ if (tensor->type == kTfLiteNoType) {
+ PyErr_Format(PyExc_ValueError, "Tensor with no type found.");
+ return nullptr;
+ }
+
int code = TfLiteTypeToPyArrayType(tensor->type);
if (code == -1) {
PyErr_Format(PyExc_ValueError, "Invalid tflite type code %d", code);
@@ -302,7 +314,12 @@ PyObject* InterpreterWrapper::TensorType(int i) const {
PyObject* InterpreterWrapper::TensorSize(int i) const {
TFLITE_PY_ENSURE_VALID_INTERPRETER();
TFLITE_PY_TENSOR_BOUNDS_CHECK(i);
+
const TfLiteTensor* tensor = interpreter_->tensor(i);
+ if (tensor->dims == nullptr) {
+ PyErr_Format(PyExc_ValueError, "Tensor with no shape found.");
+ return nullptr;
+ }
PyObject* np_array =
PyArrayFromIntVector(tensor->dims->data, tensor->dims->size);
diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h
index f5ca81e62a..b98046fe8a 100644
--- a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h
+++ b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h
@@ -59,6 +59,7 @@ class InterpreterWrapper {
PyObject* OutputIndices() const;
PyObject* ResizeInputTensor(int i, PyObject* value);
+ int NumTensors() const;
std::string TensorName(int i) const;
PyObject* TensorType(int i) const;
PyObject* TensorSize(int i) const;