aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc')
-rw-r--r--tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc71
1 files changed, 48 insertions, 23 deletions
diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc
index c38b692dcd..9ab05f3068 100644
--- a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc
+++ b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc
@@ -108,7 +108,9 @@ std::unique_ptr<tflite::Interpreter> CreateInterpreter(
ImportNumpy();
std::unique_ptr<tflite::Interpreter> interpreter;
- tflite::InterpreterBuilder(*model, resolver)(&interpreter);
+ if (tflite::InterpreterBuilder(*model, resolver)(&interpreter) != kTfLiteOk) {
+ return nullptr;
+ }
return interpreter;
}
@@ -182,13 +184,37 @@ PyObject* PyTupleFromQuantizationParam(const TfLiteQuantizationParams& param) {
} // namespace
+InterpreterWrapper* InterpreterWrapper::CreateInterpreterWrapper(
+ std::unique_ptr<tflite::FlatBufferModel> model,
+ std::unique_ptr<PythonErrorReporter> error_reporter,
+ std::string* error_msg) {
+ if (!model) {
+ *error_msg = error_reporter->message();
+ return nullptr;
+ }
+
+ auto resolver = absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>();
+ auto interpreter = CreateInterpreter(model.get(), *resolver);
+ if (!interpreter) {
+ *error_msg = error_reporter->message();
+ return nullptr;
+ }
+
+ InterpreterWrapper* wrapper =
+ new InterpreterWrapper(std::move(model), std::move(error_reporter),
+ std::move(resolver), std::move(interpreter));
+ return wrapper;
+}
+
InterpreterWrapper::InterpreterWrapper(
std::unique_ptr<tflite::FlatBufferModel> model,
- std::unique_ptr<PythonErrorReporter> error_reporter)
+ std::unique_ptr<PythonErrorReporter> error_reporter,
+ std::unique_ptr<tflite::ops::builtin::BuiltinOpResolver> resolver,
+ std::unique_ptr<tflite::Interpreter> interpreter)
: model_(std::move(model)),
error_reporter_(std::move(error_reporter)),
- resolver_(absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>()),
- interpreter_(CreateInterpreter(model_.get(), *resolver_)) {}
+ resolver_(std::move(resolver)),
+ interpreter_(std::move(interpreter)) {}
InterpreterWrapper::~InterpreterWrapper() {}
@@ -340,6 +366,8 @@ PyObject* InterpreterWrapper::SetTensor(int i, PyObject* value) {
namespace {
+// Checks to see if a tensor access can succeed (returns nullptr on error).
+// Otherwise returns Py_None.
PyObject* CheckGetTensorArgs(Interpreter* interpreter_, int tensor_index,
TfLiteTensor** tensor, int* type_num) {
TFLITE_PY_ENSURE_VALID_INTERPRETER();
@@ -362,7 +390,7 @@ PyObject* CheckGetTensorArgs(Interpreter* interpreter_, int tensor_index,
return nullptr;
}
- return nullptr;
+ Py_RETURN_NONE;
}
} // namespace
@@ -371,10 +399,12 @@ PyObject* InterpreterWrapper::GetTensor(int i) const {
// Sanity check accessor
TfLiteTensor* tensor = nullptr;
int type_num = 0;
- if (PyObject* pynone_or_nullptr =
- CheckGetTensorArgs(interpreter_.get(), i, &tensor, &type_num)) {
- return pynone_or_nullptr;
- }
+
+ PyObject* check_result =
+ CheckGetTensorArgs(interpreter_.get(), i, &tensor, &type_num);
+ if (check_result == nullptr) return check_result;
+ Py_XDECREF(check_result);
+
std::vector<npy_intp> dims(tensor->dims->data,
tensor->dims->data + tensor->dims->size);
// Make a buffer copy but we must tell Numpy It owns that data or else
@@ -396,10 +426,11 @@ PyObject* InterpreterWrapper::tensor(PyObject* base_object, int i) {
// Sanity check accessor
TfLiteTensor* tensor = nullptr;
int type_num = 0;
- if (PyObject* pynone_or_nullptr =
- CheckGetTensorArgs(interpreter_.get(), i, &tensor, &type_num)) {
- return pynone_or_nullptr;
- }
+
+ PyObject* check_result =
+ CheckGetTensorArgs(interpreter_.get(), i, &tensor, &type_num);
+ if (check_result == nullptr) return check_result;
+ Py_XDECREF(check_result);
std::vector<npy_intp> dims(tensor->dims->data,
tensor->dims->data + tensor->dims->size);
@@ -416,11 +447,8 @@ InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromFile(
std::unique_ptr<PythonErrorReporter> error_reporter(new PythonErrorReporter);
std::unique_ptr<tflite::FlatBufferModel> model =
tflite::FlatBufferModel::BuildFromFile(model_path, error_reporter.get());
- if (!model) {
- *error_msg = error_reporter->message();
- return nullptr;
- }
- return new InterpreterWrapper(std::move(model), std::move(error_reporter));
+ return CreateInterpreterWrapper(std::move(model), std::move(error_reporter),
+ error_msg);
}
InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromBuffer(
@@ -434,11 +462,8 @@ InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromBuffer(
std::unique_ptr<tflite::FlatBufferModel> model =
tflite::FlatBufferModel::BuildFromBuffer(buf, length,
error_reporter.get());
- if (!model) {
- *error_msg = error_reporter->message();
- return nullptr;
- }
- return new InterpreterWrapper(std::move(model), std::move(error_reporter));
+ return CreateInterpreterWrapper(std::move(model), std::move(error_reporter),
+ error_msg);
}
PyObject* InterpreterWrapper::ResetVariableTensorsToZero() {