diff options
Diffstat (limited to 'tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h')
-rw-r--r-- | tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h | 40 |
1 files changed, 31 insertions, 9 deletions
diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h index e7343cb388..3e03751da4 100644 --- a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h +++ b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h @@ -15,12 +15,12 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_LITE_PYTHON_INTERPRETER_WRAPPER_INTERPRETER_WRAPPER_H_ #define TENSORFLOW_CONTRIB_LITE_PYTHON_INTERPRETER_WRAPPER_INTERPRETER_WRAPPER_H_ +// Place `<locale>` before <Python.h> to avoid build failures in macOS. +#include <locale> #include <memory> #include <string> #include <vector> -// Place `<locale>` before <Python.h> to avoid build failures in macOS. -#include <locale> #include <Python.h> // We forward declare TFLite classes here to avoid exposing them to SWIG. @@ -36,41 +36,63 @@ class Interpreter; namespace interpreter_wrapper { +class PythonErrorReporter; + class InterpreterWrapper { public: // SWIG caller takes ownership of pointer. - static InterpreterWrapper* CreateWrapperCPPFromFile(const char* model_path); + static InterpreterWrapper* CreateWrapperCPPFromFile(const char* model_path, + std::string* error_msg); // SWIG caller takes ownership of pointer. - static InterpreterWrapper* CreateWrapperCPPFromBuffer(PyObject* data); + static InterpreterWrapper* CreateWrapperCPPFromBuffer(PyObject* data, + std::string* error_msg); ~InterpreterWrapper(); - bool AllocateTensors(); - bool Invoke(); + PyObject* AllocateTensors(); + PyObject* Invoke(); PyObject* InputIndices() const; PyObject* OutputIndices() const; - bool ResizeInputTensor(int i, PyObject* value); + PyObject* ResizeInputTensor(int i, PyObject* value); std::string TensorName(int i) const; PyObject* TensorType(int i) const; PyObject* TensorSize(int i) const; PyObject* TensorQuantization(int i) const; - bool SetTensor(int i, PyObject* value); + PyObject* SetTensor(int i, PyObject* value); PyObject* GetTensor(int i) const; + PyObject* ResetVariableTensorsToZero(); + // Returns a reference to tensor index i as a numpy array. The base_object // should be the interpreter object providing the memory. PyObject* tensor(PyObject* base_object, int i); private: - InterpreterWrapper(std::unique_ptr<tflite::FlatBufferModel> model); + // Helper function to construct an `InterpreterWrapper` object. + // It only returns InterpreterWrapper if it can construct an `Interpreter`. + // Otherwise it returns `nullptr`. + static InterpreterWrapper* CreateInterpreterWrapper( + std::unique_ptr<tflite::FlatBufferModel> model, + std::unique_ptr<PythonErrorReporter> error_reporter, + std::string* error_msg); + + InterpreterWrapper( + std::unique_ptr<tflite::FlatBufferModel> model, + std::unique_ptr<PythonErrorReporter> error_reporter, + std::unique_ptr<tflite::ops::builtin::BuiltinOpResolver> resolver, + std::unique_ptr<tflite::Interpreter> interpreter); // InterpreterWrapper is not copyable or assignable. We avoid the use of // InterpreterWrapper() = delete here for SWIG compatibility. InterpreterWrapper(); InterpreterWrapper(const InterpreterWrapper& rhs); + // The public functions which creates `InterpreterWrapper` should ensure all + // these member variables are initialized successfully. Otherwise it should + // report the error and return `nullptr`. const std::unique_ptr<tflite::FlatBufferModel> model_; + const std::unique_ptr<PythonErrorReporter> error_reporter_; const std::unique_ptr<tflite::ops::builtin::BuiltinOpResolver> resolver_; const std::unique_ptr<tflite::Interpreter> interpreter_; }; |