aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/arena_planner_test.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-29 16:19:37 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-29 16:25:43 -0700
commit8648bd52264116760c54de16ffbce6c98d7397e8 (patch)
treeb26123fca57aab79b86c45c99a1b7724ceceec44 /tensorflow/contrib/lite/arena_planner_test.cc
parentd7642767d24464127aae8c118caad597dea9e017 (diff)
Do not overwrite inputs.
PiperOrigin-RevId: 202724720
Diffstat (limited to 'tensorflow/contrib/lite/arena_planner_test.cc')
-rw-r--r--tensorflow/contrib/lite/arena_planner_test.cc29
1 files changed, 27 insertions, 2 deletions
diff --git a/tensorflow/contrib/lite/arena_planner_test.cc b/tensorflow/contrib/lite/arena_planner_test.cc
index f0fd35216f..f5bd1932f9 100644
--- a/tensorflow/contrib/lite/arena_planner_test.cc
+++ b/tensorflow/contrib/lite/arena_planner_test.cc
@@ -151,11 +151,12 @@ void ReportError(TfLiteContext* context, const char* format, ...) {
class ArenaPlannerTest : public ::testing::Test {
protected:
- void SetGraph(TestGraph* graph) {
+ void SetGraph(TestGraph* graph, bool preserve_inputs = false) {
graph_ = graph;
context_.ReportError = ReportError;
planner_.reset(new ArenaPlanner(
- &context_, std::unique_ptr<GraphInfo>(new TestGraphInfo(graph))));
+ &context_, std::unique_ptr<GraphInfo>(new TestGraphInfo(graph)),
+ preserve_inputs));
CHECK(planner_->ResetAllocations() == kTfLiteOk);
CHECK(planner_->PlanAllocations() == kTfLiteOk);
}
@@ -243,6 +244,30 @@ TEST_F(ArenaPlannerTest, SimpleGraph) {
EXPECT_EQ(GetOffset(3), 0);
}
+TEST_F(ArenaPlannerTest, SimpleGraphInputsPreserved) {
+ TestGraph graph({0, 1},
+ {
+ /* in, out, tmp */
+ {{0, 1}, {2}, {}}, // First op
+ {{2, 0}, {4, 5}, {}}, // Second op
+ {{4, 5}, {3}, {}} // Third op
+ },
+ {3});
+ SetGraph(&graph, /*preserve_inputs=*/true);
+ Execute(0, 10);
+
+ // Alloc(+) and dealloc(-) order: +0 +1 +2 +4 +5 -2 +3 -4 -5
+ EXPECT_EQ(GetOffset(0), 0);
+ EXPECT_EQ(GetOffset(1), GetOffsetAfter(0));
+ EXPECT_EQ(GetOffset(2), GetOffsetAfter(1));
+ EXPECT_EQ(GetOffset(4), GetOffsetAfter(2));
+ EXPECT_EQ(GetOffset(5), GetOffsetAfter(4));
+ // Because we are keeping the inputs alive until the end (due to
+ // preserve_inputs=true), the output tensor will not be able to use that
+ // space. It will end up using the same are as tensor #2.
+ EXPECT_EQ(GetOffset(3), GetOffsetAfter(1));
+}
+
TEST_F(ArenaPlannerTest, SimpleGraphWithTemporary) {
TestGraph graph({0, 1},
{