aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/model.h
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-15 15:29:39 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-15 15:33:32 -0700
commit6c62e650252ab32f83637a8de6720e73ffeca226 (patch)
treea4133a93cada7b18238d607bc4d5e551f9e685e6 /tensorflow/contrib/lite/model.h
parent239eb8b652f94b43d51f7c7ffdbbfc02ad094a9c (diff)
Pass error reporter to file copy allocation,
and avoid loading model from file twice PiperOrigin-RevId: 189256489
Diffstat (limited to 'tensorflow/contrib/lite/model.h')
-rw-r--r--tensorflow/contrib/lite/model.h37
1 files changed, 20 insertions, 17 deletions
diff --git a/tensorflow/contrib/lite/model.h b/tensorflow/contrib/lite/model.h
index 8dc1c794dc..38eea0e26b 100644
--- a/tensorflow/contrib/lite/model.h
+++ b/tensorflow/contrib/lite/model.h
@@ -41,6 +41,17 @@ limitations under the License.
namespace tflite {
+// Abstract interface that verifies whether a given model is legit.
+// It facilitates the use-case to verify and build a model without loading it
+// twice.
+class TfLiteVerifier {
+ public:
+ // Returns true if the model is legit.
+ virtual bool Verify(const char* data, int length,
+ ErrorReporter* reporter) = 0;
+ virtual ~TfLiteVerifier() {}
+};
+
// An RAII object that represents a read-only tflite model, copied from disk,
// or mmapped. This uses flatbuffers as the serialization format.
class FlatBufferModel {
@@ -50,6 +61,12 @@ class FlatBufferModel {
const char* filename,
ErrorReporter* error_reporter = DefaultErrorReporter());
+ // Verifies whether the content of the file is legit, then builds a model
+ // based on the file. Returns a nullptr in case of failure.
+ static std::unique_ptr<FlatBufferModel> VerifyAndBuildFromFile(
+ const char* filename, TfLiteVerifier* verifier = nullptr,
+ ErrorReporter* error_reporter = DefaultErrorReporter());
+
// Builds a model based on a pre-loaded flatbuffer. The caller retains
// ownership of the buffer and should keep it alive until the returned object
// is destroyed. Returns a nullptr in case of failure.
@@ -82,23 +99,9 @@ class FlatBufferModel {
bool CheckModelIdentifier() const;
private:
- // Loads a model from `filename`. If `mmap_file` is true then use mmap,
- // otherwise make a copy of the model in a buffer.
- //
- // Note, if `error_reporter` is null, then a DefaultErrorReporter() will be
- // used.
- explicit FlatBufferModel(
- const char* filename, bool mmap_file = true,
- ErrorReporter* error_reporter = DefaultErrorReporter(),
- bool use_nnapi = false);
-
- // Loads a model from `ptr` and `num_bytes` of the model file. The `ptr` has
- // to remain alive and unchanged until the end of this flatbuffermodel's
- // lifetime.
- //
- // Note, if `error_reporter` is null, then a DefaultErrorReporter() will be
- // used.
- FlatBufferModel(const char* ptr, size_t num_bytes,
+ // Loads a model from a given allocation. FlatBufferModel will take over the
+ // ownership of `allocation`, and delete it in desctructor.
+ FlatBufferModel(Allocation* allocation,
ErrorReporter* error_reporter = DefaultErrorReporter());
// Loads a model from Model flatbuffer. The `model` has to remain alive and