aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/arena_planner.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/arena_planner.h')
-rw-r--r--tensorflow/contrib/lite/arena_planner.h24
1 files changed, 22 insertions, 2 deletions
diff --git a/tensorflow/contrib/lite/arena_planner.h b/tensorflow/contrib/lite/arena_planner.h
index e9d0fbc5a9..55003cf4e9 100644
--- a/tensorflow/contrib/lite/arena_planner.h
+++ b/tensorflow/contrib/lite/arena_planner.h
@@ -25,6 +25,10 @@ limitations under the License.
namespace tflite {
+// Memory allocation tuning
+constexpr const int kDefaultArenaAlignment = 64;
+constexpr const int kDefaultTensorAlignment = 64;
+
struct AllocationInfo;
// A memory planner that makes all the allocations using arenas.
@@ -43,8 +47,12 @@ struct AllocationInfo;
class ArenaPlanner : public MemoryPlanner {
public:
// Ownership of 'context' is not taken and it must remain util the
- // ArenaPlanner is destroyed.
- ArenaPlanner(TfLiteContext* context, std::unique_ptr<GraphInfo> graph_info);
+ // ArenaPlanner is destroyed. If 'preserve_inputs' is true the inputs to the
+ // graph will not share memory with any other tensor, effectively preserving
+ // them until the end of inference.
+ ArenaPlanner(TfLiteContext* context, std::unique_ptr<GraphInfo> graph_info,
+ bool preserve_inputs, bool preserve_intermediates,
+ int tensor_alignment = kDefaultTensorAlignment);
~ArenaPlanner() override;
ArenaPlanner(const ArenaPlanner&) = delete;
ArenaPlanner& operator=(const ArenaPlanner&) = delete;
@@ -100,6 +108,18 @@ class ArenaPlanner : public MemoryPlanner {
// Raw memory buffer that is allocated for persistent tensors that are
// declared as kTfLiteArenaRwPersistent.
SimpleMemoryArena persistent_arena_;
+
+ // Ensure that the memory self-allocated for inputs is never reused by the
+ // allocator. This allows for example, multiple runs without getting
+ // unpredictable results.
+ bool preserve_inputs_;
+
+ // If true, then no overlapping of memory areas is done, meaning intermediates
+ // results can be queried after running (modulo running delegates).
+ bool preserve_intermediates_;
+
+ // Number of bytes that tensor buffers should be aligned to.
+ int tensor_alignment_;
};
} // namespace tflite