diff options
author | 2018-03-15 15:29:39 -0700 | |
---|---|---|
committer | 2018-03-15 15:33:32 -0700 | |
commit | 6c62e650252ab32f83637a8de6720e73ffeca226 (patch) | |
tree | a4133a93cada7b18238d607bc4d5e551f9e685e6 /tensorflow/contrib/lite/model.h | |
parent | 239eb8b652f94b43d51f7c7ffdbbfc02ad094a9c (diff) |
Pass error reporter to file copy allocation,
and avoid loading model from file twice
PiperOrigin-RevId: 189256489
Diffstat (limited to 'tensorflow/contrib/lite/model.h')
-rw-r--r-- | tensorflow/contrib/lite/model.h | 37 |
1 files changed, 20 insertions, 17 deletions
diff --git a/tensorflow/contrib/lite/model.h b/tensorflow/contrib/lite/model.h index 8dc1c794dc..38eea0e26b 100644 --- a/tensorflow/contrib/lite/model.h +++ b/tensorflow/contrib/lite/model.h @@ -41,6 +41,17 @@ limitations under the License. namespace tflite { +// Abstract interface that verifies whether a given model is legit. +// It facilitates the use-case to verify and build a model without loading it +// twice. +class TfLiteVerifier { + public: + // Returns true if the model is legit. + virtual bool Verify(const char* data, int length, + ErrorReporter* reporter) = 0; + virtual ~TfLiteVerifier() {} +}; + // An RAII object that represents a read-only tflite model, copied from disk, // or mmapped. This uses flatbuffers as the serialization format. class FlatBufferModel { @@ -50,6 +61,12 @@ class FlatBufferModel { const char* filename, ErrorReporter* error_reporter = DefaultErrorReporter()); + // Verifies whether the content of the file is legit, then builds a model + // based on the file. Returns a nullptr in case of failure. + static std::unique_ptr<FlatBufferModel> VerifyAndBuildFromFile( + const char* filename, TfLiteVerifier* verifier = nullptr, + ErrorReporter* error_reporter = DefaultErrorReporter()); + // Builds a model based on a pre-loaded flatbuffer. The caller retains // ownership of the buffer and should keep it alive until the returned object // is destroyed. Returns a nullptr in case of failure. @@ -82,23 +99,9 @@ class FlatBufferModel { bool CheckModelIdentifier() const; private: - // Loads a model from `filename`. If `mmap_file` is true then use mmap, - // otherwise make a copy of the model in a buffer. - // - // Note, if `error_reporter` is null, then a DefaultErrorReporter() will be - // used. - explicit FlatBufferModel( - const char* filename, bool mmap_file = true, - ErrorReporter* error_reporter = DefaultErrorReporter(), - bool use_nnapi = false); - - // Loads a model from `ptr` and `num_bytes` of the model file. The `ptr` has - // to remain alive and unchanged until the end of this flatbuffermodel's - // lifetime. - // - // Note, if `error_reporter` is null, then a DefaultErrorReporter() will be - // used. - FlatBufferModel(const char* ptr, size_t num_bytes, + // Loads a model from a given allocation. FlatBufferModel will take over the + // ownership of `allocation`, and delete it in desctructor. + FlatBufferModel(Allocation* allocation, ErrorReporter* error_reporter = DefaultErrorReporter()); // Loads a model from Model flatbuffer. The `model` has to remain alive and |