aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/python
diff options
context:
space:
mode:
authorGravatar Yu-Cheng Ling <ycling@google.com>2018-07-24 15:26:39 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-24 15:30:26 -0700
commitba4ccc83adc397936ac01f80dd04ee8b2686c929 (patch)
tree344aeddc2429e44181dff6ad08ba53002990c757 /tensorflow/contrib/lite/python
parent74a75900faf88d7ce4e05f4bebd2b872abdf16a9 (diff)
Improve TFLite Python error handling.
When `InterpreterBuilder` fails, now it fails silently and later user sees "Interpreter was not initialized.". There's no way to get the error message and troubleshoot. This fixes the issue and displays the error message. PiperOrigin-RevId: 205902280
Diffstat (limited to 'tensorflow/contrib/lite/python')
-rw-r--r--tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc48
-rw-r--r--tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h18
2 files changed, 50 insertions, 16 deletions
diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc
index f97919363b..9ab05f3068 100644
--- a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc
+++ b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc
@@ -108,7 +108,9 @@ std::unique_ptr<tflite::Interpreter> CreateInterpreter(
ImportNumpy();
std::unique_ptr<tflite::Interpreter> interpreter;
- tflite::InterpreterBuilder(*model, resolver)(&interpreter);
+ if (tflite::InterpreterBuilder(*model, resolver)(&interpreter) != kTfLiteOk) {
+ return nullptr;
+ }
return interpreter;
}
@@ -182,13 +184,37 @@ PyObject* PyTupleFromQuantizationParam(const TfLiteQuantizationParams& param) {
} // namespace
+InterpreterWrapper* InterpreterWrapper::CreateInterpreterWrapper(
+ std::unique_ptr<tflite::FlatBufferModel> model,
+ std::unique_ptr<PythonErrorReporter> error_reporter,
+ std::string* error_msg) {
+ if (!model) {
+ *error_msg = error_reporter->message();
+ return nullptr;
+ }
+
+ auto resolver = absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>();
+ auto interpreter = CreateInterpreter(model.get(), *resolver);
+ if (!interpreter) {
+ *error_msg = error_reporter->message();
+ return nullptr;
+ }
+
+ InterpreterWrapper* wrapper =
+ new InterpreterWrapper(std::move(model), std::move(error_reporter),
+ std::move(resolver), std::move(interpreter));
+ return wrapper;
+}
+
InterpreterWrapper::InterpreterWrapper(
std::unique_ptr<tflite::FlatBufferModel> model,
- std::unique_ptr<PythonErrorReporter> error_reporter)
+ std::unique_ptr<PythonErrorReporter> error_reporter,
+ std::unique_ptr<tflite::ops::builtin::BuiltinOpResolver> resolver,
+ std::unique_ptr<tflite::Interpreter> interpreter)
: model_(std::move(model)),
error_reporter_(std::move(error_reporter)),
- resolver_(absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>()),
- interpreter_(CreateInterpreter(model_.get(), *resolver_)) {}
+ resolver_(std::move(resolver)),
+ interpreter_(std::move(interpreter)) {}
InterpreterWrapper::~InterpreterWrapper() {}
@@ -421,11 +447,8 @@ InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromFile(
std::unique_ptr<PythonErrorReporter> error_reporter(new PythonErrorReporter);
std::unique_ptr<tflite::FlatBufferModel> model =
tflite::FlatBufferModel::BuildFromFile(model_path, error_reporter.get());
- if (!model) {
- *error_msg = error_reporter->message();
- return nullptr;
- }
- return new InterpreterWrapper(std::move(model), std::move(error_reporter));
+ return CreateInterpreterWrapper(std::move(model), std::move(error_reporter),
+ error_msg);
}
InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromBuffer(
@@ -439,11 +462,8 @@ InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromBuffer(
std::unique_ptr<tflite::FlatBufferModel> model =
tflite::FlatBufferModel::BuildFromBuffer(buf, length,
error_reporter.get());
- if (!model) {
- *error_msg = error_reporter->message();
- return nullptr;
- }
- return new InterpreterWrapper(std::move(model), std::move(error_reporter));
+ return CreateInterpreterWrapper(std::move(model), std::move(error_reporter),
+ error_msg);
}
PyObject* InterpreterWrapper::ResetVariableTensorsToZero() {
diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h
index 556ec7117a..3e03751da4 100644
--- a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h
+++ b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h
@@ -69,14 +69,28 @@ class InterpreterWrapper {
PyObject* tensor(PyObject* base_object, int i);
private:
- InterpreterWrapper(std::unique_ptr<tflite::FlatBufferModel> model,
- std::unique_ptr<PythonErrorReporter> error_reporter);
+ // Helper function to construct an `InterpreterWrapper` object.
+ // It only returns InterpreterWrapper if it can construct an `Interpreter`.
+ // Otherwise it returns `nullptr`.
+ static InterpreterWrapper* CreateInterpreterWrapper(
+ std::unique_ptr<tflite::FlatBufferModel> model,
+ std::unique_ptr<PythonErrorReporter> error_reporter,
+ std::string* error_msg);
+
+ InterpreterWrapper(
+ std::unique_ptr<tflite::FlatBufferModel> model,
+ std::unique_ptr<PythonErrorReporter> error_reporter,
+ std::unique_ptr<tflite::ops::builtin::BuiltinOpResolver> resolver,
+ std::unique_ptr<tflite::Interpreter> interpreter);
// InterpreterWrapper is not copyable or assignable. We avoid the use of
// InterpreterWrapper() = delete here for SWIG compatibility.
InterpreterWrapper();
InterpreterWrapper(const InterpreterWrapper& rhs);
+ // The public functions which creates `InterpreterWrapper` should ensure all
+ // these member variables are initialized successfully. Otherwise it should
+ // report the error and return `nullptr`.
const std::unique_ptr<tflite::FlatBufferModel> model_;
const std::unique_ptr<PythonErrorReporter> error_reporter_;
const std::unique_ptr<tflite::ops::builtin::BuiltinOpResolver> resolver_;