aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-29 16:19:37 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-29 16:25:43 -0700
commit8648bd52264116760c54de16ffbce6c98d7397e8 (patch)
treeb26123fca57aab79b86c45c99a1b7724ceceec44
parentd7642767d24464127aae8c118caad597dea9e017 (diff)
Do not overwrite inputs.
PiperOrigin-RevId: 202724720
-rw-r--r--tensorflow/contrib/lite/arena_planner.cc13
-rw-r--r--tensorflow/contrib/lite/arena_planner.h9
-rw-r--r--tensorflow/contrib/lite/arena_planner_test.cc29
-rw-r--r--tensorflow/contrib/lite/interpreter.cc3
-rw-r--r--tensorflow/contrib/lite/interpreter_test.cc30
5 files changed, 52 insertions, 32 deletions
diff --git a/tensorflow/contrib/lite/arena_planner.cc b/tensorflow/contrib/lite/arena_planner.cc
index 22be64d6ff..4257e754ad 100644
--- a/tensorflow/contrib/lite/arena_planner.cc
+++ b/tensorflow/contrib/lite/arena_planner.cc
@@ -35,12 +35,13 @@ struct AllocationInfo {
};
ArenaPlanner::ArenaPlanner(TfLiteContext* context,
- std::unique_ptr<GraphInfo> graph_info)
+ std::unique_ptr<GraphInfo> graph_info,
+ bool preserve_inputs)
: context_(context),
graph_info_(std::move(graph_info)),
arena_(kDefaultArenaAlignment),
- persistent_arena_(kDefaultArenaAlignment) {}
-
+ persistent_arena_(kDefaultArenaAlignment),
+ preserve_inputs_(preserve_inputs) {}
ArenaPlanner::~ArenaPlanner() {}
int64_t ArenaPlanner::BasePointer(TfLiteAllocationType type) {
@@ -112,9 +113,13 @@ TfLiteStatus ArenaPlanner::PlanAllocations() {
refcounts[tensor_index]++;
}
- // Queue all graph inputs for allocation.
+ // Queue all graph inputs for allocation. If preserve_inputs_ is true, make
+ // sure they never be overwritten.
for (int tensor_index : graph_info_->inputs()) {
if (tensor_index != kOptionalTensor) {
+ if (preserve_inputs_) {
+ refcounts[tensor_index]++;
+ }
TF_LITE_ENSURE_STATUS(allocate(0, tensor_index));
}
}
diff --git a/tensorflow/contrib/lite/arena_planner.h b/tensorflow/contrib/lite/arena_planner.h
index e9d0fbc5a9..1d84950e91 100644
--- a/tensorflow/contrib/lite/arena_planner.h
+++ b/tensorflow/contrib/lite/arena_planner.h
@@ -43,8 +43,11 @@ struct AllocationInfo;
class ArenaPlanner : public MemoryPlanner {
public:
// Ownership of 'context' is not taken and it must remain util the
- // ArenaPlanner is destroyed.
- ArenaPlanner(TfLiteContext* context, std::unique_ptr<GraphInfo> graph_info);
+ // ArenaPlanner is destroyed. If 'preserve_inputs' is true the inputs to the
+ // graph will not share memory with any other tensor, effectively preserving
+ // them until the end of inference.
+ ArenaPlanner(TfLiteContext* context, std::unique_ptr<GraphInfo> graph_info,
+ bool preserve_inputs);
~ArenaPlanner() override;
ArenaPlanner(const ArenaPlanner&) = delete;
ArenaPlanner& operator=(const ArenaPlanner&) = delete;
@@ -100,6 +103,8 @@ class ArenaPlanner : public MemoryPlanner {
// Raw memory buffer that is allocated for persistent tensors that are
// declared as kTfLiteArenaRwPersistent.
SimpleMemoryArena persistent_arena_;
+
+ bool preserve_inputs_;
};
} // namespace tflite
diff --git a/tensorflow/contrib/lite/arena_planner_test.cc b/tensorflow/contrib/lite/arena_planner_test.cc
index f0fd35216f..f5bd1932f9 100644
--- a/tensorflow/contrib/lite/arena_planner_test.cc
+++ b/tensorflow/contrib/lite/arena_planner_test.cc
@@ -151,11 +151,12 @@ void ReportError(TfLiteContext* context, const char* format, ...) {
class ArenaPlannerTest : public ::testing::Test {
protected:
- void SetGraph(TestGraph* graph) {
+ void SetGraph(TestGraph* graph, bool preserve_inputs = false) {
graph_ = graph;
context_.ReportError = ReportError;
planner_.reset(new ArenaPlanner(
- &context_, std::unique_ptr<GraphInfo>(new TestGraphInfo(graph))));
+ &context_, std::unique_ptr<GraphInfo>(new TestGraphInfo(graph)),
+ preserve_inputs));
CHECK(planner_->ResetAllocations() == kTfLiteOk);
CHECK(planner_->PlanAllocations() == kTfLiteOk);
}
@@ -243,6 +244,30 @@ TEST_F(ArenaPlannerTest, SimpleGraph) {
EXPECT_EQ(GetOffset(3), 0);
}
+TEST_F(ArenaPlannerTest, SimpleGraphInputsPreserved) {
+ TestGraph graph({0, 1},
+ {
+ /* in, out, tmp */
+ {{0, 1}, {2}, {}}, // First op
+ {{2, 0}, {4, 5}, {}}, // Second op
+ {{4, 5}, {3}, {}} // Third op
+ },
+ {3});
+ SetGraph(&graph, /*preserve_inputs=*/true);
+ Execute(0, 10);
+
+ // Alloc(+) and dealloc(-) order: +0 +1 +2 +4 +5 -2 +3 -4 -5
+ EXPECT_EQ(GetOffset(0), 0);
+ EXPECT_EQ(GetOffset(1), GetOffsetAfter(0));
+ EXPECT_EQ(GetOffset(2), GetOffsetAfter(1));
+ EXPECT_EQ(GetOffset(4), GetOffsetAfter(2));
+ EXPECT_EQ(GetOffset(5), GetOffsetAfter(4));
+ // Because we are keeping the inputs alive until the end (due to
+ // preserve_inputs=true), the output tensor will not be able to use that
+ // space. It will end up using the same are as tensor #2.
+ EXPECT_EQ(GetOffset(3), GetOffsetAfter(1));
+}
+
TEST_F(ArenaPlannerTest, SimpleGraphWithTemporary) {
TestGraph graph({0, 1},
{
diff --git a/tensorflow/contrib/lite/interpreter.cc b/tensorflow/contrib/lite/interpreter.cc
index 4b3ba5df10..dcb4ef593e 100644
--- a/tensorflow/contrib/lite/interpreter.cc
+++ b/tensorflow/contrib/lite/interpreter.cc
@@ -538,7 +538,8 @@ TfLiteStatus Interpreter::PrepareOpsStartingAt(
TfLiteStatus Interpreter::PrepareOpsAndTensors() {
if (!memory_planner_) {
memory_planner_.reset(new ArenaPlanner(
- &context_, std::unique_ptr<GraphInfo>(new InterpreterInfo(this))));
+ &context_, std::unique_ptr<GraphInfo>(new InterpreterInfo(this)),
+ /*preserve_inputs=*/true));
memory_planner_->PlanAllocations();
}
diff --git a/tensorflow/contrib/lite/interpreter_test.cc b/tensorflow/contrib/lite/interpreter_test.cc
index 21cdf87d1e..6f13b43ebf 100644
--- a/tensorflow/contrib/lite/interpreter_test.cc
+++ b/tensorflow/contrib/lite/interpreter_test.cc
@@ -231,32 +231,16 @@ TEST(BasicInterpreter, CheckArenaAllocation) {
ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
- ASSERT_EQ(interpreter.tensor(0)->data.raw, interpreter.tensor(4)->data.raw);
- ASSERT_EQ(interpreter.tensor(1)->data.raw, interpreter.tensor(7)->data.raw);
- ASSERT_EQ(interpreter.tensor(8)->data.raw, nullptr);
-
- ASSERT_LT(interpreter.tensor(4)->data.raw, interpreter.tensor(1)->data.raw);
- ASSERT_LT(interpreter.tensor(6)->data.raw, interpreter.tensor(1)->data.raw);
ASSERT_LT(interpreter.tensor(0)->data.raw, interpreter.tensor(1)->data.raw);
-
- ASSERT_LT(interpreter.tensor(0)->data.raw, interpreter.tensor(3)->data.raw);
- ASSERT_LT(interpreter.tensor(1)->data.raw, interpreter.tensor(3)->data.raw);
+ ASSERT_LT(interpreter.tensor(1)->data.raw, interpreter.tensor(2)->data.raw);
ASSERT_LT(interpreter.tensor(2)->data.raw, interpreter.tensor(3)->data.raw);
- ASSERT_LT(interpreter.tensor(4)->data.raw, interpreter.tensor(3)->data.raw);
- ASSERT_LT(interpreter.tensor(6)->data.raw, interpreter.tensor(3)->data.raw);
- ASSERT_LT(interpreter.tensor(7)->data.raw, interpreter.tensor(3)->data.raw);
- ASSERT_LT(interpreter.tensor(8)->data.raw, interpreter.tensor(3)->data.raw);
- ASSERT_LT(interpreter.tensor(9)->data.raw, interpreter.tensor(3)->data.raw);
-
- ASSERT_LT(interpreter.tensor(0)->data.raw, interpreter.tensor(5)->data.raw);
- ASSERT_LT(interpreter.tensor(1)->data.raw, interpreter.tensor(5)->data.raw);
- ASSERT_LT(interpreter.tensor(2)->data.raw, interpreter.tensor(5)->data.raw);
- ASSERT_LT(interpreter.tensor(3)->data.raw, interpreter.tensor(5)->data.raw);
+ ASSERT_LT(interpreter.tensor(3)->data.raw, interpreter.tensor(4)->data.raw);
ASSERT_LT(interpreter.tensor(4)->data.raw, interpreter.tensor(5)->data.raw);
- ASSERT_LT(interpreter.tensor(6)->data.raw, interpreter.tensor(5)->data.raw);
- ASSERT_LT(interpreter.tensor(7)->data.raw, interpreter.tensor(5)->data.raw);
- ASSERT_LT(interpreter.tensor(8)->data.raw, interpreter.tensor(5)->data.raw);
- ASSERT_LT(interpreter.tensor(9)->data.raw, interpreter.tensor(5)->data.raw);
+ ASSERT_LT(interpreter.tensor(5)->data.raw, interpreter.tensor(7)->data.raw);
+ ASSERT_EQ(interpreter.tensor(6)->data.raw, interpreter.tensor(2)->data.raw);
+ // #7 is the one with the largest pointer.
+ ASSERT_EQ(interpreter.tensor(8)->data.raw, nullptr);
+ ASSERT_EQ(interpreter.tensor(9)->data.raw, interpreter.tensor(5)->data.raw);
}
TEST(BasicInterpreter, BufferAccess) {