diff options
author | Yu-Cheng Ling <ycling@google.com> | 2018-07-24 15:26:39 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-24 15:30:26 -0700 |
commit | ba4ccc83adc397936ac01f80dd04ee8b2686c929 (patch) | |
tree | 344aeddc2429e44181dff6ad08ba53002990c757 /tensorflow/contrib/lite/python | |
parent | 74a75900faf88d7ce4e05f4bebd2b872abdf16a9 (diff) |
Improve TFLite Python error handling.
When `InterpreterBuilder` fails, now it fails silently and later
user sees "Interpreter was not initialized.". There's no way to
get the error message and troubleshoot. This fixes the issue and
displays the error message.
PiperOrigin-RevId: 205902280
Diffstat (limited to 'tensorflow/contrib/lite/python')
-rw-r--r-- | tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc | 48 | ||||
-rw-r--r-- | tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h | 18 |
2 files changed, 50 insertions, 16 deletions
diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc index f97919363b..9ab05f3068 100644 --- a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc +++ b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc @@ -108,7 +108,9 @@ std::unique_ptr<tflite::Interpreter> CreateInterpreter( ImportNumpy(); std::unique_ptr<tflite::Interpreter> interpreter; - tflite::InterpreterBuilder(*model, resolver)(&interpreter); + if (tflite::InterpreterBuilder(*model, resolver)(&interpreter) != kTfLiteOk) { + return nullptr; + } return interpreter; } @@ -182,13 +184,37 @@ PyObject* PyTupleFromQuantizationParam(const TfLiteQuantizationParams& param) { } // namespace +InterpreterWrapper* InterpreterWrapper::CreateInterpreterWrapper( + std::unique_ptr<tflite::FlatBufferModel> model, + std::unique_ptr<PythonErrorReporter> error_reporter, + std::string* error_msg) { + if (!model) { + *error_msg = error_reporter->message(); + return nullptr; + } + + auto resolver = absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>(); + auto interpreter = CreateInterpreter(model.get(), *resolver); + if (!interpreter) { + *error_msg = error_reporter->message(); + return nullptr; + } + + InterpreterWrapper* wrapper = + new InterpreterWrapper(std::move(model), std::move(error_reporter), + std::move(resolver), std::move(interpreter)); + return wrapper; +} + InterpreterWrapper::InterpreterWrapper( std::unique_ptr<tflite::FlatBufferModel> model, - std::unique_ptr<PythonErrorReporter> error_reporter) + std::unique_ptr<PythonErrorReporter> error_reporter, + std::unique_ptr<tflite::ops::builtin::BuiltinOpResolver> resolver, + std::unique_ptr<tflite::Interpreter> interpreter) : model_(std::move(model)), error_reporter_(std::move(error_reporter)), - resolver_(absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>()), - interpreter_(CreateInterpreter(model_.get(), *resolver_)) {} + resolver_(std::move(resolver)), + interpreter_(std::move(interpreter)) {} InterpreterWrapper::~InterpreterWrapper() {} @@ -421,11 +447,8 @@ InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromFile( std::unique_ptr<PythonErrorReporter> error_reporter(new PythonErrorReporter); std::unique_ptr<tflite::FlatBufferModel> model = tflite::FlatBufferModel::BuildFromFile(model_path, error_reporter.get()); - if (!model) { - *error_msg = error_reporter->message(); - return nullptr; - } - return new InterpreterWrapper(std::move(model), std::move(error_reporter)); + return CreateInterpreterWrapper(std::move(model), std::move(error_reporter), + error_msg); } InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromBuffer( @@ -439,11 +462,8 @@ InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromBuffer( std::unique_ptr<tflite::FlatBufferModel> model = tflite::FlatBufferModel::BuildFromBuffer(buf, length, error_reporter.get()); - if (!model) { - *error_msg = error_reporter->message(); - return nullptr; - } - return new InterpreterWrapper(std::move(model), std::move(error_reporter)); + return CreateInterpreterWrapper(std::move(model), std::move(error_reporter), + error_msg); } PyObject* InterpreterWrapper::ResetVariableTensorsToZero() { diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h index 556ec7117a..3e03751da4 100644 --- a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h +++ b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h @@ -69,14 +69,28 @@ class InterpreterWrapper { PyObject* tensor(PyObject* base_object, int i); private: - InterpreterWrapper(std::unique_ptr<tflite::FlatBufferModel> model, - std::unique_ptr<PythonErrorReporter> error_reporter); + // Helper function to construct an `InterpreterWrapper` object. + // It only returns InterpreterWrapper if it can construct an `Interpreter`. + // Otherwise it returns `nullptr`. + static InterpreterWrapper* CreateInterpreterWrapper( + std::unique_ptr<tflite::FlatBufferModel> model, + std::unique_ptr<PythonErrorReporter> error_reporter, + std::string* error_msg); + + InterpreterWrapper( + std::unique_ptr<tflite::FlatBufferModel> model, + std::unique_ptr<PythonErrorReporter> error_reporter, + std::unique_ptr<tflite::ops::builtin::BuiltinOpResolver> resolver, + std::unique_ptr<tflite::Interpreter> interpreter); // InterpreterWrapper is not copyable or assignable. We avoid the use of // InterpreterWrapper() = delete here for SWIG compatibility. InterpreterWrapper(); InterpreterWrapper(const InterpreterWrapper& rhs); + // The public functions which creates `InterpreterWrapper` should ensure all + // these member variables are initialized successfully. Otherwise it should + // report the error and return `nullptr`. const std::unique_ptr<tflite::FlatBufferModel> model_; const std::unique_ptr<PythonErrorReporter> error_reporter_; const std::unique_ptr<tflite::ops::builtin::BuiltinOpResolver> resolver_; |