diff options
author | Nupur Garg <nupurgarg@google.com> | 2018-10-03 10:51:17 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-03 10:55:31 -0700 |
commit | 560624bff65b7b502da2c52f9b250d9181c4a3f7 (patch) | |
tree | 29d3aab2396c231223952515333ce2f2c08f8e30 /tensorflow/contrib/lite/python | |
parent | af1458a9c1a3bc8d49a1e55386950b4941ab1815 (diff) |
Internal change.
PiperOrigin-RevId: 215589009
Diffstat (limited to 'tensorflow/contrib/lite/python')
3 files changed, 36 insertions, 1 deletions
diff --git a/tensorflow/contrib/lite/python/interpreter.py b/tensorflow/contrib/lite/python/interpreter.py index 5700bf7892..6300552cbe 100644 --- a/tensorflow/contrib/lite/python/interpreter.py +++ b/tensorflow/contrib/lite/python/interpreter.py @@ -129,6 +129,23 @@ class Interpreter(object): return details + def get_tensor_details(self): + """Gets tensor details for every tensor with valid tensor details. + + Tensors where required information about the tensor is not found are not + added to the list. This includes temporary tensors without a name. + + Returns: + A list of dictionaries containing tensor information. + """ + tensor_details = [] + for idx in range(self._interpreter.NumTensors()): + try: + tensor_details.append(self._get_tensor_details(idx)) + except ValueError: + pass + return tensor_details + def get_input_details(self): """Gets model input details. diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc index 418f19a179..1e2384b6d2 100644 --- a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc +++ b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc @@ -277,13 +277,20 @@ PyObject* InterpreterWrapper::ResizeInputTensor(int i, PyObject* value) { Py_RETURN_NONE; } +int InterpreterWrapper::NumTensors() const { + if (!interpreter_) { + return 0; + } + return interpreter_->tensors_size(); +} + std::string InterpreterWrapper::TensorName(int i) const { if (!interpreter_ || i >= interpreter_->tensors_size() || i < 0) { return ""; } const TfLiteTensor* tensor = interpreter_->tensor(i); - return tensor->name; + return tensor->name ? tensor->name : ""; } PyObject* InterpreterWrapper::TensorType(int i) const { @@ -291,6 +298,11 @@ PyObject* InterpreterWrapper::TensorType(int i) const { TFLITE_PY_TENSOR_BOUNDS_CHECK(i); const TfLiteTensor* tensor = interpreter_->tensor(i); + if (tensor->type == kTfLiteNoType) { + PyErr_Format(PyExc_ValueError, "Tensor with no type found."); + return nullptr; + } + int code = TfLiteTypeToPyArrayType(tensor->type); if (code == -1) { PyErr_Format(PyExc_ValueError, "Invalid tflite type code %d", code); @@ -302,7 +314,12 @@ PyObject* InterpreterWrapper::TensorType(int i) const { PyObject* InterpreterWrapper::TensorSize(int i) const { TFLITE_PY_ENSURE_VALID_INTERPRETER(); TFLITE_PY_TENSOR_BOUNDS_CHECK(i); + const TfLiteTensor* tensor = interpreter_->tensor(i); + if (tensor->dims == nullptr) { + PyErr_Format(PyExc_ValueError, "Tensor with no shape found."); + return nullptr; + } PyObject* np_array = PyArrayFromIntVector(tensor->dims->data, tensor->dims->size); diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h index f5ca81e62a..b98046fe8a 100644 --- a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h +++ b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h @@ -59,6 +59,7 @@ class InterpreterWrapper { PyObject* OutputIndices() const; PyObject* ResizeInputTensor(int i, PyObject* value); + int NumTensors() const; std::string TensorName(int i) const; PyObject* TensorType(int i) const; PyObject* TensorSize(int i) const; |