aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc')
-rw-r--r--tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc312
1 files changed, 186 insertions, 126 deletions
diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc
index 5554d08fa0..9ab05f3068 100644
--- a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc
+++ b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc
@@ -14,13 +14,13 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h"
+#include <sstream>
#include <string>
#include "absl/memory/memory.h"
#include "tensorflow/contrib/lite/interpreter.h"
#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/model.h"
-#include "tensorflow/core/platform/logging.h"
// Disallow Numpy 1.7 deprecated symbols.
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
@@ -38,9 +38,58 @@ limitations under the License.
#define CPP_TO_PYSTRING PyString_FromStringAndSize
#endif
+#define TFLITE_PY_CHECK(x) \
+ if ((x) != kTfLiteOk) { \
+ return error_reporter_->exception(); \
+ }
+
+#define TFLITE_PY_TENSOR_BOUNDS_CHECK(i) \
+ if (i >= interpreter_->tensors_size() || i < 0) { \
+ PyErr_Format(PyExc_ValueError, \
+ "Invalid tensor index %d exceeds max tensor index %lu", i, \
+ interpreter_->tensors_size()); \
+ return nullptr; \
+ }
+
+#define TFLITE_PY_ENSURE_VALID_INTERPRETER() \
+ if (!interpreter_) { \
+ PyErr_SetString(PyExc_ValueError, "Interpreter was not initialized."); \
+ return nullptr; \
+ }
+
namespace tflite {
namespace interpreter_wrapper {
+class PythonErrorReporter : public tflite::ErrorReporter {
+ public:
+ PythonErrorReporter() {}
+
+ // Report an error message
+ int Report(const char* format, va_list args) override {
+ char buf[1024];
+ int formatted = vsnprintf(buf, sizeof(buf), format, args);
+ buffer_ << buf;
+ return formatted;
+ }
+
+ // Set's a Python runtime exception with the last error.
+ PyObject* exception() {
+ std::string last_message = message();
+ PyErr_SetString(PyExc_RuntimeError, last_message.c_str());
+ return nullptr;
+ }
+
+ // Gets the last error message and clears the buffer.
+ std::string message() {
+ std::string value = buffer_.str();
+ buffer_.clear();
+ return value;
+ }
+
+ private:
+ std::stringstream buffer_;
+};
+
namespace {
// Calls PyArray's initialization to initialize all the API pointers. Note that
@@ -59,19 +108,8 @@ std::unique_ptr<tflite::Interpreter> CreateInterpreter(
ImportNumpy();
std::unique_ptr<tflite::Interpreter> interpreter;
- tflite::InterpreterBuilder(*model, resolver)(&interpreter);
- if (interpreter) {
- for (const int input_index : interpreter->inputs()) {
- const TfLiteTensor* tensor = interpreter->tensor(input_index);
- CHECK(tensor);
- const TfLiteIntArray* dims = tensor->dims;
- if (!dims) {
- continue;
- }
-
- std::vector<int> input_dims(dims->data, dims->data + dims->size);
- interpreter->ResizeInputTensor(input_index, input_dims);
- }
+ if (tflite::InterpreterBuilder(*model, resolver)(&interpreter) != kTfLiteOk) {
+ return nullptr;
}
return interpreter;
}
@@ -95,10 +133,10 @@ int TfLiteTypeToPyArrayType(TfLiteType tf_lite_type) {
case kTfLiteComplex64:
return NPY_COMPLEX64;
case kTfLiteNoType:
- return -1;
+ return NPY_NOTYPE;
+ // Avoid default so compiler errors created when new types are made.
}
- LOG(ERROR) << "Unknown TfLiteType " << tf_lite_type;
- return -1;
+ return NPY_NOTYPE;
}
TfLiteType TfLiteTypeFromPyArray(PyArrayObject* array) {
@@ -122,8 +160,8 @@ TfLiteType TfLiteTypeFromPyArray(PyArrayObject* array) {
return kTfLiteString;
case NPY_COMPLEX64:
return kTfLiteComplex64;
+ // Avoid default so compiler errors created when new types are made.
}
- LOG(ERROR) << "Unknown PyArray dtype " << pyarray_type;
return kTfLiteNoType;
}
@@ -146,33 +184,54 @@ PyObject* PyTupleFromQuantizationParam(const TfLiteQuantizationParams& param) {
} // namespace
+InterpreterWrapper* InterpreterWrapper::CreateInterpreterWrapper(
+ std::unique_ptr<tflite::FlatBufferModel> model,
+ std::unique_ptr<PythonErrorReporter> error_reporter,
+ std::string* error_msg) {
+ if (!model) {
+ *error_msg = error_reporter->message();
+ return nullptr;
+ }
+
+ auto resolver = absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>();
+ auto interpreter = CreateInterpreter(model.get(), *resolver);
+ if (!interpreter) {
+ *error_msg = error_reporter->message();
+ return nullptr;
+ }
+
+ InterpreterWrapper* wrapper =
+ new InterpreterWrapper(std::move(model), std::move(error_reporter),
+ std::move(resolver), std::move(interpreter));
+ return wrapper;
+}
+
InterpreterWrapper::InterpreterWrapper(
- std::unique_ptr<tflite::FlatBufferModel> model)
+ std::unique_ptr<tflite::FlatBufferModel> model,
+ std::unique_ptr<PythonErrorReporter> error_reporter,
+ std::unique_ptr<tflite::ops::builtin::BuiltinOpResolver> resolver,
+ std::unique_ptr<tflite::Interpreter> interpreter)
: model_(std::move(model)),
- resolver_(absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>()),
- interpreter_(CreateInterpreter(model_.get(), *resolver_)) {}
+ error_reporter_(std::move(error_reporter)),
+ resolver_(std::move(resolver)),
+ interpreter_(std::move(interpreter)) {}
InterpreterWrapper::~InterpreterWrapper() {}
-bool InterpreterWrapper::AllocateTensors() {
- if (!interpreter_) {
- LOG(ERROR) << "Cannot allocate tensors: invalid interpreter.";
- return false;
- }
-
- if (interpreter_->AllocateTensors() != kTfLiteOk) {
- LOG(ERROR) << "Unable to allocate tensors.";
- return false;
- }
-
- return true;
+PyObject* InterpreterWrapper::AllocateTensors() {
+ TFLITE_PY_ENSURE_VALID_INTERPRETER();
+ TFLITE_PY_CHECK(interpreter_->AllocateTensors());
+ Py_RETURN_NONE;
}
-bool InterpreterWrapper::Invoke() {
- return interpreter_ ? (interpreter_->Invoke() == kTfLiteOk) : false;
+PyObject* InterpreterWrapper::Invoke() {
+ TFLITE_PY_ENSURE_VALID_INTERPRETER();
+ TFLITE_PY_CHECK(interpreter_->Invoke());
+ Py_RETURN_NONE;
}
PyObject* InterpreterWrapper::InputIndices() const {
+ TFLITE_PY_ENSURE_VALID_INTERPRETER();
PyObject* np_array = PyArrayFromIntVector(interpreter_->inputs().data(),
interpreter_->inputs().size());
@@ -186,35 +245,36 @@ PyObject* InterpreterWrapper::OutputIndices() const {
return PyArray_Return(reinterpret_cast<PyArrayObject*>(np_array));
}
-bool InterpreterWrapper::ResizeInputTensor(int i, PyObject* value) {
- if (!interpreter_) {
- LOG(ERROR) << "Invalid interpreter.";
- return false;
- }
+PyObject* InterpreterWrapper::ResizeInputTensor(int i, PyObject* value) {
+ TFLITE_PY_ENSURE_VALID_INTERPRETER();
std::unique_ptr<PyObject, PyDecrefDeleter> array_safe(
PyArray_FromAny(value, nullptr, 0, 0, NPY_ARRAY_CARRAY, nullptr));
if (!array_safe) {
- LOG(ERROR) << "Failed to convert value into readable tensor.";
- return false;
+ PyErr_SetString(PyExc_ValueError,
+ "Failed to convert numpy value into readable tensor.");
+ return nullptr;
}
PyArrayObject* array = reinterpret_cast<PyArrayObject*>(array_safe.get());
if (PyArray_NDIM(array) != 1) {
- LOG(ERROR) << "Expected 1-D defining input shape.";
- return false;
+ PyErr_Format(PyExc_ValueError, "Shape should be 1D instead of %d.",
+ PyArray_NDIM(array));
+ return nullptr;
}
if (PyArray_TYPE(array) != NPY_INT32) {
- LOG(ERROR) << "Shape must be an int32 array";
- return false;
+ PyErr_Format(PyExc_ValueError, "Shape must be type int32 (was %d).",
+ PyArray_TYPE(array));
+ return nullptr;
}
std::vector<int> dims(PyArray_SHAPE(array)[0]);
memcpy(dims.data(), PyArray_BYTES(array), dims.size() * sizeof(int));
- return (interpreter_->ResizeInputTensor(i, dims) == kTfLiteOk);
+ TFLITE_PY_CHECK(interpreter_->ResizeInputTensor(i, dims));
+ Py_RETURN_NONE;
}
std::string InterpreterWrapper::TensorName(int i) const {
@@ -227,21 +287,21 @@ std::string InterpreterWrapper::TensorName(int i) const {
}
PyObject* InterpreterWrapper::TensorType(int i) const {
- if (!interpreter_ || i >= interpreter_->tensors_size() || i < 0) {
- return nullptr;
- }
+ TFLITE_PY_ENSURE_VALID_INTERPRETER();
+ TFLITE_PY_TENSOR_BOUNDS_CHECK(i);
const TfLiteTensor* tensor = interpreter_->tensor(i);
- int typenum = TfLiteTypeToPyArrayType(tensor->type);
- return PyArray_TypeObjectFromType(typenum);
+ int code = TfLiteTypeToPyArrayType(tensor->type);
+ if (code == -1) {
+ PyErr_Format(PyExc_ValueError, "Invalid tflite type code %d", code);
+ return nullptr;
+ }
+ return PyArray_TypeObjectFromType(code);
}
PyObject* InterpreterWrapper::TensorSize(int i) const {
- if (!interpreter_ || i >= interpreter_->tensors_size() || i < 0) {
- Py_INCREF(Py_None);
- return Py_None;
- }
-
+ TFLITE_PY_ENSURE_VALID_INTERPRETER();
+ TFLITE_PY_TENSOR_BOUNDS_CHECK(i);
const TfLiteTensor* tensor = interpreter_->tensor(i);
PyObject* np_array =
PyArrayFromIntVector(tensor->dims->data, tensor->dims->size);
@@ -250,100 +310,87 @@ PyObject* InterpreterWrapper::TensorSize(int i) const {
}
PyObject* InterpreterWrapper::TensorQuantization(int i) const {
- if (!interpreter_ || i >= interpreter_->tensors_size() || i < 0) {
- Py_INCREF(Py_None);
- return Py_None;
- }
-
+ TFLITE_PY_ENSURE_VALID_INTERPRETER();
+ TFLITE_PY_TENSOR_BOUNDS_CHECK(i);
const TfLiteTensor* tensor = interpreter_->tensor(i);
return PyTupleFromQuantizationParam(tensor->params);
}
-bool InterpreterWrapper::SetTensor(int i, PyObject* value) {
- if (!interpreter_) {
- LOG(ERROR) << "Invalid interpreter.";
- return false;
- }
-
- if (i >= interpreter_->tensors_size()) {
- LOG(ERROR) << "Invalid tensor index: " << i << " exceeds max tensor index "
- << interpreter_->tensors_size();
- return false;
- }
+PyObject* InterpreterWrapper::SetTensor(int i, PyObject* value) {
+ TFLITE_PY_ENSURE_VALID_INTERPRETER();
+ TFLITE_PY_TENSOR_BOUNDS_CHECK(i);
std::unique_ptr<PyObject, PyDecrefDeleter> array_safe(
PyArray_FromAny(value, nullptr, 0, 0, NPY_ARRAY_CARRAY, nullptr));
if (!array_safe) {
- LOG(ERROR) << "Failed to convert value into readable tensor.";
- return false;
+ PyErr_SetString(PyExc_ValueError,
+ "Failed to convert value into readable tensor.");
+ return nullptr;
}
PyArrayObject* array = reinterpret_cast<PyArrayObject*>(array_safe.get());
const TfLiteTensor* tensor = interpreter_->tensor(i);
if (TfLiteTypeFromPyArray(array) != tensor->type) {
- LOG(ERROR) << "Cannot set tensor:"
- << " Got tensor of type " << TfLiteTypeFromPyArray(array)
- << " but expected type " << tensor->type << " for input " << i;
- return false;
+ PyErr_Format(PyExc_ValueError,
+ "Cannot set tensor:"
+ " Got tensor of type %d"
+ " but expected type %d for input %d ",
+ TfLiteTypeFromPyArray(array), tensor->type, i);
+ return nullptr;
}
if (PyArray_NDIM(array) != tensor->dims->size) {
- LOG(ERROR) << "Cannot set tensor: Dimension mismatch";
- return false;
+ PyErr_SetString(PyExc_ValueError, "Cannot set tensor: Dimension mismatch");
+ return nullptr;
}
for (int j = 0; j < PyArray_NDIM(array); j++) {
if (tensor->dims->data[j] != PyArray_SHAPE(array)[j]) {
- LOG(ERROR) << "Cannot set tensor: Dimension mismatch";
- return false;
+ PyErr_SetString(PyExc_ValueError,
+ "Cannot set tensor: Dimension mismatch");
+ return nullptr;
}
}
size_t size = PyArray_NBYTES(array);
- DCHECK_EQ(size, tensor->bytes);
+ if (size != tensor->bytes) {
+ PyErr_Format(PyExc_ValueError,
+ "numpy array had %zu bytes but expected %zu bytes.", size,
+ tensor->bytes);
+ return nullptr;
+ }
memcpy(tensor->data.raw, PyArray_DATA(array), size);
- return true;
+ Py_RETURN_NONE;
}
namespace {
-PyObject* CheckGetTensorArgs(Interpreter* interpreter, int tensor_index,
+// Checks to see if a tensor access can succeed (returns nullptr on error).
+// Otherwise returns Py_None.
+PyObject* CheckGetTensorArgs(Interpreter* interpreter_, int tensor_index,
TfLiteTensor** tensor, int* type_num) {
- if (!interpreter) {
- LOG(ERROR) << "Invalid interpreter.";
- Py_INCREF(Py_None);
- return Py_None;
- }
+ TFLITE_PY_ENSURE_VALID_INTERPRETER();
+ TFLITE_PY_TENSOR_BOUNDS_CHECK(tensor_index);
- if (tensor_index >= interpreter->tensors_size() || tensor_index < 0) {
- LOG(ERROR) << "Invalid tensor index: " << tensor_index
- << " exceeds max tensor index " << interpreter->inputs().size();
- Py_INCREF(Py_None);
- return Py_None;
- }
-
- *tensor = interpreter->tensor(tensor_index);
+ *tensor = interpreter_->tensor(tensor_index);
if ((*tensor)->bytes == 0) {
- LOG(ERROR) << "Invalid tensor size";
- Py_INCREF(Py_None);
- return Py_None;
+ PyErr_SetString(PyExc_ValueError, "Invalid tensor size.");
+ return nullptr;
}
*type_num = TfLiteTypeToPyArrayType((*tensor)->type);
if (*type_num == -1) {
- LOG(ERROR) << "Unknown tensor type " << (*tensor)->type;
- Py_INCREF(Py_None);
- return Py_None;
+ PyErr_SetString(PyExc_ValueError, "Unknown tensor type.");
+ return nullptr;
}
if (!(*tensor)->data.raw) {
- LOG(ERROR) << "Tensor data is null.";
- Py_INCREF(Py_None);
- return Py_None;
+ PyErr_SetString(PyExc_ValueError, "Tensor data is null.");
+ return nullptr;
}
- return nullptr;
+ Py_RETURN_NONE;
}
} // namespace
@@ -352,19 +399,20 @@ PyObject* InterpreterWrapper::GetTensor(int i) const {
// Sanity check accessor
TfLiteTensor* tensor = nullptr;
int type_num = 0;
- if (PyObject* pynone_or_nullptr =
- CheckGetTensorArgs(interpreter_.get(), i, &tensor, &type_num)) {
- return pynone_or_nullptr;
- }
+
+ PyObject* check_result =
+ CheckGetTensorArgs(interpreter_.get(), i, &tensor, &type_num);
+ if (check_result == nullptr) return check_result;
+ Py_XDECREF(check_result);
+
std::vector<npy_intp> dims(tensor->dims->data,
tensor->dims->data + tensor->dims->size);
// Make a buffer copy but we must tell Numpy It owns that data or else
// it will leak.
void* data = malloc(tensor->bytes);
if (!data) {
- LOG(ERROR) << "Malloc to copy tensor failed.";
- Py_INCREF(Py_None);
- return Py_None;
+ PyErr_SetString(PyExc_ValueError, "Malloc to copy tensor failed.");
+ return nullptr;
}
memcpy(data, tensor->data.raw, tensor->bytes);
PyObject* np_array =
@@ -378,10 +426,11 @@ PyObject* InterpreterWrapper::tensor(PyObject* base_object, int i) {
// Sanity check accessor
TfLiteTensor* tensor = nullptr;
int type_num = 0;
- if (PyObject* pynone_or_nullptr =
- CheckGetTensorArgs(interpreter_.get(), i, &tensor, &type_num)) {
- return pynone_or_nullptr;
- }
+
+ PyObject* check_result =
+ CheckGetTensorArgs(interpreter_.get(), i, &tensor, &type_num);
+ if (check_result == nullptr) return check_result;
+ Py_XDECREF(check_result);
std::vector<npy_intp> dims(tensor->dims->data,
tensor->dims->data + tensor->dims->size);
@@ -394,22 +443,33 @@ PyObject* InterpreterWrapper::tensor(PyObject* base_object, int i) {
}
InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromFile(
- const char* model_path) {
+ const char* model_path, std::string* error_msg) {
+ std::unique_ptr<PythonErrorReporter> error_reporter(new PythonErrorReporter);
std::unique_ptr<tflite::FlatBufferModel> model =
- tflite::FlatBufferModel::BuildFromFile(model_path);
- return model ? new InterpreterWrapper(std::move(model)) : nullptr;
+ tflite::FlatBufferModel::BuildFromFile(model_path, error_reporter.get());
+ return CreateInterpreterWrapper(std::move(model), std::move(error_reporter),
+ error_msg);
}
InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromBuffer(
- PyObject* data) {
+ PyObject* data, std::string* error_msg) {
char * buf = nullptr;
Py_ssize_t length;
+ std::unique_ptr<PythonErrorReporter> error_reporter(new PythonErrorReporter);
if (PY_TO_CPPSTRING(data, &buf, &length) == -1) {
return nullptr;
}
std::unique_ptr<tflite::FlatBufferModel> model =
- tflite::FlatBufferModel::BuildFromBuffer(buf, length);
- return model ? new InterpreterWrapper(std::move(model)) : nullptr;
+ tflite::FlatBufferModel::BuildFromBuffer(buf, length,
+ error_reporter.get());
+ return CreateInterpreterWrapper(std::move(model), std::move(error_reporter),
+ error_msg);
+}
+
+PyObject* InterpreterWrapper::ResetVariableTensorsToZero() {
+ TFLITE_PY_ENSURE_VALID_INTERPRETER();
+ TFLITE_PY_CHECK(interpreter_->ResetVariableTensorsToZero());
+ Py_RETURN_NONE;
}
} // namespace interpreter_wrapper