diff options
author | 2018-06-29 16:19:37 -0700 | |
---|---|---|
committer | 2018-06-29 16:25:43 -0700 | |
commit | 8648bd52264116760c54de16ffbce6c98d7397e8 (patch) | |
tree | b26123fca57aab79b86c45c99a1b7724ceceec44 | |
parent | d7642767d24464127aae8c118caad597dea9e017 (diff) |
Do not overwrite inputs.
PiperOrigin-RevId: 202724720
-rw-r--r-- | tensorflow/contrib/lite/arena_planner.cc | 13 | ||||
-rw-r--r-- | tensorflow/contrib/lite/arena_planner.h | 9 | ||||
-rw-r--r-- | tensorflow/contrib/lite/arena_planner_test.cc | 29 | ||||
-rw-r--r-- | tensorflow/contrib/lite/interpreter.cc | 3 | ||||
-rw-r--r-- | tensorflow/contrib/lite/interpreter_test.cc | 30 |
5 files changed, 52 insertions, 32 deletions
diff --git a/tensorflow/contrib/lite/arena_planner.cc b/tensorflow/contrib/lite/arena_planner.cc index 22be64d6ff..4257e754ad 100644 --- a/tensorflow/contrib/lite/arena_planner.cc +++ b/tensorflow/contrib/lite/arena_planner.cc @@ -35,12 +35,13 @@ struct AllocationInfo { }; ArenaPlanner::ArenaPlanner(TfLiteContext* context, - std::unique_ptr<GraphInfo> graph_info) + std::unique_ptr<GraphInfo> graph_info, + bool preserve_inputs) : context_(context), graph_info_(std::move(graph_info)), arena_(kDefaultArenaAlignment), - persistent_arena_(kDefaultArenaAlignment) {} - + persistent_arena_(kDefaultArenaAlignment), + preserve_inputs_(preserve_inputs) {} ArenaPlanner::~ArenaPlanner() {} int64_t ArenaPlanner::BasePointer(TfLiteAllocationType type) { @@ -112,9 +113,13 @@ TfLiteStatus ArenaPlanner::PlanAllocations() { refcounts[tensor_index]++; } - // Queue all graph inputs for allocation. + // Queue all graph inputs for allocation. If preserve_inputs_ is true, make + // sure they never be overwritten. for (int tensor_index : graph_info_->inputs()) { if (tensor_index != kOptionalTensor) { + if (preserve_inputs_) { + refcounts[tensor_index]++; + } TF_LITE_ENSURE_STATUS(allocate(0, tensor_index)); } } diff --git a/tensorflow/contrib/lite/arena_planner.h b/tensorflow/contrib/lite/arena_planner.h index e9d0fbc5a9..1d84950e91 100644 --- a/tensorflow/contrib/lite/arena_planner.h +++ b/tensorflow/contrib/lite/arena_planner.h @@ -43,8 +43,11 @@ struct AllocationInfo; class ArenaPlanner : public MemoryPlanner { public: // Ownership of 'context' is not taken and it must remain util the - // ArenaPlanner is destroyed. - ArenaPlanner(TfLiteContext* context, std::unique_ptr<GraphInfo> graph_info); + // ArenaPlanner is destroyed. If 'preserve_inputs' is true the inputs to the + // graph will not share memory with any other tensor, effectively preserving + // them until the end of inference. + ArenaPlanner(TfLiteContext* context, std::unique_ptr<GraphInfo> graph_info, + bool preserve_inputs); ~ArenaPlanner() override; ArenaPlanner(const ArenaPlanner&) = delete; ArenaPlanner& operator=(const ArenaPlanner&) = delete; @@ -100,6 +103,8 @@ class ArenaPlanner : public MemoryPlanner { // Raw memory buffer that is allocated for persistent tensors that are // declared as kTfLiteArenaRwPersistent. SimpleMemoryArena persistent_arena_; + + bool preserve_inputs_; }; } // namespace tflite diff --git a/tensorflow/contrib/lite/arena_planner_test.cc b/tensorflow/contrib/lite/arena_planner_test.cc index f0fd35216f..f5bd1932f9 100644 --- a/tensorflow/contrib/lite/arena_planner_test.cc +++ b/tensorflow/contrib/lite/arena_planner_test.cc @@ -151,11 +151,12 @@ void ReportError(TfLiteContext* context, const char* format, ...) { class ArenaPlannerTest : public ::testing::Test { protected: - void SetGraph(TestGraph* graph) { + void SetGraph(TestGraph* graph, bool preserve_inputs = false) { graph_ = graph; context_.ReportError = ReportError; planner_.reset(new ArenaPlanner( - &context_, std::unique_ptr<GraphInfo>(new TestGraphInfo(graph)))); + &context_, std::unique_ptr<GraphInfo>(new TestGraphInfo(graph)), + preserve_inputs)); CHECK(planner_->ResetAllocations() == kTfLiteOk); CHECK(planner_->PlanAllocations() == kTfLiteOk); } @@ -243,6 +244,30 @@ TEST_F(ArenaPlannerTest, SimpleGraph) { EXPECT_EQ(GetOffset(3), 0); } +TEST_F(ArenaPlannerTest, SimpleGraphInputsPreserved) { + TestGraph graph({0, 1}, + { + /* in, out, tmp */ + {{0, 1}, {2}, {}}, // First op + {{2, 0}, {4, 5}, {}}, // Second op + {{4, 5}, {3}, {}} // Third op + }, + {3}); + SetGraph(&graph, /*preserve_inputs=*/true); + Execute(0, 10); + + // Alloc(+) and dealloc(-) order: +0 +1 +2 +4 +5 -2 +3 -4 -5 + EXPECT_EQ(GetOffset(0), 0); + EXPECT_EQ(GetOffset(1), GetOffsetAfter(0)); + EXPECT_EQ(GetOffset(2), GetOffsetAfter(1)); + EXPECT_EQ(GetOffset(4), GetOffsetAfter(2)); + EXPECT_EQ(GetOffset(5), GetOffsetAfter(4)); + // Because we are keeping the inputs alive until the end (due to + // preserve_inputs=true), the output tensor will not be able to use that + // space. It will end up using the same are as tensor #2. + EXPECT_EQ(GetOffset(3), GetOffsetAfter(1)); +} + TEST_F(ArenaPlannerTest, SimpleGraphWithTemporary) { TestGraph graph({0, 1}, { diff --git a/tensorflow/contrib/lite/interpreter.cc b/tensorflow/contrib/lite/interpreter.cc index 4b3ba5df10..dcb4ef593e 100644 --- a/tensorflow/contrib/lite/interpreter.cc +++ b/tensorflow/contrib/lite/interpreter.cc @@ -538,7 +538,8 @@ TfLiteStatus Interpreter::PrepareOpsStartingAt( TfLiteStatus Interpreter::PrepareOpsAndTensors() { if (!memory_planner_) { memory_planner_.reset(new ArenaPlanner( - &context_, std::unique_ptr<GraphInfo>(new InterpreterInfo(this)))); + &context_, std::unique_ptr<GraphInfo>(new InterpreterInfo(this)), + /*preserve_inputs=*/true)); memory_planner_->PlanAllocations(); } diff --git a/tensorflow/contrib/lite/interpreter_test.cc b/tensorflow/contrib/lite/interpreter_test.cc index 21cdf87d1e..6f13b43ebf 100644 --- a/tensorflow/contrib/lite/interpreter_test.cc +++ b/tensorflow/contrib/lite/interpreter_test.cc @@ -231,32 +231,16 @@ TEST(BasicInterpreter, CheckArenaAllocation) { ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); - ASSERT_EQ(interpreter.tensor(0)->data.raw, interpreter.tensor(4)->data.raw); - ASSERT_EQ(interpreter.tensor(1)->data.raw, interpreter.tensor(7)->data.raw); - ASSERT_EQ(interpreter.tensor(8)->data.raw, nullptr); - - ASSERT_LT(interpreter.tensor(4)->data.raw, interpreter.tensor(1)->data.raw); - ASSERT_LT(interpreter.tensor(6)->data.raw, interpreter.tensor(1)->data.raw); ASSERT_LT(interpreter.tensor(0)->data.raw, interpreter.tensor(1)->data.raw); - - ASSERT_LT(interpreter.tensor(0)->data.raw, interpreter.tensor(3)->data.raw); - ASSERT_LT(interpreter.tensor(1)->data.raw, interpreter.tensor(3)->data.raw); + ASSERT_LT(interpreter.tensor(1)->data.raw, interpreter.tensor(2)->data.raw); ASSERT_LT(interpreter.tensor(2)->data.raw, interpreter.tensor(3)->data.raw); - ASSERT_LT(interpreter.tensor(4)->data.raw, interpreter.tensor(3)->data.raw); - ASSERT_LT(interpreter.tensor(6)->data.raw, interpreter.tensor(3)->data.raw); - ASSERT_LT(interpreter.tensor(7)->data.raw, interpreter.tensor(3)->data.raw); - ASSERT_LT(interpreter.tensor(8)->data.raw, interpreter.tensor(3)->data.raw); - ASSERT_LT(interpreter.tensor(9)->data.raw, interpreter.tensor(3)->data.raw); - - ASSERT_LT(interpreter.tensor(0)->data.raw, interpreter.tensor(5)->data.raw); - ASSERT_LT(interpreter.tensor(1)->data.raw, interpreter.tensor(5)->data.raw); - ASSERT_LT(interpreter.tensor(2)->data.raw, interpreter.tensor(5)->data.raw); - ASSERT_LT(interpreter.tensor(3)->data.raw, interpreter.tensor(5)->data.raw); + ASSERT_LT(interpreter.tensor(3)->data.raw, interpreter.tensor(4)->data.raw); ASSERT_LT(interpreter.tensor(4)->data.raw, interpreter.tensor(5)->data.raw); - ASSERT_LT(interpreter.tensor(6)->data.raw, interpreter.tensor(5)->data.raw); - ASSERT_LT(interpreter.tensor(7)->data.raw, interpreter.tensor(5)->data.raw); - ASSERT_LT(interpreter.tensor(8)->data.raw, interpreter.tensor(5)->data.raw); - ASSERT_LT(interpreter.tensor(9)->data.raw, interpreter.tensor(5)->data.raw); + ASSERT_LT(interpreter.tensor(5)->data.raw, interpreter.tensor(7)->data.raw); + ASSERT_EQ(interpreter.tensor(6)->data.raw, interpreter.tensor(2)->data.raw); + // #7 is the one with the largest pointer. + ASSERT_EQ(interpreter.tensor(8)->data.raw, nullptr); + ASSERT_EQ(interpreter.tensor(9)->data.raw, interpreter.tensor(5)->data.raw); } TEST(BasicInterpreter, BufferAccess) { |