diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-01-13 16:21:55 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-01-13 16:25:51 -0800 |
commit | e6ff665dbe4888aa5fdff8f34c44405acca2ddd1 (patch) | |
tree | a290fe97b76bebec82ef3b3a76c906acf6b55f41 /tensorflow/contrib/lite/interpreter.h | |
parent | a4973345351a14a786987cd7f648a99c029fdc1d (diff) |
Clean up the allocation logic in the interpreter.
PiperOrigin-RevId: 181865795
Diffstat (limited to 'tensorflow/contrib/lite/interpreter.h')
-rw-r--r-- | tensorflow/contrib/lite/interpreter.h | 40 |
1 files changed, 16 insertions, 24 deletions
diff --git a/tensorflow/contrib/lite/interpreter.h b/tensorflow/contrib/lite/interpreter.h index 65c61e44be..38dd402e8a 100644 --- a/tensorflow/contrib/lite/interpreter.h +++ b/tensorflow/contrib/lite/interpreter.h @@ -23,7 +23,7 @@ limitations under the License. #include "tensorflow/contrib/lite/allocation.h" #include "tensorflow/contrib/lite/context.h" #include "tensorflow/contrib/lite/error_reporter.h" -#include "tensorflow/contrib/lite/simple_memory_arena.h" +#include "tensorflow/contrib/lite/memory_planner.h" namespace tflite { @@ -49,13 +49,6 @@ constexpr TfLiteType typeToTfLiteType<unsigned char>() { return kTfLiteUInt8; } -struct ArenaAllocRefCount { - ArenaAllocRefCount() : alloc(), count(0) {} - - ArenaAlloc alloc; - int count; -}; - // Forward declare since NNAPIDelegate uses Interpreter. class NNAPIDelegate; @@ -276,9 +269,17 @@ class Interpreter { return op_reg.invoke(&context_, node); } - // Allocate tensors whose sizes are known in order of nodes. Discontinue when - // we encounter a node that has a dynamic output tensor. - TfLiteStatus AllocateTensorsWhoseSizesAreKnown(); + // Call OpPrepare() for as many ops as possible, allocating memory for their + // tensors. If an op containing dynamic tensors is found, preparation will be + // postponed until this function is called again. This allows the interpreter + // to wait until Invoke() to resolve the sizes of dynamic tensors. + TfLiteStatus PrepareOpsAndTensors(); + + // Call OpPrepare() for all ops starting at 'first_node'. Stop when a + // dynamic tensors is found or all ops have been prepared. Fill + // 'last_node_prepared' with the id of the op containing dynamic tensors, or + // the last in the graph. + TfLiteStatus PrepareOpsStartingAt(int first_node, int* last_node_prepared); // Tensors needed by the interpreter. Use `AddTensors` to add more blank // tensor entries. Note, `tensors_.data()` needs to be synchronized to the @@ -325,17 +326,6 @@ class Interpreter { std::vector<std::pair<TfLiteNode, TfLiteRegistration>> nodes_and_registration_; - // Raw memory buffer that is allocated for all temporary and graph outputs. - // that are declared kTfLiteArenaRw. - SimpleMemoryArena arena_; - - // Raw memory buffer that is allocated for persistent tensors that are - // declared as kTfLiteArenaRwPersistent. - SimpleMemoryArena persistent_arena_; - - // Stores allocation and reference counts of all tensors. - std::vector<ArenaAllocRefCount> allocs_and_refcounts_; - // Whether the model is consistent. That is to say if the inputs and outputs // of every node and the global inputs and outputs are valid indexes into // the tensor array. @@ -356,7 +346,7 @@ class Interpreter { // The error reporter delegate that tflite will forward queries errors to. ErrorReporter* error_reporter_; - // Next node to allocate output tensors. + // Index of the next node to prepare. // During Invoke(), Interpreter will allocate input tensors first, which are // known to be fixed size. Then it will allocate outputs from nodes as many // as possible. When there is a node that produces dynamic sized tensor. @@ -364,10 +354,12 @@ class Interpreter { // node id, and execute the node to generate the output tensor before continue // to allocate successors. This process repeats until all nodes are executed. // NOTE: this relies on the order of nodes that is in topological order. - int next_allocate_node_id_; + int next_node_to_prepare_; // Whether to delegate to NN API std::unique_ptr<NNAPIDelegate> nnapi_delegate_; + + std::unique_ptr<MemoryPlanner> memory_planner_; }; } // namespace tflite |