diff options
author | 2017-12-07 11:29:51 -0800 | |
---|---|---|
committer | 2017-12-07 11:36:10 -0800 | |
commit | 1fe793d36a2907ab063bc508fab264cf9e2c46db (patch) | |
tree | 77786f36914da8775ce5a2e24d3aaa788d7cd6b6 /tensorflow/contrib/lite/model.h | |
parent | 1e54177c916d97c34faa1a349b9898186f8b6325 (diff) |
Adds support for loading model directly from a Flatbuffer object.
PiperOrigin-RevId: 178270704
Diffstat (limited to 'tensorflow/contrib/lite/model.h')
-rw-r--r-- | tensorflow/contrib/lite/model.h | 29 |
1 files changed, 20 insertions, 9 deletions
diff --git a/tensorflow/contrib/lite/model.h b/tensorflow/contrib/lite/model.h index 15659d33f3..e0c96f7f04 100644 --- a/tensorflow/contrib/lite/model.h +++ b/tensorflow/contrib/lite/model.h @@ -45,18 +45,25 @@ namespace tflite { // or mmapped. This uses flatbuffers as the serialization format. class FlatBufferModel { public: - // Build a model based on a file. Return a nullptr in case of failure. + // Builds a model based on a file. Returns a nullptr in case of failure. static std::unique_ptr<FlatBufferModel> BuildFromFile( const char* filename, ErrorReporter* error_reporter = DefaultErrorReporter()); - // Build a model based on a pre-loaded flatbuffer. The caller retains + // Builds a model based on a pre-loaded flatbuffer. The caller retains // ownership of the buffer and should keep it alive until the returned object - // is destroyed. Return a nullptr in case of failure. + // is destroyed. Returns a nullptr in case of failure. static std::unique_ptr<FlatBufferModel> BuildFromBuffer( const char* buffer, size_t buffer_size, ErrorReporter* error_reporter = DefaultErrorReporter()); + // Builds a model directly from a flatbuffer pointer. The caller retains + // ownership of the buffer and should keep it alive until the returned object + // is destroyed. Returns a nullptr in case of failure. + static std::unique_ptr<FlatBufferModel> BuildFromModel( + const tflite::Model* model_spec, + ErrorReporter* error_reporter = DefaultErrorReporter()); + // Releases memory or unmaps mmaped meory. ~FlatBufferModel(); @@ -75,7 +82,7 @@ class FlatBufferModel { bool CheckModelIdentifier() const; private: - // Load a model from `filename`. If `mmap_file` is true then use mmap, + // Loads a model from `filename`. If `mmap_file` is true then use mmap, // otherwise make a copy of the model in a buffer. // // Note, if `error_reporter` is null, then a DefaultErrorReporter() will be @@ -85,8 +92,8 @@ class FlatBufferModel { ErrorReporter* error_reporter = DefaultErrorReporter(), bool use_nnapi = false); - // Load a model from `ptr` and `num_bytes` of the model file. The `ptr` has to - // remain alive and unchanged until the end of this flatbuffermodel's + // Loads a model from `ptr` and `num_bytes` of the model file. The `ptr` has + // to remain alive and unchanged until the end of this flatbuffermodel's // lifetime. // // Note, if `error_reporter` is null, then a DefaultErrorReporter() will be @@ -94,6 +101,10 @@ class FlatBufferModel { FlatBufferModel(const char* ptr, size_t num_bytes, ErrorReporter* error_reporter = DefaultErrorReporter()); + // Loads a model from Model flatbuffer. The `model` has to remain alive and + // unchanged until the end of this flatbuffermodel's lifetime. + FlatBufferModel(const Model* model, ErrorReporter* error_reporter); + // Flatbuffer traverser pointer. (Model* is a pointer that is within the // allocated memory of the data allocated by allocation's internals. const tflite::Model* model_ = nullptr; @@ -106,9 +117,9 @@ class FlatBufferModel { // model are mapped to executable function pointers (TfLiteRegistrations). class OpResolver { public: - // Find the op registration for a builtin operator by enum code. + // Finds the op registration for a builtin operator by enum code. virtual TfLiteRegistration* FindOp(tflite::BuiltinOperator op) const = 0; - // Find the op registration of a custom operator by op name. + // Finds the op registration of a custom operator by op name. virtual TfLiteRegistration* FindOp(const char* op) const = 0; virtual ~OpResolver() {} }; @@ -131,7 +142,7 @@ class InterpreterBuilder { public: InterpreterBuilder(const FlatBufferModel& model, const OpResolver& op_resolver); - // Build an interpreter given only the raw flatbuffer Model object (instead + // Builds an interpreter given only the raw flatbuffer Model object (instead // of a FlatBufferModel). Mostly used for testing. // If `error_reporter` is null, then DefaultErrorReporter() is used. InterpreterBuilder(const ::tflite::Model* model, |