aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/interpreter.h
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-01-13 16:21:55 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-13 16:25:51 -0800
commite6ff665dbe4888aa5fdff8f34c44405acca2ddd1 (patch)
treea290fe97b76bebec82ef3b3a76c906acf6b55f41 /tensorflow/contrib/lite/interpreter.h
parenta4973345351a14a786987cd7f648a99c029fdc1d (diff)
Clean up the allocation logic in the interpreter.
PiperOrigin-RevId: 181865795
Diffstat (limited to 'tensorflow/contrib/lite/interpreter.h')
-rw-r--r--tensorflow/contrib/lite/interpreter.h40
1 files changed, 16 insertions, 24 deletions
diff --git a/tensorflow/contrib/lite/interpreter.h b/tensorflow/contrib/lite/interpreter.h
index 65c61e44be..38dd402e8a 100644
--- a/tensorflow/contrib/lite/interpreter.h
+++ b/tensorflow/contrib/lite/interpreter.h
@@ -23,7 +23,7 @@ limitations under the License.
#include "tensorflow/contrib/lite/allocation.h"
#include "tensorflow/contrib/lite/context.h"
#include "tensorflow/contrib/lite/error_reporter.h"
-#include "tensorflow/contrib/lite/simple_memory_arena.h"
+#include "tensorflow/contrib/lite/memory_planner.h"
namespace tflite {
@@ -49,13 +49,6 @@ constexpr TfLiteType typeToTfLiteType<unsigned char>() {
return kTfLiteUInt8;
}
-struct ArenaAllocRefCount {
- ArenaAllocRefCount() : alloc(), count(0) {}
-
- ArenaAlloc alloc;
- int count;
-};
-
// Forward declare since NNAPIDelegate uses Interpreter.
class NNAPIDelegate;
@@ -276,9 +269,17 @@ class Interpreter {
return op_reg.invoke(&context_, node);
}
- // Allocate tensors whose sizes are known in order of nodes. Discontinue when
- // we encounter a node that has a dynamic output tensor.
- TfLiteStatus AllocateTensorsWhoseSizesAreKnown();
+ // Call OpPrepare() for as many ops as possible, allocating memory for their
+ // tensors. If an op containing dynamic tensors is found, preparation will be
+ // postponed until this function is called again. This allows the interpreter
+ // to wait until Invoke() to resolve the sizes of dynamic tensors.
+ TfLiteStatus PrepareOpsAndTensors();
+
+ // Call OpPrepare() for all ops starting at 'first_node'. Stop when a
+ // dynamic tensors is found or all ops have been prepared. Fill
+ // 'last_node_prepared' with the id of the op containing dynamic tensors, or
+ // the last in the graph.
+ TfLiteStatus PrepareOpsStartingAt(int first_node, int* last_node_prepared);
// Tensors needed by the interpreter. Use `AddTensors` to add more blank
// tensor entries. Note, `tensors_.data()` needs to be synchronized to the
@@ -325,17 +326,6 @@ class Interpreter {
std::vector<std::pair<TfLiteNode, TfLiteRegistration>>
nodes_and_registration_;
- // Raw memory buffer that is allocated for all temporary and graph outputs.
- // that are declared kTfLiteArenaRw.
- SimpleMemoryArena arena_;
-
- // Raw memory buffer that is allocated for persistent tensors that are
- // declared as kTfLiteArenaRwPersistent.
- SimpleMemoryArena persistent_arena_;
-
- // Stores allocation and reference counts of all tensors.
- std::vector<ArenaAllocRefCount> allocs_and_refcounts_;
-
// Whether the model is consistent. That is to say if the inputs and outputs
// of every node and the global inputs and outputs are valid indexes into
// the tensor array.
@@ -356,7 +346,7 @@ class Interpreter {
// The error reporter delegate that tflite will forward queries errors to.
ErrorReporter* error_reporter_;
- // Next node to allocate output tensors.
+ // Index of the next node to prepare.
// During Invoke(), Interpreter will allocate input tensors first, which are
// known to be fixed size. Then it will allocate outputs from nodes as many
// as possible. When there is a node that produces dynamic sized tensor.
@@ -364,10 +354,12 @@ class Interpreter {
// node id, and execute the node to generate the output tensor before continue
// to allocate successors. This process repeats until all nodes are executed.
// NOTE: this relies on the order of nodes that is in topological order.
- int next_allocate_node_id_;
+ int next_node_to_prepare_;
// Whether to delegate to NN API
std::unique_ptr<NNAPIDelegate> nnapi_delegate_;
+
+ std::unique_ptr<MemoryPlanner> memory_planner_;
};
} // namespace tflite