aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/python
diff options
context:
space:
mode:
authorGravatar Andrew Selle <aselle@google.com>2018-07-11 10:38:07 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-11 10:47:27 -0700
commita13948083e6a8100379cf02afecb8f37ce33f40a (patch)
treed4f59d98d0f60d650566c249a15f6b9a45a6f9e9 /tensorflow/contrib/lite/python
parent77c25879baac86de114fd77eb753085eaa01b00e (diff)
Improve error handling and usability of Python interpreter
- Break dependency on tensorflow platform and logging and absl - Propagate exceptions that capture the TensorFlow lite errors in a buffer. PiperOrigin-RevId: 204148724
Diffstat (limited to 'tensorflow/contrib/lite/python')
-rw-r--r--tensorflow/contrib/lite/python/BUILD6
-rw-r--r--tensorflow/contrib/lite/python/interpreter.py22
-rw-r--r--tensorflow/contrib/lite/python/interpreter_test.py23
-rw-r--r--tensorflow/contrib/lite/python/interpreter_wrapper/BUILD1
-rw-r--r--tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc265
-rw-r--r--tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h22
-rw-r--r--tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.i43
7 files changed, 245 insertions, 137 deletions
diff --git a/tensorflow/contrib/lite/python/BUILD b/tensorflow/contrib/lite/python/BUILD
index 27909a9458..8c9608db04 100644
--- a/tensorflow/contrib/lite/python/BUILD
+++ b/tensorflow/contrib/lite/python/BUILD
@@ -19,6 +19,7 @@ py_library(
visibility = ["//visibility:public"],
deps = [
"//tensorflow/contrib/lite/python/interpreter_wrapper:tensorflow_wrap_interpreter_wrapper",
+ "//tensorflow/python:util",
],
)
@@ -30,9 +31,10 @@ py_test(
tags = ["no_oss"],
deps = [
":interpreter",
- "//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
- "//tensorflow/python:platform_test",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:platform",
+ "//third_party/py/numpy",
],
)
diff --git a/tensorflow/contrib/lite/python/interpreter.py b/tensorflow/contrib/lite/python/interpreter.py
index fd90823425..e1981ceae2 100644
--- a/tensorflow/contrib/lite/python/interpreter.py
+++ b/tensorflow/contrib/lite/python/interpreter.py
@@ -56,9 +56,6 @@ class Interpreter(object):
self._interpreter = (
_interpreter_wrapper.InterpreterWrapper_CreateWrapperCPPFromBuffer(
model_content))
- if not self._interpreter:
- raise ValueError(
- 'Failed to create model from {} bytes'.format(len(model_content)))
elif not model_path and not model_path:
raise ValueError('`model_path` or `model_content` must be specified.')
else:
@@ -66,8 +63,7 @@ class Interpreter(object):
def allocate_tensors(self):
self._ensure_safe()
- if not self._interpreter.AllocateTensors():
- raise ValueError('Failed to allocate tensors')
+ return self._interpreter.AllocateTensors()
def _safe_to_run(self):
"""Returns true if there exist no numpy array buffers.
@@ -152,8 +148,7 @@ class Interpreter(object):
Raises:
ValueError: If the interpreter could not set the tensor.
"""
- if not self._interpreter.SetTensor(tensor_index, value):
- raise ValueError('Failed to set tensor')
+ self._interpreter.SetTensor(tensor_index, value)
def resize_tensor_input(self, input_index, tensor_size):
"""Resizes an input tensor.
@@ -167,8 +162,7 @@ class Interpreter(object):
ValueError: If the interpreter could not resize the input tensor.
"""
self._ensure_safe()
- if not self._interpreter.ResizeInputTensor(input_index, tensor_size):
- raise ValueError('Failed to resize input')
+ self._interpreter.ResizeInputTensor(input_index, tensor_size)
def get_output_details(self):
"""Gets model output details.
@@ -181,7 +175,9 @@ class Interpreter(object):
]
def get_tensor(self, tensor_index):
- """Gets the value of the input tensor. Note this makes a copy so prefer `tensor()`.
+ """Gets the value of the input tensor (get a copy).
+
+ If you wish to avoid the copy, use `tensor()`.
Args:
tensor_index: Tensor index of tensor to get. This value can be gotten from
@@ -247,5 +243,7 @@ class Interpreter(object):
ValueError: When the underlying interpreter fails raise ValueError.
"""
self._ensure_safe()
- if not self._interpreter.Invoke():
- raise ValueError('Failed to invoke TFLite model')
+ self._interpreter.Invoke()
+
+ def reset_all_variables_to_zero(self):
+ return self._interpreter.ResetVariableTensorsToZero()
diff --git a/tensorflow/contrib/lite/python/interpreter_test.py b/tensorflow/contrib/lite/python/interpreter_test.py
index 5f1fa26c3b..95fa4b8584 100644
--- a/tensorflow/contrib/lite/python/interpreter_test.py
+++ b/tensorflow/contrib/lite/python/interpreter_test.py
@@ -19,6 +19,7 @@ from __future__ import print_function
import io
import numpy as np
+import six
from tensorflow.contrib.lite.python import interpreter as interpreter_wrapper
from tensorflow.python.framework import test_util
@@ -91,6 +92,28 @@ class InterpreterTest(test_util.TensorFlowTestCase):
self.assertTrue((expected_output == output_data).all())
+class InterpreterTestErrorPropagation(test_util.TensorFlowTestCase):
+
+ def testInvalidModelContent(self):
+ with self.assertRaisesRegexp(ValueError,
+ 'Model provided has model identifier \''):
+ interpreter_wrapper.Interpreter(model_content=six.b('garbage'))
+
+ def testInvalidModelFile(self):
+ with self.assertRaisesRegexp(
+ ValueError, 'Could not open \'totally_invalid_file_name\''):
+ interpreter_wrapper.Interpreter(
+ model_path='totally_invalid_file_name')
+
+ def testInvokeBeforeReady(self):
+ interpreter = interpreter_wrapper.Interpreter(
+ model_path=resource_loader.get_path_to_datafile(
+ 'testdata/permute_float.tflite'))
+ with self.assertRaisesRegexp(RuntimeError,
+ 'Invoke called on model that is not ready'):
+ interpreter.invoke()
+
+
class InterpreterTensorAccessorTest(test_util.TensorFlowTestCase):
def setUp(self):
diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/BUILD b/tensorflow/contrib/lite/python/interpreter_wrapper/BUILD
index 634c2a1e1f..69ee95c320 100644
--- a/tensorflow/contrib/lite/python/interpreter_wrapper/BUILD
+++ b/tensorflow/contrib/lite/python/interpreter_wrapper/BUILD
@@ -13,7 +13,6 @@ cc_library(
deps = [
"//tensorflow/contrib/lite:framework",
"//tensorflow/contrib/lite/kernels:builtin_ops",
- "//tensorflow/core:lib",
"//third_party/py/numpy:headers",
"//third_party/python_runtime:headers",
"@com_google_absl//absl/memory",
diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc
index 5554d08fa0..c38b692dcd 100644
--- a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc
+++ b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc
@@ -14,13 +14,13 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h"
+#include <sstream>
#include <string>
#include "absl/memory/memory.h"
#include "tensorflow/contrib/lite/interpreter.h"
#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/model.h"
-#include "tensorflow/core/platform/logging.h"
// Disallow Numpy 1.7 deprecated symbols.
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
@@ -38,9 +38,58 @@ limitations under the License.
#define CPP_TO_PYSTRING PyString_FromStringAndSize
#endif
+#define TFLITE_PY_CHECK(x) \
+ if ((x) != kTfLiteOk) { \
+ return error_reporter_->exception(); \
+ }
+
+#define TFLITE_PY_TENSOR_BOUNDS_CHECK(i) \
+ if (i >= interpreter_->tensors_size() || i < 0) { \
+ PyErr_Format(PyExc_ValueError, \
+ "Invalid tensor index %d exceeds max tensor index %lu", i, \
+ interpreter_->tensors_size()); \
+ return nullptr; \
+ }
+
+#define TFLITE_PY_ENSURE_VALID_INTERPRETER() \
+ if (!interpreter_) { \
+ PyErr_SetString(PyExc_ValueError, "Interpreter was not initialized."); \
+ return nullptr; \
+ }
+
namespace tflite {
namespace interpreter_wrapper {
+class PythonErrorReporter : public tflite::ErrorReporter {
+ public:
+ PythonErrorReporter() {}
+
+ // Report an error message
+ int Report(const char* format, va_list args) override {
+ char buf[1024];
+ int formatted = vsnprintf(buf, sizeof(buf), format, args);
+ buffer_ << buf;
+ return formatted;
+ }
+
+ // Set's a Python runtime exception with the last error.
+ PyObject* exception() {
+ std::string last_message = message();
+ PyErr_SetString(PyExc_RuntimeError, last_message.c_str());
+ return nullptr;
+ }
+
+ // Gets the last error message and clears the buffer.
+ std::string message() {
+ std::string value = buffer_.str();
+ buffer_.clear();
+ return value;
+ }
+
+ private:
+ std::stringstream buffer_;
+};
+
namespace {
// Calls PyArray's initialization to initialize all the API pointers. Note that
@@ -60,19 +109,6 @@ std::unique_ptr<tflite::Interpreter> CreateInterpreter(
std::unique_ptr<tflite::Interpreter> interpreter;
tflite::InterpreterBuilder(*model, resolver)(&interpreter);
- if (interpreter) {
- for (const int input_index : interpreter->inputs()) {
- const TfLiteTensor* tensor = interpreter->tensor(input_index);
- CHECK(tensor);
- const TfLiteIntArray* dims = tensor->dims;
- if (!dims) {
- continue;
- }
-
- std::vector<int> input_dims(dims->data, dims->data + dims->size);
- interpreter->ResizeInputTensor(input_index, input_dims);
- }
- }
return interpreter;
}
@@ -95,10 +131,10 @@ int TfLiteTypeToPyArrayType(TfLiteType tf_lite_type) {
case kTfLiteComplex64:
return NPY_COMPLEX64;
case kTfLiteNoType:
- return -1;
+ return NPY_NOTYPE;
+ // Avoid default so compiler errors created when new types are made.
}
- LOG(ERROR) << "Unknown TfLiteType " << tf_lite_type;
- return -1;
+ return NPY_NOTYPE;
}
TfLiteType TfLiteTypeFromPyArray(PyArrayObject* array) {
@@ -122,8 +158,8 @@ TfLiteType TfLiteTypeFromPyArray(PyArrayObject* array) {
return kTfLiteString;
case NPY_COMPLEX64:
return kTfLiteComplex64;
+ // Avoid default so compiler errors created when new types are made.
}
- LOG(ERROR) << "Unknown PyArray dtype " << pyarray_type;
return kTfLiteNoType;
}
@@ -147,32 +183,29 @@ PyObject* PyTupleFromQuantizationParam(const TfLiteQuantizationParams& param) {
} // namespace
InterpreterWrapper::InterpreterWrapper(
- std::unique_ptr<tflite::FlatBufferModel> model)
+ std::unique_ptr<tflite::FlatBufferModel> model,
+ std::unique_ptr<PythonErrorReporter> error_reporter)
: model_(std::move(model)),
+ error_reporter_(std::move(error_reporter)),
resolver_(absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>()),
interpreter_(CreateInterpreter(model_.get(), *resolver_)) {}
InterpreterWrapper::~InterpreterWrapper() {}
-bool InterpreterWrapper::AllocateTensors() {
- if (!interpreter_) {
- LOG(ERROR) << "Cannot allocate tensors: invalid interpreter.";
- return false;
- }
-
- if (interpreter_->AllocateTensors() != kTfLiteOk) {
- LOG(ERROR) << "Unable to allocate tensors.";
- return false;
- }
-
- return true;
+PyObject* InterpreterWrapper::AllocateTensors() {
+ TFLITE_PY_ENSURE_VALID_INTERPRETER();
+ TFLITE_PY_CHECK(interpreter_->AllocateTensors());
+ Py_RETURN_NONE;
}
-bool InterpreterWrapper::Invoke() {
- return interpreter_ ? (interpreter_->Invoke() == kTfLiteOk) : false;
+PyObject* InterpreterWrapper::Invoke() {
+ TFLITE_PY_ENSURE_VALID_INTERPRETER();
+ TFLITE_PY_CHECK(interpreter_->Invoke());
+ Py_RETURN_NONE;
}
PyObject* InterpreterWrapper::InputIndices() const {
+ TFLITE_PY_ENSURE_VALID_INTERPRETER();
PyObject* np_array = PyArrayFromIntVector(interpreter_->inputs().data(),
interpreter_->inputs().size());
@@ -186,35 +219,36 @@ PyObject* InterpreterWrapper::OutputIndices() const {
return PyArray_Return(reinterpret_cast<PyArrayObject*>(np_array));
}
-bool InterpreterWrapper::ResizeInputTensor(int i, PyObject* value) {
- if (!interpreter_) {
- LOG(ERROR) << "Invalid interpreter.";
- return false;
- }
+PyObject* InterpreterWrapper::ResizeInputTensor(int i, PyObject* value) {
+ TFLITE_PY_ENSURE_VALID_INTERPRETER();
std::unique_ptr<PyObject, PyDecrefDeleter> array_safe(
PyArray_FromAny(value, nullptr, 0, 0, NPY_ARRAY_CARRAY, nullptr));
if (!array_safe) {
- LOG(ERROR) << "Failed to convert value into readable tensor.";
- return false;
+ PyErr_SetString(PyExc_ValueError,
+ "Failed to convert numpy value into readable tensor.");
+ return nullptr;
}
PyArrayObject* array = reinterpret_cast<PyArrayObject*>(array_safe.get());
if (PyArray_NDIM(array) != 1) {
- LOG(ERROR) << "Expected 1-D defining input shape.";
- return false;
+ PyErr_Format(PyExc_ValueError, "Shape should be 1D instead of %d.",
+ PyArray_NDIM(array));
+ return nullptr;
}
if (PyArray_TYPE(array) != NPY_INT32) {
- LOG(ERROR) << "Shape must be an int32 array";
- return false;
+ PyErr_Format(PyExc_ValueError, "Shape must be type int32 (was %d).",
+ PyArray_TYPE(array));
+ return nullptr;
}
std::vector<int> dims(PyArray_SHAPE(array)[0]);
memcpy(dims.data(), PyArray_BYTES(array), dims.size() * sizeof(int));
- return (interpreter_->ResizeInputTensor(i, dims) == kTfLiteOk);
+ TFLITE_PY_CHECK(interpreter_->ResizeInputTensor(i, dims));
+ Py_RETURN_NONE;
}
std::string InterpreterWrapper::TensorName(int i) const {
@@ -227,21 +261,21 @@ std::string InterpreterWrapper::TensorName(int i) const {
}
PyObject* InterpreterWrapper::TensorType(int i) const {
- if (!interpreter_ || i >= interpreter_->tensors_size() || i < 0) {
- return nullptr;
- }
+ TFLITE_PY_ENSURE_VALID_INTERPRETER();
+ TFLITE_PY_TENSOR_BOUNDS_CHECK(i);
const TfLiteTensor* tensor = interpreter_->tensor(i);
- int typenum = TfLiteTypeToPyArrayType(tensor->type);
- return PyArray_TypeObjectFromType(typenum);
+ int code = TfLiteTypeToPyArrayType(tensor->type);
+ if (code == -1) {
+ PyErr_Format(PyExc_ValueError, "Invalid tflite type code %d", code);
+ return nullptr;
+ }
+ return PyArray_TypeObjectFromType(code);
}
PyObject* InterpreterWrapper::TensorSize(int i) const {
- if (!interpreter_ || i >= interpreter_->tensors_size() || i < 0) {
- Py_INCREF(Py_None);
- return Py_None;
- }
-
+ TFLITE_PY_ENSURE_VALID_INTERPRETER();
+ TFLITE_PY_TENSOR_BOUNDS_CHECK(i);
const TfLiteTensor* tensor = interpreter_->tensor(i);
PyObject* np_array =
PyArrayFromIntVector(tensor->dims->data, tensor->dims->size);
@@ -250,97 +284,82 @@ PyObject* InterpreterWrapper::TensorSize(int i) const {
}
PyObject* InterpreterWrapper::TensorQuantization(int i) const {
- if (!interpreter_ || i >= interpreter_->tensors_size() || i < 0) {
- Py_INCREF(Py_None);
- return Py_None;
- }
-
+ TFLITE_PY_ENSURE_VALID_INTERPRETER();
+ TFLITE_PY_TENSOR_BOUNDS_CHECK(i);
const TfLiteTensor* tensor = interpreter_->tensor(i);
return PyTupleFromQuantizationParam(tensor->params);
}
-bool InterpreterWrapper::SetTensor(int i, PyObject* value) {
- if (!interpreter_) {
- LOG(ERROR) << "Invalid interpreter.";
- return false;
- }
-
- if (i >= interpreter_->tensors_size()) {
- LOG(ERROR) << "Invalid tensor index: " << i << " exceeds max tensor index "
- << interpreter_->tensors_size();
- return false;
- }
+PyObject* InterpreterWrapper::SetTensor(int i, PyObject* value) {
+ TFLITE_PY_ENSURE_VALID_INTERPRETER();
+ TFLITE_PY_TENSOR_BOUNDS_CHECK(i);
std::unique_ptr<PyObject, PyDecrefDeleter> array_safe(
PyArray_FromAny(value, nullptr, 0, 0, NPY_ARRAY_CARRAY, nullptr));
if (!array_safe) {
- LOG(ERROR) << "Failed to convert value into readable tensor.";
- return false;
+ PyErr_SetString(PyExc_ValueError,
+ "Failed to convert value into readable tensor.");
+ return nullptr;
}
PyArrayObject* array = reinterpret_cast<PyArrayObject*>(array_safe.get());
const TfLiteTensor* tensor = interpreter_->tensor(i);
if (TfLiteTypeFromPyArray(array) != tensor->type) {
- LOG(ERROR) << "Cannot set tensor:"
- << " Got tensor of type " << TfLiteTypeFromPyArray(array)
- << " but expected type " << tensor->type << " for input " << i;
- return false;
+ PyErr_Format(PyExc_ValueError,
+ "Cannot set tensor:"
+ " Got tensor of type %d"
+ " but expected type %d for input %d ",
+ TfLiteTypeFromPyArray(array), tensor->type, i);
+ return nullptr;
}
if (PyArray_NDIM(array) != tensor->dims->size) {
- LOG(ERROR) << "Cannot set tensor: Dimension mismatch";
- return false;
+ PyErr_SetString(PyExc_ValueError, "Cannot set tensor: Dimension mismatch");
+ return nullptr;
}
for (int j = 0; j < PyArray_NDIM(array); j++) {
if (tensor->dims->data[j] != PyArray_SHAPE(array)[j]) {
- LOG(ERROR) << "Cannot set tensor: Dimension mismatch";
- return false;
+ PyErr_SetString(PyExc_ValueError,
+ "Cannot set tensor: Dimension mismatch");
+ return nullptr;
}
}
size_t size = PyArray_NBYTES(array);
- DCHECK_EQ(size, tensor->bytes);
+ if (size != tensor->bytes) {
+ PyErr_Format(PyExc_ValueError,
+ "numpy array had %zu bytes but expected %zu bytes.", size,
+ tensor->bytes);
+ return nullptr;
+ }
memcpy(tensor->data.raw, PyArray_DATA(array), size);
- return true;
+ Py_RETURN_NONE;
}
namespace {
-PyObject* CheckGetTensorArgs(Interpreter* interpreter, int tensor_index,
+PyObject* CheckGetTensorArgs(Interpreter* interpreter_, int tensor_index,
TfLiteTensor** tensor, int* type_num) {
- if (!interpreter) {
- LOG(ERROR) << "Invalid interpreter.";
- Py_INCREF(Py_None);
- return Py_None;
- }
-
- if (tensor_index >= interpreter->tensors_size() || tensor_index < 0) {
- LOG(ERROR) << "Invalid tensor index: " << tensor_index
- << " exceeds max tensor index " << interpreter->inputs().size();
- Py_INCREF(Py_None);
- return Py_None;
- }
+ TFLITE_PY_ENSURE_VALID_INTERPRETER();
+ TFLITE_PY_TENSOR_BOUNDS_CHECK(tensor_index);
- *tensor = interpreter->tensor(tensor_index);
+ *tensor = interpreter_->tensor(tensor_index);
if ((*tensor)->bytes == 0) {
- LOG(ERROR) << "Invalid tensor size";
- Py_INCREF(Py_None);
- return Py_None;
+ PyErr_SetString(PyExc_ValueError, "Invalid tensor size.");
+ return nullptr;
}
*type_num = TfLiteTypeToPyArrayType((*tensor)->type);
if (*type_num == -1) {
- LOG(ERROR) << "Unknown tensor type " << (*tensor)->type;
- Py_INCREF(Py_None);
- return Py_None;
+ PyErr_SetString(PyExc_ValueError, "Unknown tensor type.");
+ return nullptr;
}
if (!(*tensor)->data.raw) {
- LOG(ERROR) << "Tensor data is null.";
- Py_INCREF(Py_None);
- return Py_None;
+ PyErr_SetString(PyExc_ValueError, "Tensor data is null.");
+ return nullptr;
}
return nullptr;
@@ -362,9 +381,8 @@ PyObject* InterpreterWrapper::GetTensor(int i) const {
// it will leak.
void* data = malloc(tensor->bytes);
if (!data) {
- LOG(ERROR) << "Malloc to copy tensor failed.";
- Py_INCREF(Py_None);
- return Py_None;
+ PyErr_SetString(PyExc_ValueError, "Malloc to copy tensor failed.");
+ return nullptr;
}
memcpy(data, tensor->data.raw, tensor->bytes);
PyObject* np_array =
@@ -394,22 +412,39 @@ PyObject* InterpreterWrapper::tensor(PyObject* base_object, int i) {
}
InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromFile(
- const char* model_path) {
+ const char* model_path, std::string* error_msg) {
+ std::unique_ptr<PythonErrorReporter> error_reporter(new PythonErrorReporter);
std::unique_ptr<tflite::FlatBufferModel> model =
- tflite::FlatBufferModel::BuildFromFile(model_path);
- return model ? new InterpreterWrapper(std::move(model)) : nullptr;
+ tflite::FlatBufferModel::BuildFromFile(model_path, error_reporter.get());
+ if (!model) {
+ *error_msg = error_reporter->message();
+ return nullptr;
+ }
+ return new InterpreterWrapper(std::move(model), std::move(error_reporter));
}
InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromBuffer(
- PyObject* data) {
+ PyObject* data, std::string* error_msg) {
char * buf = nullptr;
Py_ssize_t length;
+ std::unique_ptr<PythonErrorReporter> error_reporter(new PythonErrorReporter);
if (PY_TO_CPPSTRING(data, &buf, &length) == -1) {
return nullptr;
}
std::unique_ptr<tflite::FlatBufferModel> model =
- tflite::FlatBufferModel::BuildFromBuffer(buf, length);
- return model ? new InterpreterWrapper(std::move(model)) : nullptr;
+ tflite::FlatBufferModel::BuildFromBuffer(buf, length,
+ error_reporter.get());
+ if (!model) {
+ *error_msg = error_reporter->message();
+ return nullptr;
+ }
+ return new InterpreterWrapper(std::move(model), std::move(error_reporter));
+}
+
+PyObject* InterpreterWrapper::ResetVariableTensorsToZero() {
+ TFLITE_PY_ENSURE_VALID_INTERPRETER();
+ TFLITE_PY_CHECK(interpreter_->ResetVariableTensorsToZero());
+ Py_RETURN_NONE;
}
} // namespace interpreter_wrapper
diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h
index 681448be20..febfd2dc56 100644
--- a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h
+++ b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h
@@ -36,34 +36,41 @@ class Interpreter;
namespace interpreter_wrapper {
+class PythonErrorReporter;
+
class InterpreterWrapper {
public:
// SWIG caller takes ownership of pointer.
- static InterpreterWrapper* CreateWrapperCPPFromFile(const char* model_path);
+ static InterpreterWrapper* CreateWrapperCPPFromFile(const char* model_path,
+ std::string* error_msg);
// SWIG caller takes ownership of pointer.
- static InterpreterWrapper* CreateWrapperCPPFromBuffer(PyObject* data);
+ static InterpreterWrapper* CreateWrapperCPPFromBuffer(PyObject* data,
+ std::string* error_msg);
~InterpreterWrapper();
- bool AllocateTensors();
- bool Invoke();
+ PyObject* AllocateTensors();
+ PyObject* Invoke();
PyObject* InputIndices() const;
PyObject* OutputIndices() const;
- bool ResizeInputTensor(int i, PyObject* value);
+ PyObject* ResizeInputTensor(int i, PyObject* value);
std::string TensorName(int i) const;
PyObject* TensorType(int i) const;
PyObject* TensorSize(int i) const;
PyObject* TensorQuantization(int i) const;
- bool SetTensor(int i, PyObject* value);
+ PyObject* SetTensor(int i, PyObject* value);
PyObject* GetTensor(int i) const;
+ PyObject* ResetVariableTensorsToZero();
+
// Returns a reference to tensor index i as a numpy array. The base_object
// should be the interpreter object providing the memory.
PyObject* tensor(PyObject* base_object, int i);
private:
- InterpreterWrapper(std::unique_ptr<tflite::FlatBufferModel> model);
+ InterpreterWrapper(std::unique_ptr<tflite::FlatBufferModel> model,
+ std::unique_ptr<PythonErrorReporter> error_reporter);
// InterpreterWrapper is not copyable or assignable. We avoid the use of
// InterpreterWrapper() = delete here for SWIG compatibility.
@@ -71,6 +78,7 @@ class InterpreterWrapper {
InterpreterWrapper(const InterpreterWrapper& rhs);
const std::unique_ptr<tflite::FlatBufferModel> model_;
+ const std::unique_ptr<PythonErrorReporter> error_reporter_;
const std::unique_ptr<tflite::ops::builtin::BuiltinOpResolver> resolver_;
const std::unique_ptr<tflite::Interpreter> interpreter_;
};
diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.i b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.i
index 7f51f9f00d..afb2092eac 100644
--- a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.i
+++ b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.i
@@ -18,8 +18,51 @@ limitations under the License.
%{
#define SWIG_FILE_WITH_INIT
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/model.h"
#include "tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h"
%}
%include "tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h"
+
+namespace tflite {
+namespace interpreter_wrapper {
+%extend InterpreterWrapper {
+
+ // Version of the constructor that handles producing Python exceptions
+ // that propagate strings.
+ static PyObject* CreateWrapperCPPFromFile(const char* model_path) {
+ std::string error;
+ if(tflite::interpreter_wrapper::InterpreterWrapper* ptr =
+ tflite::interpreter_wrapper::InterpreterWrapper
+ ::CreateWrapperCPPFromFile(
+ model_path, &error)) {
+ return SWIG_NewPointerObj(
+ ptr, SWIGTYPE_p_tflite__interpreter_wrapper__InterpreterWrapper, 1);
+ } else {
+ PyErr_SetString(PyExc_ValueError, error.c_str());
+ return nullptr;
+ }
+ }
+
+ // Version of the constructor that handles producing Python exceptions
+ // that propagate strings.
+ static PyObject* CreateWrapperCPPFromBuffer(
+ PyObject* data) {
+ std::string error;
+ if(tflite::interpreter_wrapper::InterpreterWrapper* ptr =
+ tflite::interpreter_wrapper::InterpreterWrapper
+ ::CreateWrapperCPPFromBuffer(
+ data, &error)) {
+ return SWIG_NewPointerObj(
+ ptr, SWIGTYPE_p_tflite__interpreter_wrapper__InterpreterWrapper, 1);
+ } else {
+ PyErr_SetString(PyExc_ValueError, error.c_str());
+ return nullptr;
+ }
+ }
+}
+
+} // namespace interpreter_wrapper
+} // namespace tflite