diff options
Diffstat (limited to 'tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc')
-rw-r--r-- | tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc | 312 |
1 files changed, 186 insertions, 126 deletions
diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc index 5554d08fa0..9ab05f3068 100644 --- a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc +++ b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc @@ -14,13 +14,13 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h" +#include <sstream> #include <string> #include "absl/memory/memory.h" #include "tensorflow/contrib/lite/interpreter.h" #include "tensorflow/contrib/lite/kernels/register.h" #include "tensorflow/contrib/lite/model.h" -#include "tensorflow/core/platform/logging.h" // Disallow Numpy 1.7 deprecated symbols. #define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION @@ -38,9 +38,58 @@ limitations under the License. #define CPP_TO_PYSTRING PyString_FromStringAndSize #endif +#define TFLITE_PY_CHECK(x) \ + if ((x) != kTfLiteOk) { \ + return error_reporter_->exception(); \ + } + +#define TFLITE_PY_TENSOR_BOUNDS_CHECK(i) \ + if (i >= interpreter_->tensors_size() || i < 0) { \ + PyErr_Format(PyExc_ValueError, \ + "Invalid tensor index %d exceeds max tensor index %lu", i, \ + interpreter_->tensors_size()); \ + return nullptr; \ + } + +#define TFLITE_PY_ENSURE_VALID_INTERPRETER() \ + if (!interpreter_) { \ + PyErr_SetString(PyExc_ValueError, "Interpreter was not initialized."); \ + return nullptr; \ + } + namespace tflite { namespace interpreter_wrapper { +class PythonErrorReporter : public tflite::ErrorReporter { + public: + PythonErrorReporter() {} + + // Report an error message + int Report(const char* format, va_list args) override { + char buf[1024]; + int formatted = vsnprintf(buf, sizeof(buf), format, args); + buffer_ << buf; + return formatted; + } + + // Set's a Python runtime exception with the last error. + PyObject* exception() { + std::string last_message = message(); + PyErr_SetString(PyExc_RuntimeError, last_message.c_str()); + return nullptr; + } + + // Gets the last error message and clears the buffer. + std::string message() { + std::string value = buffer_.str(); + buffer_.clear(); + return value; + } + + private: + std::stringstream buffer_; +}; + namespace { // Calls PyArray's initialization to initialize all the API pointers. Note that @@ -59,19 +108,8 @@ std::unique_ptr<tflite::Interpreter> CreateInterpreter( ImportNumpy(); std::unique_ptr<tflite::Interpreter> interpreter; - tflite::InterpreterBuilder(*model, resolver)(&interpreter); - if (interpreter) { - for (const int input_index : interpreter->inputs()) { - const TfLiteTensor* tensor = interpreter->tensor(input_index); - CHECK(tensor); - const TfLiteIntArray* dims = tensor->dims; - if (!dims) { - continue; - } - - std::vector<int> input_dims(dims->data, dims->data + dims->size); - interpreter->ResizeInputTensor(input_index, input_dims); - } + if (tflite::InterpreterBuilder(*model, resolver)(&interpreter) != kTfLiteOk) { + return nullptr; } return interpreter; } @@ -95,10 +133,10 @@ int TfLiteTypeToPyArrayType(TfLiteType tf_lite_type) { case kTfLiteComplex64: return NPY_COMPLEX64; case kTfLiteNoType: - return -1; + return NPY_NOTYPE; + // Avoid default so compiler errors created when new types are made. } - LOG(ERROR) << "Unknown TfLiteType " << tf_lite_type; - return -1; + return NPY_NOTYPE; } TfLiteType TfLiteTypeFromPyArray(PyArrayObject* array) { @@ -122,8 +160,8 @@ TfLiteType TfLiteTypeFromPyArray(PyArrayObject* array) { return kTfLiteString; case NPY_COMPLEX64: return kTfLiteComplex64; + // Avoid default so compiler errors created when new types are made. } - LOG(ERROR) << "Unknown PyArray dtype " << pyarray_type; return kTfLiteNoType; } @@ -146,33 +184,54 @@ PyObject* PyTupleFromQuantizationParam(const TfLiteQuantizationParams& param) { } // namespace +InterpreterWrapper* InterpreterWrapper::CreateInterpreterWrapper( + std::unique_ptr<tflite::FlatBufferModel> model, + std::unique_ptr<PythonErrorReporter> error_reporter, + std::string* error_msg) { + if (!model) { + *error_msg = error_reporter->message(); + return nullptr; + } + + auto resolver = absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>(); + auto interpreter = CreateInterpreter(model.get(), *resolver); + if (!interpreter) { + *error_msg = error_reporter->message(); + return nullptr; + } + + InterpreterWrapper* wrapper = + new InterpreterWrapper(std::move(model), std::move(error_reporter), + std::move(resolver), std::move(interpreter)); + return wrapper; +} + InterpreterWrapper::InterpreterWrapper( - std::unique_ptr<tflite::FlatBufferModel> model) + std::unique_ptr<tflite::FlatBufferModel> model, + std::unique_ptr<PythonErrorReporter> error_reporter, + std::unique_ptr<tflite::ops::builtin::BuiltinOpResolver> resolver, + std::unique_ptr<tflite::Interpreter> interpreter) : model_(std::move(model)), - resolver_(absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>()), - interpreter_(CreateInterpreter(model_.get(), *resolver_)) {} + error_reporter_(std::move(error_reporter)), + resolver_(std::move(resolver)), + interpreter_(std::move(interpreter)) {} InterpreterWrapper::~InterpreterWrapper() {} -bool InterpreterWrapper::AllocateTensors() { - if (!interpreter_) { - LOG(ERROR) << "Cannot allocate tensors: invalid interpreter."; - return false; - } - - if (interpreter_->AllocateTensors() != kTfLiteOk) { - LOG(ERROR) << "Unable to allocate tensors."; - return false; - } - - return true; +PyObject* InterpreterWrapper::AllocateTensors() { + TFLITE_PY_ENSURE_VALID_INTERPRETER(); + TFLITE_PY_CHECK(interpreter_->AllocateTensors()); + Py_RETURN_NONE; } -bool InterpreterWrapper::Invoke() { - return interpreter_ ? (interpreter_->Invoke() == kTfLiteOk) : false; +PyObject* InterpreterWrapper::Invoke() { + TFLITE_PY_ENSURE_VALID_INTERPRETER(); + TFLITE_PY_CHECK(interpreter_->Invoke()); + Py_RETURN_NONE; } PyObject* InterpreterWrapper::InputIndices() const { + TFLITE_PY_ENSURE_VALID_INTERPRETER(); PyObject* np_array = PyArrayFromIntVector(interpreter_->inputs().data(), interpreter_->inputs().size()); @@ -186,35 +245,36 @@ PyObject* InterpreterWrapper::OutputIndices() const { return PyArray_Return(reinterpret_cast<PyArrayObject*>(np_array)); } -bool InterpreterWrapper::ResizeInputTensor(int i, PyObject* value) { - if (!interpreter_) { - LOG(ERROR) << "Invalid interpreter."; - return false; - } +PyObject* InterpreterWrapper::ResizeInputTensor(int i, PyObject* value) { + TFLITE_PY_ENSURE_VALID_INTERPRETER(); std::unique_ptr<PyObject, PyDecrefDeleter> array_safe( PyArray_FromAny(value, nullptr, 0, 0, NPY_ARRAY_CARRAY, nullptr)); if (!array_safe) { - LOG(ERROR) << "Failed to convert value into readable tensor."; - return false; + PyErr_SetString(PyExc_ValueError, + "Failed to convert numpy value into readable tensor."); + return nullptr; } PyArrayObject* array = reinterpret_cast<PyArrayObject*>(array_safe.get()); if (PyArray_NDIM(array) != 1) { - LOG(ERROR) << "Expected 1-D defining input shape."; - return false; + PyErr_Format(PyExc_ValueError, "Shape should be 1D instead of %d.", + PyArray_NDIM(array)); + return nullptr; } if (PyArray_TYPE(array) != NPY_INT32) { - LOG(ERROR) << "Shape must be an int32 array"; - return false; + PyErr_Format(PyExc_ValueError, "Shape must be type int32 (was %d).", + PyArray_TYPE(array)); + return nullptr; } std::vector<int> dims(PyArray_SHAPE(array)[0]); memcpy(dims.data(), PyArray_BYTES(array), dims.size() * sizeof(int)); - return (interpreter_->ResizeInputTensor(i, dims) == kTfLiteOk); + TFLITE_PY_CHECK(interpreter_->ResizeInputTensor(i, dims)); + Py_RETURN_NONE; } std::string InterpreterWrapper::TensorName(int i) const { @@ -227,21 +287,21 @@ std::string InterpreterWrapper::TensorName(int i) const { } PyObject* InterpreterWrapper::TensorType(int i) const { - if (!interpreter_ || i >= interpreter_->tensors_size() || i < 0) { - return nullptr; - } + TFLITE_PY_ENSURE_VALID_INTERPRETER(); + TFLITE_PY_TENSOR_BOUNDS_CHECK(i); const TfLiteTensor* tensor = interpreter_->tensor(i); - int typenum = TfLiteTypeToPyArrayType(tensor->type); - return PyArray_TypeObjectFromType(typenum); + int code = TfLiteTypeToPyArrayType(tensor->type); + if (code == -1) { + PyErr_Format(PyExc_ValueError, "Invalid tflite type code %d", code); + return nullptr; + } + return PyArray_TypeObjectFromType(code); } PyObject* InterpreterWrapper::TensorSize(int i) const { - if (!interpreter_ || i >= interpreter_->tensors_size() || i < 0) { - Py_INCREF(Py_None); - return Py_None; - } - + TFLITE_PY_ENSURE_VALID_INTERPRETER(); + TFLITE_PY_TENSOR_BOUNDS_CHECK(i); const TfLiteTensor* tensor = interpreter_->tensor(i); PyObject* np_array = PyArrayFromIntVector(tensor->dims->data, tensor->dims->size); @@ -250,100 +310,87 @@ PyObject* InterpreterWrapper::TensorSize(int i) const { } PyObject* InterpreterWrapper::TensorQuantization(int i) const { - if (!interpreter_ || i >= interpreter_->tensors_size() || i < 0) { - Py_INCREF(Py_None); - return Py_None; - } - + TFLITE_PY_ENSURE_VALID_INTERPRETER(); + TFLITE_PY_TENSOR_BOUNDS_CHECK(i); const TfLiteTensor* tensor = interpreter_->tensor(i); return PyTupleFromQuantizationParam(tensor->params); } -bool InterpreterWrapper::SetTensor(int i, PyObject* value) { - if (!interpreter_) { - LOG(ERROR) << "Invalid interpreter."; - return false; - } - - if (i >= interpreter_->tensors_size()) { - LOG(ERROR) << "Invalid tensor index: " << i << " exceeds max tensor index " - << interpreter_->tensors_size(); - return false; - } +PyObject* InterpreterWrapper::SetTensor(int i, PyObject* value) { + TFLITE_PY_ENSURE_VALID_INTERPRETER(); + TFLITE_PY_TENSOR_BOUNDS_CHECK(i); std::unique_ptr<PyObject, PyDecrefDeleter> array_safe( PyArray_FromAny(value, nullptr, 0, 0, NPY_ARRAY_CARRAY, nullptr)); if (!array_safe) { - LOG(ERROR) << "Failed to convert value into readable tensor."; - return false; + PyErr_SetString(PyExc_ValueError, + "Failed to convert value into readable tensor."); + return nullptr; } PyArrayObject* array = reinterpret_cast<PyArrayObject*>(array_safe.get()); const TfLiteTensor* tensor = interpreter_->tensor(i); if (TfLiteTypeFromPyArray(array) != tensor->type) { - LOG(ERROR) << "Cannot set tensor:" - << " Got tensor of type " << TfLiteTypeFromPyArray(array) - << " but expected type " << tensor->type << " for input " << i; - return false; + PyErr_Format(PyExc_ValueError, + "Cannot set tensor:" + " Got tensor of type %d" + " but expected type %d for input %d ", + TfLiteTypeFromPyArray(array), tensor->type, i); + return nullptr; } if (PyArray_NDIM(array) != tensor->dims->size) { - LOG(ERROR) << "Cannot set tensor: Dimension mismatch"; - return false; + PyErr_SetString(PyExc_ValueError, "Cannot set tensor: Dimension mismatch"); + return nullptr; } for (int j = 0; j < PyArray_NDIM(array); j++) { if (tensor->dims->data[j] != PyArray_SHAPE(array)[j]) { - LOG(ERROR) << "Cannot set tensor: Dimension mismatch"; - return false; + PyErr_SetString(PyExc_ValueError, + "Cannot set tensor: Dimension mismatch"); + return nullptr; } } size_t size = PyArray_NBYTES(array); - DCHECK_EQ(size, tensor->bytes); + if (size != tensor->bytes) { + PyErr_Format(PyExc_ValueError, + "numpy array had %zu bytes but expected %zu bytes.", size, + tensor->bytes); + return nullptr; + } memcpy(tensor->data.raw, PyArray_DATA(array), size); - return true; + Py_RETURN_NONE; } namespace { -PyObject* CheckGetTensorArgs(Interpreter* interpreter, int tensor_index, +// Checks to see if a tensor access can succeed (returns nullptr on error). +// Otherwise returns Py_None. +PyObject* CheckGetTensorArgs(Interpreter* interpreter_, int tensor_index, TfLiteTensor** tensor, int* type_num) { - if (!interpreter) { - LOG(ERROR) << "Invalid interpreter."; - Py_INCREF(Py_None); - return Py_None; - } + TFLITE_PY_ENSURE_VALID_INTERPRETER(); + TFLITE_PY_TENSOR_BOUNDS_CHECK(tensor_index); - if (tensor_index >= interpreter->tensors_size() || tensor_index < 0) { - LOG(ERROR) << "Invalid tensor index: " << tensor_index - << " exceeds max tensor index " << interpreter->inputs().size(); - Py_INCREF(Py_None); - return Py_None; - } - - *tensor = interpreter->tensor(tensor_index); + *tensor = interpreter_->tensor(tensor_index); if ((*tensor)->bytes == 0) { - LOG(ERROR) << "Invalid tensor size"; - Py_INCREF(Py_None); - return Py_None; + PyErr_SetString(PyExc_ValueError, "Invalid tensor size."); + return nullptr; } *type_num = TfLiteTypeToPyArrayType((*tensor)->type); if (*type_num == -1) { - LOG(ERROR) << "Unknown tensor type " << (*tensor)->type; - Py_INCREF(Py_None); - return Py_None; + PyErr_SetString(PyExc_ValueError, "Unknown tensor type."); + return nullptr; } if (!(*tensor)->data.raw) { - LOG(ERROR) << "Tensor data is null."; - Py_INCREF(Py_None); - return Py_None; + PyErr_SetString(PyExc_ValueError, "Tensor data is null."); + return nullptr; } - return nullptr; + Py_RETURN_NONE; } } // namespace @@ -352,19 +399,20 @@ PyObject* InterpreterWrapper::GetTensor(int i) const { // Sanity check accessor TfLiteTensor* tensor = nullptr; int type_num = 0; - if (PyObject* pynone_or_nullptr = - CheckGetTensorArgs(interpreter_.get(), i, &tensor, &type_num)) { - return pynone_or_nullptr; - } + + PyObject* check_result = + CheckGetTensorArgs(interpreter_.get(), i, &tensor, &type_num); + if (check_result == nullptr) return check_result; + Py_XDECREF(check_result); + std::vector<npy_intp> dims(tensor->dims->data, tensor->dims->data + tensor->dims->size); // Make a buffer copy but we must tell Numpy It owns that data or else // it will leak. void* data = malloc(tensor->bytes); if (!data) { - LOG(ERROR) << "Malloc to copy tensor failed."; - Py_INCREF(Py_None); - return Py_None; + PyErr_SetString(PyExc_ValueError, "Malloc to copy tensor failed."); + return nullptr; } memcpy(data, tensor->data.raw, tensor->bytes); PyObject* np_array = @@ -378,10 +426,11 @@ PyObject* InterpreterWrapper::tensor(PyObject* base_object, int i) { // Sanity check accessor TfLiteTensor* tensor = nullptr; int type_num = 0; - if (PyObject* pynone_or_nullptr = - CheckGetTensorArgs(interpreter_.get(), i, &tensor, &type_num)) { - return pynone_or_nullptr; - } + + PyObject* check_result = + CheckGetTensorArgs(interpreter_.get(), i, &tensor, &type_num); + if (check_result == nullptr) return check_result; + Py_XDECREF(check_result); std::vector<npy_intp> dims(tensor->dims->data, tensor->dims->data + tensor->dims->size); @@ -394,22 +443,33 @@ PyObject* InterpreterWrapper::tensor(PyObject* base_object, int i) { } InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromFile( - const char* model_path) { + const char* model_path, std::string* error_msg) { + std::unique_ptr<PythonErrorReporter> error_reporter(new PythonErrorReporter); std::unique_ptr<tflite::FlatBufferModel> model = - tflite::FlatBufferModel::BuildFromFile(model_path); - return model ? new InterpreterWrapper(std::move(model)) : nullptr; + tflite::FlatBufferModel::BuildFromFile(model_path, error_reporter.get()); + return CreateInterpreterWrapper(std::move(model), std::move(error_reporter), + error_msg); } InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromBuffer( - PyObject* data) { + PyObject* data, std::string* error_msg) { char * buf = nullptr; Py_ssize_t length; + std::unique_ptr<PythonErrorReporter> error_reporter(new PythonErrorReporter); if (PY_TO_CPPSTRING(data, &buf, &length) == -1) { return nullptr; } std::unique_ptr<tflite::FlatBufferModel> model = - tflite::FlatBufferModel::BuildFromBuffer(buf, length); - return model ? new InterpreterWrapper(std::move(model)) : nullptr; + tflite::FlatBufferModel::BuildFromBuffer(buf, length, + error_reporter.get()); + return CreateInterpreterWrapper(std::move(model), std::move(error_reporter), + error_msg); +} + +PyObject* InterpreterWrapper::ResetVariableTensorsToZero() { + TFLITE_PY_ENSURE_VALID_INTERPRETER(); + TFLITE_PY_CHECK(interpreter_->ResetVariableTensorsToZero()); + Py_RETURN_NONE; } } // namespace interpreter_wrapper |