diff options
Diffstat (limited to 'tensorflow/contrib/lite/arena_planner.cc')
-rw-r--r-- | tensorflow/contrib/lite/arena_planner.cc | 42 |
1 files changed, 19 insertions, 23 deletions
diff --git a/tensorflow/contrib/lite/arena_planner.cc b/tensorflow/contrib/lite/arena_planner.cc index 4257e754ad..02442575b3 100644 --- a/tensorflow/contrib/lite/arena_planner.cc +++ b/tensorflow/contrib/lite/arena_planner.cc @@ -17,14 +17,6 @@ limitations under the License. namespace tflite { -namespace { - -// Memory allocation tuning -constexpr const int kDefaultArenaAlignment = 64; -constexpr const int kDefaultTensorAlignment = 4; - -} // namespace - struct AllocationInfo { // The node index requesting this allocation. int node; @@ -36,12 +28,16 @@ struct AllocationInfo { ArenaPlanner::ArenaPlanner(TfLiteContext* context, std::unique_ptr<GraphInfo> graph_info, - bool preserve_inputs) + bool preserve_inputs, bool preserve_intermediates, + int tensor_alignment) : context_(context), graph_info_(std::move(graph_info)), arena_(kDefaultArenaAlignment), persistent_arena_(kDefaultArenaAlignment), - preserve_inputs_(preserve_inputs) {} + preserve_inputs_(preserve_inputs), + preserve_intermediates_(preserve_intermediates), + tensor_alignment_(tensor_alignment) {} + ArenaPlanner::~ArenaPlanner() {} int64_t ArenaPlanner::BasePointer(TfLiteAllocationType type) { @@ -164,13 +160,15 @@ TfLiteStatus ArenaPlanner::PlanAllocations() { // Then update the ref-counts of the node's inputs, and if necessary queue // them for deallocation. - TfLiteIntArray* node_inputs = node.inputs; - for (int j = 0; j < node_inputs->size; ++j) { - int tensor_index = node_inputs->data[j]; - if (tensor_index != kOptionalTensor) { - refcounts[tensor_index]--; - if (refcounts[tensor_index] == 0) { - TF_LITE_ENSURE_STATUS(deallocate(i, tensor_index)); + if (!preserve_intermediates_) { + TfLiteIntArray* node_inputs = node.inputs; + for (int j = 0; j < node_inputs->size; ++j) { + int tensor_index = node_inputs->data[j]; + if (tensor_index != kOptionalTensor) { + refcounts[tensor_index]--; + if (refcounts[tensor_index] == 0) { + TF_LITE_ENSURE_STATUS(deallocate(i, tensor_index)); + } } } } @@ -261,14 +259,12 @@ TfLiteStatus ArenaPlanner::ResolveTensorAllocation(int tensor_index) { TfLiteStatus ArenaPlanner::CalculateTensorAllocation(int tensor_index) { TfLiteTensor& tensor = *graph_info_->tensor(tensor_index); if (tensor.allocation_type == kTfLiteArenaRw) { - TF_LITE_ENSURE_STATUS(arena_.Allocate(context_, kDefaultTensorAlignment, - tensor.bytes, - &allocs_[tensor_index])); + TF_LITE_ENSURE_STATUS(arena_.Allocate( + context_, tensor_alignment_, tensor.bytes, &allocs_[tensor_index])); } if (tensor.allocation_type == kTfLiteArenaRwPersistent) { - TF_LITE_ENSURE_STATUS( - persistent_arena_.Allocate(context_, kDefaultTensorAlignment, - tensor.bytes, &allocs_[tensor_index])); + TF_LITE_ENSURE_STATUS(persistent_arena_.Allocate( + context_, tensor_alignment_, tensor.bytes, &allocs_[tensor_index])); } return kTfLiteOk; } |