diff options
author | Andrew Selle <aselle@google.com> | 2018-07-11 10:38:07 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-11 10:47:27 -0700 |
commit | a13948083e6a8100379cf02afecb8f37ce33f40a (patch) | |
tree | d4f59d98d0f60d650566c249a15f6b9a45a6f9e9 /tensorflow | |
parent | 77c25879baac86de114fd77eb753085eaa01b00e (diff) |
Improve error handling and usability of Python interpreter
- Break dependency on tensorflow platform and logging and absl
- Propagate exceptions that capture the TensorFlow lite errors
in a buffer.
PiperOrigin-RevId: 204148724
Diffstat (limited to 'tensorflow')
7 files changed, 245 insertions, 137 deletions
diff --git a/tensorflow/contrib/lite/python/BUILD b/tensorflow/contrib/lite/python/BUILD index 27909a9458..8c9608db04 100644 --- a/tensorflow/contrib/lite/python/BUILD +++ b/tensorflow/contrib/lite/python/BUILD @@ -19,6 +19,7 @@ py_library( visibility = ["//visibility:public"], deps = [ "//tensorflow/contrib/lite/python/interpreter_wrapper:tensorflow_wrap_interpreter_wrapper", + "//tensorflow/python:util", ], ) @@ -30,9 +31,10 @@ py_test( tags = ["no_oss"], deps = [ ":interpreter", - "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", - "//tensorflow/python:platform_test", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform", + "//third_party/py/numpy", ], ) diff --git a/tensorflow/contrib/lite/python/interpreter.py b/tensorflow/contrib/lite/python/interpreter.py index fd90823425..e1981ceae2 100644 --- a/tensorflow/contrib/lite/python/interpreter.py +++ b/tensorflow/contrib/lite/python/interpreter.py @@ -56,9 +56,6 @@ class Interpreter(object): self._interpreter = ( _interpreter_wrapper.InterpreterWrapper_CreateWrapperCPPFromBuffer( model_content)) - if not self._interpreter: - raise ValueError( - 'Failed to create model from {} bytes'.format(len(model_content))) elif not model_path and not model_path: raise ValueError('`model_path` or `model_content` must be specified.') else: @@ -66,8 +63,7 @@ class Interpreter(object): def allocate_tensors(self): self._ensure_safe() - if not self._interpreter.AllocateTensors(): - raise ValueError('Failed to allocate tensors') + return self._interpreter.AllocateTensors() def _safe_to_run(self): """Returns true if there exist no numpy array buffers. @@ -152,8 +148,7 @@ class Interpreter(object): Raises: ValueError: If the interpreter could not set the tensor. """ - if not self._interpreter.SetTensor(tensor_index, value): - raise ValueError('Failed to set tensor') + self._interpreter.SetTensor(tensor_index, value) def resize_tensor_input(self, input_index, tensor_size): """Resizes an input tensor. @@ -167,8 +162,7 @@ class Interpreter(object): ValueError: If the interpreter could not resize the input tensor. """ self._ensure_safe() - if not self._interpreter.ResizeInputTensor(input_index, tensor_size): - raise ValueError('Failed to resize input') + self._interpreter.ResizeInputTensor(input_index, tensor_size) def get_output_details(self): """Gets model output details. @@ -181,7 +175,9 @@ class Interpreter(object): ] def get_tensor(self, tensor_index): - """Gets the value of the input tensor. Note this makes a copy so prefer `tensor()`. + """Gets the value of the input tensor (get a copy). + + If you wish to avoid the copy, use `tensor()`. Args: tensor_index: Tensor index of tensor to get. This value can be gotten from @@ -247,5 +243,7 @@ class Interpreter(object): ValueError: When the underlying interpreter fails raise ValueError. """ self._ensure_safe() - if not self._interpreter.Invoke(): - raise ValueError('Failed to invoke TFLite model') + self._interpreter.Invoke() + + def reset_all_variables_to_zero(self): + return self._interpreter.ResetVariableTensorsToZero() diff --git a/tensorflow/contrib/lite/python/interpreter_test.py b/tensorflow/contrib/lite/python/interpreter_test.py index 5f1fa26c3b..95fa4b8584 100644 --- a/tensorflow/contrib/lite/python/interpreter_test.py +++ b/tensorflow/contrib/lite/python/interpreter_test.py @@ -19,6 +19,7 @@ from __future__ import print_function import io import numpy as np +import six from tensorflow.contrib.lite.python import interpreter as interpreter_wrapper from tensorflow.python.framework import test_util @@ -91,6 +92,28 @@ class InterpreterTest(test_util.TensorFlowTestCase): self.assertTrue((expected_output == output_data).all()) +class InterpreterTestErrorPropagation(test_util.TensorFlowTestCase): + + def testInvalidModelContent(self): + with self.assertRaisesRegexp(ValueError, + 'Model provided has model identifier \''): + interpreter_wrapper.Interpreter(model_content=six.b('garbage')) + + def testInvalidModelFile(self): + with self.assertRaisesRegexp( + ValueError, 'Could not open \'totally_invalid_file_name\''): + interpreter_wrapper.Interpreter( + model_path='totally_invalid_file_name') + + def testInvokeBeforeReady(self): + interpreter = interpreter_wrapper.Interpreter( + model_path=resource_loader.get_path_to_datafile( + 'testdata/permute_float.tflite')) + with self.assertRaisesRegexp(RuntimeError, + 'Invoke called on model that is not ready'): + interpreter.invoke() + + class InterpreterTensorAccessorTest(test_util.TensorFlowTestCase): def setUp(self): diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/BUILD b/tensorflow/contrib/lite/python/interpreter_wrapper/BUILD index 634c2a1e1f..69ee95c320 100644 --- a/tensorflow/contrib/lite/python/interpreter_wrapper/BUILD +++ b/tensorflow/contrib/lite/python/interpreter_wrapper/BUILD @@ -13,7 +13,6 @@ cc_library( deps = [ "//tensorflow/contrib/lite:framework", "//tensorflow/contrib/lite/kernels:builtin_ops", - "//tensorflow/core:lib", "//third_party/py/numpy:headers", "//third_party/python_runtime:headers", "@com_google_absl//absl/memory", diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc index 5554d08fa0..c38b692dcd 100644 --- a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc +++ b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc @@ -14,13 +14,13 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h" +#include <sstream> #include <string> #include "absl/memory/memory.h" #include "tensorflow/contrib/lite/interpreter.h" #include "tensorflow/contrib/lite/kernels/register.h" #include "tensorflow/contrib/lite/model.h" -#include "tensorflow/core/platform/logging.h" // Disallow Numpy 1.7 deprecated symbols. #define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION @@ -38,9 +38,58 @@ limitations under the License. #define CPP_TO_PYSTRING PyString_FromStringAndSize #endif +#define TFLITE_PY_CHECK(x) \ + if ((x) != kTfLiteOk) { \ + return error_reporter_->exception(); \ + } + +#define TFLITE_PY_TENSOR_BOUNDS_CHECK(i) \ + if (i >= interpreter_->tensors_size() || i < 0) { \ + PyErr_Format(PyExc_ValueError, \ + "Invalid tensor index %d exceeds max tensor index %lu", i, \ + interpreter_->tensors_size()); \ + return nullptr; \ + } + +#define TFLITE_PY_ENSURE_VALID_INTERPRETER() \ + if (!interpreter_) { \ + PyErr_SetString(PyExc_ValueError, "Interpreter was not initialized."); \ + return nullptr; \ + } + namespace tflite { namespace interpreter_wrapper { +class PythonErrorReporter : public tflite::ErrorReporter { + public: + PythonErrorReporter() {} + + // Report an error message + int Report(const char* format, va_list args) override { + char buf[1024]; + int formatted = vsnprintf(buf, sizeof(buf), format, args); + buffer_ << buf; + return formatted; + } + + // Set's a Python runtime exception with the last error. + PyObject* exception() { + std::string last_message = message(); + PyErr_SetString(PyExc_RuntimeError, last_message.c_str()); + return nullptr; + } + + // Gets the last error message and clears the buffer. + std::string message() { + std::string value = buffer_.str(); + buffer_.clear(); + return value; + } + + private: + std::stringstream buffer_; +}; + namespace { // Calls PyArray's initialization to initialize all the API pointers. Note that @@ -60,19 +109,6 @@ std::unique_ptr<tflite::Interpreter> CreateInterpreter( std::unique_ptr<tflite::Interpreter> interpreter; tflite::InterpreterBuilder(*model, resolver)(&interpreter); - if (interpreter) { - for (const int input_index : interpreter->inputs()) { - const TfLiteTensor* tensor = interpreter->tensor(input_index); - CHECK(tensor); - const TfLiteIntArray* dims = tensor->dims; - if (!dims) { - continue; - } - - std::vector<int> input_dims(dims->data, dims->data + dims->size); - interpreter->ResizeInputTensor(input_index, input_dims); - } - } return interpreter; } @@ -95,10 +131,10 @@ int TfLiteTypeToPyArrayType(TfLiteType tf_lite_type) { case kTfLiteComplex64: return NPY_COMPLEX64; case kTfLiteNoType: - return -1; + return NPY_NOTYPE; + // Avoid default so compiler errors created when new types are made. } - LOG(ERROR) << "Unknown TfLiteType " << tf_lite_type; - return -1; + return NPY_NOTYPE; } TfLiteType TfLiteTypeFromPyArray(PyArrayObject* array) { @@ -122,8 +158,8 @@ TfLiteType TfLiteTypeFromPyArray(PyArrayObject* array) { return kTfLiteString; case NPY_COMPLEX64: return kTfLiteComplex64; + // Avoid default so compiler errors created when new types are made. } - LOG(ERROR) << "Unknown PyArray dtype " << pyarray_type; return kTfLiteNoType; } @@ -147,32 +183,29 @@ PyObject* PyTupleFromQuantizationParam(const TfLiteQuantizationParams& param) { } // namespace InterpreterWrapper::InterpreterWrapper( - std::unique_ptr<tflite::FlatBufferModel> model) + std::unique_ptr<tflite::FlatBufferModel> model, + std::unique_ptr<PythonErrorReporter> error_reporter) : model_(std::move(model)), + error_reporter_(std::move(error_reporter)), resolver_(absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>()), interpreter_(CreateInterpreter(model_.get(), *resolver_)) {} InterpreterWrapper::~InterpreterWrapper() {} -bool InterpreterWrapper::AllocateTensors() { - if (!interpreter_) { - LOG(ERROR) << "Cannot allocate tensors: invalid interpreter."; - return false; - } - - if (interpreter_->AllocateTensors() != kTfLiteOk) { - LOG(ERROR) << "Unable to allocate tensors."; - return false; - } - - return true; +PyObject* InterpreterWrapper::AllocateTensors() { + TFLITE_PY_ENSURE_VALID_INTERPRETER(); + TFLITE_PY_CHECK(interpreter_->AllocateTensors()); + Py_RETURN_NONE; } -bool InterpreterWrapper::Invoke() { - return interpreter_ ? (interpreter_->Invoke() == kTfLiteOk) : false; +PyObject* InterpreterWrapper::Invoke() { + TFLITE_PY_ENSURE_VALID_INTERPRETER(); + TFLITE_PY_CHECK(interpreter_->Invoke()); + Py_RETURN_NONE; } PyObject* InterpreterWrapper::InputIndices() const { + TFLITE_PY_ENSURE_VALID_INTERPRETER(); PyObject* np_array = PyArrayFromIntVector(interpreter_->inputs().data(), interpreter_->inputs().size()); @@ -186,35 +219,36 @@ PyObject* InterpreterWrapper::OutputIndices() const { return PyArray_Return(reinterpret_cast<PyArrayObject*>(np_array)); } -bool InterpreterWrapper::ResizeInputTensor(int i, PyObject* value) { - if (!interpreter_) { - LOG(ERROR) << "Invalid interpreter."; - return false; - } +PyObject* InterpreterWrapper::ResizeInputTensor(int i, PyObject* value) { + TFLITE_PY_ENSURE_VALID_INTERPRETER(); std::unique_ptr<PyObject, PyDecrefDeleter> array_safe( PyArray_FromAny(value, nullptr, 0, 0, NPY_ARRAY_CARRAY, nullptr)); if (!array_safe) { - LOG(ERROR) << "Failed to convert value into readable tensor."; - return false; + PyErr_SetString(PyExc_ValueError, + "Failed to convert numpy value into readable tensor."); + return nullptr; } PyArrayObject* array = reinterpret_cast<PyArrayObject*>(array_safe.get()); if (PyArray_NDIM(array) != 1) { - LOG(ERROR) << "Expected 1-D defining input shape."; - return false; + PyErr_Format(PyExc_ValueError, "Shape should be 1D instead of %d.", + PyArray_NDIM(array)); + return nullptr; } if (PyArray_TYPE(array) != NPY_INT32) { - LOG(ERROR) << "Shape must be an int32 array"; - return false; + PyErr_Format(PyExc_ValueError, "Shape must be type int32 (was %d).", + PyArray_TYPE(array)); + return nullptr; } std::vector<int> dims(PyArray_SHAPE(array)[0]); memcpy(dims.data(), PyArray_BYTES(array), dims.size() * sizeof(int)); - return (interpreter_->ResizeInputTensor(i, dims) == kTfLiteOk); + TFLITE_PY_CHECK(interpreter_->ResizeInputTensor(i, dims)); + Py_RETURN_NONE; } std::string InterpreterWrapper::TensorName(int i) const { @@ -227,21 +261,21 @@ std::string InterpreterWrapper::TensorName(int i) const { } PyObject* InterpreterWrapper::TensorType(int i) const { - if (!interpreter_ || i >= interpreter_->tensors_size() || i < 0) { - return nullptr; - } + TFLITE_PY_ENSURE_VALID_INTERPRETER(); + TFLITE_PY_TENSOR_BOUNDS_CHECK(i); const TfLiteTensor* tensor = interpreter_->tensor(i); - int typenum = TfLiteTypeToPyArrayType(tensor->type); - return PyArray_TypeObjectFromType(typenum); + int code = TfLiteTypeToPyArrayType(tensor->type); + if (code == -1) { + PyErr_Format(PyExc_ValueError, "Invalid tflite type code %d", code); + return nullptr; + } + return PyArray_TypeObjectFromType(code); } PyObject* InterpreterWrapper::TensorSize(int i) const { - if (!interpreter_ || i >= interpreter_->tensors_size() || i < 0) { - Py_INCREF(Py_None); - return Py_None; - } - + TFLITE_PY_ENSURE_VALID_INTERPRETER(); + TFLITE_PY_TENSOR_BOUNDS_CHECK(i); const TfLiteTensor* tensor = interpreter_->tensor(i); PyObject* np_array = PyArrayFromIntVector(tensor->dims->data, tensor->dims->size); @@ -250,97 +284,82 @@ PyObject* InterpreterWrapper::TensorSize(int i) const { } PyObject* InterpreterWrapper::TensorQuantization(int i) const { - if (!interpreter_ || i >= interpreter_->tensors_size() || i < 0) { - Py_INCREF(Py_None); - return Py_None; - } - + TFLITE_PY_ENSURE_VALID_INTERPRETER(); + TFLITE_PY_TENSOR_BOUNDS_CHECK(i); const TfLiteTensor* tensor = interpreter_->tensor(i); return PyTupleFromQuantizationParam(tensor->params); } -bool InterpreterWrapper::SetTensor(int i, PyObject* value) { - if (!interpreter_) { - LOG(ERROR) << "Invalid interpreter."; - return false; - } - - if (i >= interpreter_->tensors_size()) { - LOG(ERROR) << "Invalid tensor index: " << i << " exceeds max tensor index " - << interpreter_->tensors_size(); - return false; - } +PyObject* InterpreterWrapper::SetTensor(int i, PyObject* value) { + TFLITE_PY_ENSURE_VALID_INTERPRETER(); + TFLITE_PY_TENSOR_BOUNDS_CHECK(i); std::unique_ptr<PyObject, PyDecrefDeleter> array_safe( PyArray_FromAny(value, nullptr, 0, 0, NPY_ARRAY_CARRAY, nullptr)); if (!array_safe) { - LOG(ERROR) << "Failed to convert value into readable tensor."; - return false; + PyErr_SetString(PyExc_ValueError, + "Failed to convert value into readable tensor."); + return nullptr; } PyArrayObject* array = reinterpret_cast<PyArrayObject*>(array_safe.get()); const TfLiteTensor* tensor = interpreter_->tensor(i); if (TfLiteTypeFromPyArray(array) != tensor->type) { - LOG(ERROR) << "Cannot set tensor:" - << " Got tensor of type " << TfLiteTypeFromPyArray(array) - << " but expected type " << tensor->type << " for input " << i; - return false; + PyErr_Format(PyExc_ValueError, + "Cannot set tensor:" + " Got tensor of type %d" + " but expected type %d for input %d ", + TfLiteTypeFromPyArray(array), tensor->type, i); + return nullptr; } if (PyArray_NDIM(array) != tensor->dims->size) { - LOG(ERROR) << "Cannot set tensor: Dimension mismatch"; - return false; + PyErr_SetString(PyExc_ValueError, "Cannot set tensor: Dimension mismatch"); + return nullptr; } for (int j = 0; j < PyArray_NDIM(array); j++) { if (tensor->dims->data[j] != PyArray_SHAPE(array)[j]) { - LOG(ERROR) << "Cannot set tensor: Dimension mismatch"; - return false; + PyErr_SetString(PyExc_ValueError, + "Cannot set tensor: Dimension mismatch"); + return nullptr; } } size_t size = PyArray_NBYTES(array); - DCHECK_EQ(size, tensor->bytes); + if (size != tensor->bytes) { + PyErr_Format(PyExc_ValueError, + "numpy array had %zu bytes but expected %zu bytes.", size, + tensor->bytes); + return nullptr; + } memcpy(tensor->data.raw, PyArray_DATA(array), size); - return true; + Py_RETURN_NONE; } namespace { -PyObject* CheckGetTensorArgs(Interpreter* interpreter, int tensor_index, +PyObject* CheckGetTensorArgs(Interpreter* interpreter_, int tensor_index, TfLiteTensor** tensor, int* type_num) { - if (!interpreter) { - LOG(ERROR) << "Invalid interpreter."; - Py_INCREF(Py_None); - return Py_None; - } - - if (tensor_index >= interpreter->tensors_size() || tensor_index < 0) { - LOG(ERROR) << "Invalid tensor index: " << tensor_index - << " exceeds max tensor index " << interpreter->inputs().size(); - Py_INCREF(Py_None); - return Py_None; - } + TFLITE_PY_ENSURE_VALID_INTERPRETER(); + TFLITE_PY_TENSOR_BOUNDS_CHECK(tensor_index); - *tensor = interpreter->tensor(tensor_index); + *tensor = interpreter_->tensor(tensor_index); if ((*tensor)->bytes == 0) { - LOG(ERROR) << "Invalid tensor size"; - Py_INCREF(Py_None); - return Py_None; + PyErr_SetString(PyExc_ValueError, "Invalid tensor size."); + return nullptr; } *type_num = TfLiteTypeToPyArrayType((*tensor)->type); if (*type_num == -1) { - LOG(ERROR) << "Unknown tensor type " << (*tensor)->type; - Py_INCREF(Py_None); - return Py_None; + PyErr_SetString(PyExc_ValueError, "Unknown tensor type."); + return nullptr; } if (!(*tensor)->data.raw) { - LOG(ERROR) << "Tensor data is null."; - Py_INCREF(Py_None); - return Py_None; + PyErr_SetString(PyExc_ValueError, "Tensor data is null."); + return nullptr; } return nullptr; @@ -362,9 +381,8 @@ PyObject* InterpreterWrapper::GetTensor(int i) const { // it will leak. void* data = malloc(tensor->bytes); if (!data) { - LOG(ERROR) << "Malloc to copy tensor failed."; - Py_INCREF(Py_None); - return Py_None; + PyErr_SetString(PyExc_ValueError, "Malloc to copy tensor failed."); + return nullptr; } memcpy(data, tensor->data.raw, tensor->bytes); PyObject* np_array = @@ -394,22 +412,39 @@ PyObject* InterpreterWrapper::tensor(PyObject* base_object, int i) { } InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromFile( - const char* model_path) { + const char* model_path, std::string* error_msg) { + std::unique_ptr<PythonErrorReporter> error_reporter(new PythonErrorReporter); std::unique_ptr<tflite::FlatBufferModel> model = - tflite::FlatBufferModel::BuildFromFile(model_path); - return model ? new InterpreterWrapper(std::move(model)) : nullptr; + tflite::FlatBufferModel::BuildFromFile(model_path, error_reporter.get()); + if (!model) { + *error_msg = error_reporter->message(); + return nullptr; + } + return new InterpreterWrapper(std::move(model), std::move(error_reporter)); } InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromBuffer( - PyObject* data) { + PyObject* data, std::string* error_msg) { char * buf = nullptr; Py_ssize_t length; + std::unique_ptr<PythonErrorReporter> error_reporter(new PythonErrorReporter); if (PY_TO_CPPSTRING(data, &buf, &length) == -1) { return nullptr; } std::unique_ptr<tflite::FlatBufferModel> model = - tflite::FlatBufferModel::BuildFromBuffer(buf, length); - return model ? new InterpreterWrapper(std::move(model)) : nullptr; + tflite::FlatBufferModel::BuildFromBuffer(buf, length, + error_reporter.get()); + if (!model) { + *error_msg = error_reporter->message(); + return nullptr; + } + return new InterpreterWrapper(std::move(model), std::move(error_reporter)); +} + +PyObject* InterpreterWrapper::ResetVariableTensorsToZero() { + TFLITE_PY_ENSURE_VALID_INTERPRETER(); + TFLITE_PY_CHECK(interpreter_->ResetVariableTensorsToZero()); + Py_RETURN_NONE; } } // namespace interpreter_wrapper diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h index 681448be20..febfd2dc56 100644 --- a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h +++ b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h @@ -36,34 +36,41 @@ class Interpreter; namespace interpreter_wrapper { +class PythonErrorReporter; + class InterpreterWrapper { public: // SWIG caller takes ownership of pointer. - static InterpreterWrapper* CreateWrapperCPPFromFile(const char* model_path); + static InterpreterWrapper* CreateWrapperCPPFromFile(const char* model_path, + std::string* error_msg); // SWIG caller takes ownership of pointer. - static InterpreterWrapper* CreateWrapperCPPFromBuffer(PyObject* data); + static InterpreterWrapper* CreateWrapperCPPFromBuffer(PyObject* data, + std::string* error_msg); ~InterpreterWrapper(); - bool AllocateTensors(); - bool Invoke(); + PyObject* AllocateTensors(); + PyObject* Invoke(); PyObject* InputIndices() const; PyObject* OutputIndices() const; - bool ResizeInputTensor(int i, PyObject* value); + PyObject* ResizeInputTensor(int i, PyObject* value); std::string TensorName(int i) const; PyObject* TensorType(int i) const; PyObject* TensorSize(int i) const; PyObject* TensorQuantization(int i) const; - bool SetTensor(int i, PyObject* value); + PyObject* SetTensor(int i, PyObject* value); PyObject* GetTensor(int i) const; + PyObject* ResetVariableTensorsToZero(); + // Returns a reference to tensor index i as a numpy array. The base_object // should be the interpreter object providing the memory. PyObject* tensor(PyObject* base_object, int i); private: - InterpreterWrapper(std::unique_ptr<tflite::FlatBufferModel> model); + InterpreterWrapper(std::unique_ptr<tflite::FlatBufferModel> model, + std::unique_ptr<PythonErrorReporter> error_reporter); // InterpreterWrapper is not copyable or assignable. We avoid the use of // InterpreterWrapper() = delete here for SWIG compatibility. @@ -71,6 +78,7 @@ class InterpreterWrapper { InterpreterWrapper(const InterpreterWrapper& rhs); const std::unique_ptr<tflite::FlatBufferModel> model_; + const std::unique_ptr<PythonErrorReporter> error_reporter_; const std::unique_ptr<tflite::ops::builtin::BuiltinOpResolver> resolver_; const std::unique_ptr<tflite::Interpreter> interpreter_; }; diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.i b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.i index 7f51f9f00d..afb2092eac 100644 --- a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.i +++ b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.i @@ -18,8 +18,51 @@ limitations under the License. %{ #define SWIG_FILE_WITH_INIT +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/model.h" #include "tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h" %} %include "tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h" + +namespace tflite { +namespace interpreter_wrapper { +%extend InterpreterWrapper { + + // Version of the constructor that handles producing Python exceptions + // that propagate strings. + static PyObject* CreateWrapperCPPFromFile(const char* model_path) { + std::string error; + if(tflite::interpreter_wrapper::InterpreterWrapper* ptr = + tflite::interpreter_wrapper::InterpreterWrapper + ::CreateWrapperCPPFromFile( + model_path, &error)) { + return SWIG_NewPointerObj( + ptr, SWIGTYPE_p_tflite__interpreter_wrapper__InterpreterWrapper, 1); + } else { + PyErr_SetString(PyExc_ValueError, error.c_str()); + return nullptr; + } + } + + // Version of the constructor that handles producing Python exceptions + // that propagate strings. + static PyObject* CreateWrapperCPPFromBuffer( + PyObject* data) { + std::string error; + if(tflite::interpreter_wrapper::InterpreterWrapper* ptr = + tflite::interpreter_wrapper::InterpreterWrapper + ::CreateWrapperCPPFromBuffer( + data, &error)) { + return SWIG_NewPointerObj( + ptr, SWIGTYPE_p_tflite__interpreter_wrapper__InterpreterWrapper, 1); + } else { + PyErr_SetString(PyExc_ValueError, error.c_str()); + return nullptr; + } + } +} + +} // namespace interpreter_wrapper +} // namespace tflite |