diff options
Diffstat (limited to 'tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc')
-rw-r--r-- | tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc | 71 |
1 files changed, 48 insertions, 23 deletions
diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc index c38b692dcd..9ab05f3068 100644 --- a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc +++ b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc @@ -108,7 +108,9 @@ std::unique_ptr<tflite::Interpreter> CreateInterpreter( ImportNumpy(); std::unique_ptr<tflite::Interpreter> interpreter; - tflite::InterpreterBuilder(*model, resolver)(&interpreter); + if (tflite::InterpreterBuilder(*model, resolver)(&interpreter) != kTfLiteOk) { + return nullptr; + } return interpreter; } @@ -182,13 +184,37 @@ PyObject* PyTupleFromQuantizationParam(const TfLiteQuantizationParams& param) { } // namespace +InterpreterWrapper* InterpreterWrapper::CreateInterpreterWrapper( + std::unique_ptr<tflite::FlatBufferModel> model, + std::unique_ptr<PythonErrorReporter> error_reporter, + std::string* error_msg) { + if (!model) { + *error_msg = error_reporter->message(); + return nullptr; + } + + auto resolver = absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>(); + auto interpreter = CreateInterpreter(model.get(), *resolver); + if (!interpreter) { + *error_msg = error_reporter->message(); + return nullptr; + } + + InterpreterWrapper* wrapper = + new InterpreterWrapper(std::move(model), std::move(error_reporter), + std::move(resolver), std::move(interpreter)); + return wrapper; +} + InterpreterWrapper::InterpreterWrapper( std::unique_ptr<tflite::FlatBufferModel> model, - std::unique_ptr<PythonErrorReporter> error_reporter) + std::unique_ptr<PythonErrorReporter> error_reporter, + std::unique_ptr<tflite::ops::builtin::BuiltinOpResolver> resolver, + std::unique_ptr<tflite::Interpreter> interpreter) : model_(std::move(model)), error_reporter_(std::move(error_reporter)), - resolver_(absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>()), - interpreter_(CreateInterpreter(model_.get(), *resolver_)) {} + resolver_(std::move(resolver)), + interpreter_(std::move(interpreter)) {} InterpreterWrapper::~InterpreterWrapper() {} @@ -340,6 +366,8 @@ PyObject* InterpreterWrapper::SetTensor(int i, PyObject* value) { namespace { +// Checks to see if a tensor access can succeed (returns nullptr on error). +// Otherwise returns Py_None. PyObject* CheckGetTensorArgs(Interpreter* interpreter_, int tensor_index, TfLiteTensor** tensor, int* type_num) { TFLITE_PY_ENSURE_VALID_INTERPRETER(); @@ -362,7 +390,7 @@ PyObject* CheckGetTensorArgs(Interpreter* interpreter_, int tensor_index, return nullptr; } - return nullptr; + Py_RETURN_NONE; } } // namespace @@ -371,10 +399,12 @@ PyObject* InterpreterWrapper::GetTensor(int i) const { // Sanity check accessor TfLiteTensor* tensor = nullptr; int type_num = 0; - if (PyObject* pynone_or_nullptr = - CheckGetTensorArgs(interpreter_.get(), i, &tensor, &type_num)) { - return pynone_or_nullptr; - } + + PyObject* check_result = + CheckGetTensorArgs(interpreter_.get(), i, &tensor, &type_num); + if (check_result == nullptr) return check_result; + Py_XDECREF(check_result); + std::vector<npy_intp> dims(tensor->dims->data, tensor->dims->data + tensor->dims->size); // Make a buffer copy but we must tell Numpy It owns that data or else @@ -396,10 +426,11 @@ PyObject* InterpreterWrapper::tensor(PyObject* base_object, int i) { // Sanity check accessor TfLiteTensor* tensor = nullptr; int type_num = 0; - if (PyObject* pynone_or_nullptr = - CheckGetTensorArgs(interpreter_.get(), i, &tensor, &type_num)) { - return pynone_or_nullptr; - } + + PyObject* check_result = + CheckGetTensorArgs(interpreter_.get(), i, &tensor, &type_num); + if (check_result == nullptr) return check_result; + Py_XDECREF(check_result); std::vector<npy_intp> dims(tensor->dims->data, tensor->dims->data + tensor->dims->size); @@ -416,11 +447,8 @@ InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromFile( std::unique_ptr<PythonErrorReporter> error_reporter(new PythonErrorReporter); std::unique_ptr<tflite::FlatBufferModel> model = tflite::FlatBufferModel::BuildFromFile(model_path, error_reporter.get()); - if (!model) { - *error_msg = error_reporter->message(); - return nullptr; - } - return new InterpreterWrapper(std::move(model), std::move(error_reporter)); + return CreateInterpreterWrapper(std::move(model), std::move(error_reporter), + error_msg); } InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromBuffer( @@ -434,11 +462,8 @@ InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromBuffer( std::unique_ptr<tflite::FlatBufferModel> model = tflite::FlatBufferModel::BuildFromBuffer(buf, length, error_reporter.get()); - if (!model) { - *error_msg = error_reporter->message(); - return nullptr; - } - return new InterpreterWrapper(std::move(model), std::move(error_reporter)); + return CreateInterpreterWrapper(std::move(model), std::move(error_reporter), + error_msg); } PyObject* InterpreterWrapper::ResetVariableTensorsToZero() { |