aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/python
diff options
context:
space:
mode:
authorGravatar Andrew Selle <aselle@google.com>2018-07-18 10:25:51 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-18 10:29:10 -0700
commitec85fc632651324cb674793ae9741fb9a9a9c4f6 (patch)
tree4d75f4959c7b02946dbcaefa0a56557e94aab40f /tensorflow/contrib/lite/python
parent9cc29a75ce8131db67b48e92dac3c16a255b92ed (diff)
Correct exception handling in TFLite Python interpreter.
We were incorrectly returning nullptr all the time. We need to return None sometimes and check for it. PiperOrigin-RevId: 205098110
Diffstat (limited to 'tensorflow/contrib/lite/python')
-rw-r--r--tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc23
1 files changed, 14 insertions, 9 deletions
diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc
index c38b692dcd..f97919363b 100644
--- a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc
+++ b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc
@@ -340,6 +340,8 @@ PyObject* InterpreterWrapper::SetTensor(int i, PyObject* value) {
namespace {
+// Checks to see if a tensor access can succeed (returns nullptr on error).
+// Otherwise returns Py_None.
PyObject* CheckGetTensorArgs(Interpreter* interpreter_, int tensor_index,
TfLiteTensor** tensor, int* type_num) {
TFLITE_PY_ENSURE_VALID_INTERPRETER();
@@ -362,7 +364,7 @@ PyObject* CheckGetTensorArgs(Interpreter* interpreter_, int tensor_index,
return nullptr;
}
- return nullptr;
+ Py_RETURN_NONE;
}
} // namespace
@@ -371,10 +373,12 @@ PyObject* InterpreterWrapper::GetTensor(int i) const {
// Sanity check accessor
TfLiteTensor* tensor = nullptr;
int type_num = 0;
- if (PyObject* pynone_or_nullptr =
- CheckGetTensorArgs(interpreter_.get(), i, &tensor, &type_num)) {
- return pynone_or_nullptr;
- }
+
+ PyObject* check_result =
+ CheckGetTensorArgs(interpreter_.get(), i, &tensor, &type_num);
+ if (check_result == nullptr) return check_result;
+ Py_XDECREF(check_result);
+
std::vector<npy_intp> dims(tensor->dims->data,
tensor->dims->data + tensor->dims->size);
// Make a buffer copy but we must tell Numpy It owns that data or else
@@ -396,10 +400,11 @@ PyObject* InterpreterWrapper::tensor(PyObject* base_object, int i) {
// Sanity check accessor
TfLiteTensor* tensor = nullptr;
int type_num = 0;
- if (PyObject* pynone_or_nullptr =
- CheckGetTensorArgs(interpreter_.get(), i, &tensor, &type_num)) {
- return pynone_or_nullptr;
- }
+
+ PyObject* check_result =
+ CheckGetTensorArgs(interpreter_.get(), i, &tensor, &type_num);
+ if (check_result == nullptr) return check_result;
+ Py_XDECREF(check_result);
std::vector<npy_intp> dims(tensor->dims->data,
tensor->dims->data + tensor->dims->size);