aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/arena_planner.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/arena_planner.cc')
-rw-r--r--tensorflow/contrib/lite/arena_planner.cc42
1 files changed, 19 insertions, 23 deletions
diff --git a/tensorflow/contrib/lite/arena_planner.cc b/tensorflow/contrib/lite/arena_planner.cc
index 4257e754ad..02442575b3 100644
--- a/tensorflow/contrib/lite/arena_planner.cc
+++ b/tensorflow/contrib/lite/arena_planner.cc
@@ -17,14 +17,6 @@ limitations under the License.
namespace tflite {
-namespace {
-
-// Memory allocation tuning
-constexpr const int kDefaultArenaAlignment = 64;
-constexpr const int kDefaultTensorAlignment = 4;
-
-} // namespace
-
struct AllocationInfo {
// The node index requesting this allocation.
int node;
@@ -36,12 +28,16 @@ struct AllocationInfo {
ArenaPlanner::ArenaPlanner(TfLiteContext* context,
std::unique_ptr<GraphInfo> graph_info,
- bool preserve_inputs)
+ bool preserve_inputs, bool preserve_intermediates,
+ int tensor_alignment)
: context_(context),
graph_info_(std::move(graph_info)),
arena_(kDefaultArenaAlignment),
persistent_arena_(kDefaultArenaAlignment),
- preserve_inputs_(preserve_inputs) {}
+ preserve_inputs_(preserve_inputs),
+ preserve_intermediates_(preserve_intermediates),
+ tensor_alignment_(tensor_alignment) {}
+
ArenaPlanner::~ArenaPlanner() {}
int64_t ArenaPlanner::BasePointer(TfLiteAllocationType type) {
@@ -164,13 +160,15 @@ TfLiteStatus ArenaPlanner::PlanAllocations() {
// Then update the ref-counts of the node's inputs, and if necessary queue
// them for deallocation.
- TfLiteIntArray* node_inputs = node.inputs;
- for (int j = 0; j < node_inputs->size; ++j) {
- int tensor_index = node_inputs->data[j];
- if (tensor_index != kOptionalTensor) {
- refcounts[tensor_index]--;
- if (refcounts[tensor_index] == 0) {
- TF_LITE_ENSURE_STATUS(deallocate(i, tensor_index));
+ if (!preserve_intermediates_) {
+ TfLiteIntArray* node_inputs = node.inputs;
+ for (int j = 0; j < node_inputs->size; ++j) {
+ int tensor_index = node_inputs->data[j];
+ if (tensor_index != kOptionalTensor) {
+ refcounts[tensor_index]--;
+ if (refcounts[tensor_index] == 0) {
+ TF_LITE_ENSURE_STATUS(deallocate(i, tensor_index));
+ }
}
}
}
@@ -261,14 +259,12 @@ TfLiteStatus ArenaPlanner::ResolveTensorAllocation(int tensor_index) {
TfLiteStatus ArenaPlanner::CalculateTensorAllocation(int tensor_index) {
TfLiteTensor& tensor = *graph_info_->tensor(tensor_index);
if (tensor.allocation_type == kTfLiteArenaRw) {
- TF_LITE_ENSURE_STATUS(arena_.Allocate(context_, kDefaultTensorAlignment,
- tensor.bytes,
- &allocs_[tensor_index]));
+ TF_LITE_ENSURE_STATUS(arena_.Allocate(
+ context_, tensor_alignment_, tensor.bytes, &allocs_[tensor_index]));
}
if (tensor.allocation_type == kTfLiteArenaRwPersistent) {
- TF_LITE_ENSURE_STATUS(
- persistent_arena_.Allocate(context_, kDefaultTensorAlignment,
- tensor.bytes, &allocs_[tensor_index]));
+ TF_LITE_ENSURE_STATUS(persistent_arena_.Allocate(
+ context_, tensor_alignment_, tensor.bytes, &allocs_[tensor_index]));
}
return kTfLiteOk;
}