diff options
author | 2018-06-29 16:19:37 -0700 | |
---|---|---|
committer | 2018-06-29 16:25:43 -0700 | |
commit | 8648bd52264116760c54de16ffbce6c98d7397e8 (patch) | |
tree | b26123fca57aab79b86c45c99a1b7724ceceec44 /tensorflow/contrib/lite/arena_planner.cc | |
parent | d7642767d24464127aae8c118caad597dea9e017 (diff) |
Do not overwrite inputs.
PiperOrigin-RevId: 202724720
Diffstat (limited to 'tensorflow/contrib/lite/arena_planner.cc')
-rw-r--r-- | tensorflow/contrib/lite/arena_planner.cc | 13 |
1 files changed, 9 insertions, 4 deletions
diff --git a/tensorflow/contrib/lite/arena_planner.cc b/tensorflow/contrib/lite/arena_planner.cc index 22be64d6ff..4257e754ad 100644 --- a/tensorflow/contrib/lite/arena_planner.cc +++ b/tensorflow/contrib/lite/arena_planner.cc @@ -35,12 +35,13 @@ struct AllocationInfo { }; ArenaPlanner::ArenaPlanner(TfLiteContext* context, - std::unique_ptr<GraphInfo> graph_info) + std::unique_ptr<GraphInfo> graph_info, + bool preserve_inputs) : context_(context), graph_info_(std::move(graph_info)), arena_(kDefaultArenaAlignment), - persistent_arena_(kDefaultArenaAlignment) {} - + persistent_arena_(kDefaultArenaAlignment), + preserve_inputs_(preserve_inputs) {} ArenaPlanner::~ArenaPlanner() {} int64_t ArenaPlanner::BasePointer(TfLiteAllocationType type) { @@ -112,9 +113,13 @@ TfLiteStatus ArenaPlanner::PlanAllocations() { refcounts[tensor_index]++; } - // Queue all graph inputs for allocation. + // Queue all graph inputs for allocation. If preserve_inputs_ is true, make + // sure they never be overwritten. for (int tensor_index : graph_info_->inputs()) { if (tensor_index != kOptionalTensor) { + if (preserve_inputs_) { + refcounts[tensor_index]++; + } TF_LITE_ENSURE_STATUS(allocate(0, tensor_index)); } } |