aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/model.h
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-12-07 11:29:51 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-07 11:36:10 -0800
commit1fe793d36a2907ab063bc508fab264cf9e2c46db (patch)
tree77786f36914da8775ce5a2e24d3aaa788d7cd6b6 /tensorflow/contrib/lite/model.h
parent1e54177c916d97c34faa1a349b9898186f8b6325 (diff)
Adds support for loading model directly from a Flatbuffer object.
PiperOrigin-RevId: 178270704
Diffstat (limited to 'tensorflow/contrib/lite/model.h')
-rw-r--r--tensorflow/contrib/lite/model.h29
1 files changed, 20 insertions, 9 deletions
diff --git a/tensorflow/contrib/lite/model.h b/tensorflow/contrib/lite/model.h
index 15659d33f3..e0c96f7f04 100644
--- a/tensorflow/contrib/lite/model.h
+++ b/tensorflow/contrib/lite/model.h
@@ -45,18 +45,25 @@ namespace tflite {
// or mmapped. This uses flatbuffers as the serialization format.
class FlatBufferModel {
public:
- // Build a model based on a file. Return a nullptr in case of failure.
+ // Builds a model based on a file. Returns a nullptr in case of failure.
static std::unique_ptr<FlatBufferModel> BuildFromFile(
const char* filename,
ErrorReporter* error_reporter = DefaultErrorReporter());
- // Build a model based on a pre-loaded flatbuffer. The caller retains
+ // Builds a model based on a pre-loaded flatbuffer. The caller retains
// ownership of the buffer and should keep it alive until the returned object
- // is destroyed. Return a nullptr in case of failure.
+ // is destroyed. Returns a nullptr in case of failure.
static std::unique_ptr<FlatBufferModel> BuildFromBuffer(
const char* buffer, size_t buffer_size,
ErrorReporter* error_reporter = DefaultErrorReporter());
+ // Builds a model directly from a flatbuffer pointer. The caller retains
+ // ownership of the buffer and should keep it alive until the returned object
+ // is destroyed. Returns a nullptr in case of failure.
+ static std::unique_ptr<FlatBufferModel> BuildFromModel(
+ const tflite::Model* model_spec,
+ ErrorReporter* error_reporter = DefaultErrorReporter());
+
// Releases memory or unmaps mmaped meory.
~FlatBufferModel();
@@ -75,7 +82,7 @@ class FlatBufferModel {
bool CheckModelIdentifier() const;
private:
- // Load a model from `filename`. If `mmap_file` is true then use mmap,
+ // Loads a model from `filename`. If `mmap_file` is true then use mmap,
// otherwise make a copy of the model in a buffer.
//
// Note, if `error_reporter` is null, then a DefaultErrorReporter() will be
@@ -85,8 +92,8 @@ class FlatBufferModel {
ErrorReporter* error_reporter = DefaultErrorReporter(),
bool use_nnapi = false);
- // Load a model from `ptr` and `num_bytes` of the model file. The `ptr` has to
- // remain alive and unchanged until the end of this flatbuffermodel's
+ // Loads a model from `ptr` and `num_bytes` of the model file. The `ptr` has
+ // to remain alive and unchanged until the end of this flatbuffermodel's
// lifetime.
//
// Note, if `error_reporter` is null, then a DefaultErrorReporter() will be
@@ -94,6 +101,10 @@ class FlatBufferModel {
FlatBufferModel(const char* ptr, size_t num_bytes,
ErrorReporter* error_reporter = DefaultErrorReporter());
+ // Loads a model from Model flatbuffer. The `model` has to remain alive and
+ // unchanged until the end of this flatbuffermodel's lifetime.
+ FlatBufferModel(const Model* model, ErrorReporter* error_reporter);
+
// Flatbuffer traverser pointer. (Model* is a pointer that is within the
// allocated memory of the data allocated by allocation's internals.
const tflite::Model* model_ = nullptr;
@@ -106,9 +117,9 @@ class FlatBufferModel {
// model are mapped to executable function pointers (TfLiteRegistrations).
class OpResolver {
public:
- // Find the op registration for a builtin operator by enum code.
+ // Finds the op registration for a builtin operator by enum code.
virtual TfLiteRegistration* FindOp(tflite::BuiltinOperator op) const = 0;
- // Find the op registration of a custom operator by op name.
+ // Finds the op registration of a custom operator by op name.
virtual TfLiteRegistration* FindOp(const char* op) const = 0;
virtual ~OpResolver() {}
};
@@ -131,7 +142,7 @@ class InterpreterBuilder {
public:
InterpreterBuilder(const FlatBufferModel& model,
const OpResolver& op_resolver);
- // Build an interpreter given only the raw flatbuffer Model object (instead
+ // Builds an interpreter given only the raw flatbuffer Model object (instead
// of a FlatBufferModel). Mostly used for testing.
// If `error_reporter` is null, then DefaultErrorReporter() is used.
InterpreterBuilder(const ::tflite::Model* model,