aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/arena_planner.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-29 16:19:37 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-29 16:25:43 -0700
commit8648bd52264116760c54de16ffbce6c98d7397e8 (patch)
treeb26123fca57aab79b86c45c99a1b7724ceceec44 /tensorflow/contrib/lite/arena_planner.cc
parentd7642767d24464127aae8c118caad597dea9e017 (diff)
Do not overwrite inputs.
PiperOrigin-RevId: 202724720
Diffstat (limited to 'tensorflow/contrib/lite/arena_planner.cc')
-rw-r--r--tensorflow/contrib/lite/arena_planner.cc13
1 files changed, 9 insertions, 4 deletions
diff --git a/tensorflow/contrib/lite/arena_planner.cc b/tensorflow/contrib/lite/arena_planner.cc
index 22be64d6ff..4257e754ad 100644
--- a/tensorflow/contrib/lite/arena_planner.cc
+++ b/tensorflow/contrib/lite/arena_planner.cc
@@ -35,12 +35,13 @@ struct AllocationInfo {
};
ArenaPlanner::ArenaPlanner(TfLiteContext* context,
- std::unique_ptr<GraphInfo> graph_info)
+ std::unique_ptr<GraphInfo> graph_info,
+ bool preserve_inputs)
: context_(context),
graph_info_(std::move(graph_info)),
arena_(kDefaultArenaAlignment),
- persistent_arena_(kDefaultArenaAlignment) {}
-
+ persistent_arena_(kDefaultArenaAlignment),
+ preserve_inputs_(preserve_inputs) {}
ArenaPlanner::~ArenaPlanner() {}
int64_t ArenaPlanner::BasePointer(TfLiteAllocationType type) {
@@ -112,9 +113,13 @@ TfLiteStatus ArenaPlanner::PlanAllocations() {
refcounts[tensor_index]++;
}
- // Queue all graph inputs for allocation.
+ // Queue all graph inputs for allocation. If preserve_inputs_ is true, make
+ // sure they never be overwritten.
for (int tensor_index : graph_info_->inputs()) {
if (tensor_index != kOptionalTensor) {
+ if (preserve_inputs_) {
+ refcounts[tensor_index]++;
+ }
TF_LITE_ENSURE_STATUS(allocate(0, tensor_index));
}
}