diff options
author | Andrew Selle <aselle@google.com> | 2018-07-18 10:25:51 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-18 10:29:10 -0700 |
commit | ec85fc632651324cb674793ae9741fb9a9a9c4f6 (patch) | |
tree | 4d75f4959c7b02946dbcaefa0a56557e94aab40f /tensorflow/contrib/lite/python | |
parent | 9cc29a75ce8131db67b48e92dac3c16a255b92ed (diff) |
Correct exception handling in TFLite Python interpreter.
We were incorrectly returning nullptr all the time. We need
to return None sometimes and check for it.
PiperOrigin-RevId: 205098110
Diffstat (limited to 'tensorflow/contrib/lite/python')
-rw-r--r-- | tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc | 23 |
1 files changed, 14 insertions, 9 deletions
diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc index c38b692dcd..f97919363b 100644 --- a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc +++ b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc @@ -340,6 +340,8 @@ PyObject* InterpreterWrapper::SetTensor(int i, PyObject* value) { namespace { +// Checks to see if a tensor access can succeed (returns nullptr on error). +// Otherwise returns Py_None. PyObject* CheckGetTensorArgs(Interpreter* interpreter_, int tensor_index, TfLiteTensor** tensor, int* type_num) { TFLITE_PY_ENSURE_VALID_INTERPRETER(); @@ -362,7 +364,7 @@ PyObject* CheckGetTensorArgs(Interpreter* interpreter_, int tensor_index, return nullptr; } - return nullptr; + Py_RETURN_NONE; } } // namespace @@ -371,10 +373,12 @@ PyObject* InterpreterWrapper::GetTensor(int i) const { // Sanity check accessor TfLiteTensor* tensor = nullptr; int type_num = 0; - if (PyObject* pynone_or_nullptr = - CheckGetTensorArgs(interpreter_.get(), i, &tensor, &type_num)) { - return pynone_or_nullptr; - } + + PyObject* check_result = + CheckGetTensorArgs(interpreter_.get(), i, &tensor, &type_num); + if (check_result == nullptr) return check_result; + Py_XDECREF(check_result); + std::vector<npy_intp> dims(tensor->dims->data, tensor->dims->data + tensor->dims->size); // Make a buffer copy but we must tell Numpy It owns that data or else @@ -396,10 +400,11 @@ PyObject* InterpreterWrapper::tensor(PyObject* base_object, int i) { // Sanity check accessor TfLiteTensor* tensor = nullptr; int type_num = 0; - if (PyObject* pynone_or_nullptr = - CheckGetTensorArgs(interpreter_.get(), i, &tensor, &type_num)) { - return pynone_or_nullptr; - } + + PyObject* check_result = + CheckGetTensorArgs(interpreter_.get(), i, &tensor, &type_num); + if (check_result == nullptr) return check_result; + Py_XDECREF(check_result); std::vector<npy_intp> dims(tensor->dims->data, tensor->dims->data + tensor->dims->size); |