aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/delegates/eager/delegate_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/delegates/eager/delegate_test.cc')
-rw-r--r--tensorflow/contrib/lite/delegates/eager/delegate_test.cc52
1 files changed, 50 insertions, 2 deletions
diff --git a/tensorflow/contrib/lite/delegates/eager/delegate_test.cc b/tensorflow/contrib/lite/delegates/eager/delegate_test.cc
index 88fb34044e..511a239363 100644
--- a/tensorflow/contrib/lite/delegates/eager/delegate_test.cc
+++ b/tensorflow/contrib/lite/delegates/eager/delegate_test.cc
@@ -25,8 +25,6 @@ namespace {
using ::testing::ContainsRegex;
using ::testing::ElementsAre;
-// TODO(nupurgarg): Add a test with multiple interpreters for one delegate.
-
class DelegateTest : public testing::EagerModelTest {
public:
DelegateTest() {
@@ -139,6 +137,56 @@ TEST_F(DelegateTest, OnlyTFLite) {
ASSERT_THAT(GetValues(2), ElementsAre(1.1f, 4.4f, 9.9f, 17.6f));
}
+TEST_F(DelegateTest, MultipleInterpretersSameDelegate) {
+ // Build a graph, configure the delegate and set inputs.
+ {
+ AddTensors(9, {0, 3}, {8}, kTfLiteFloat32, {3});
+ AddTfOp(testing::kUnpack, {0}, {1, 2});
+ AddTfOp(testing::kUnpack, {3}, {4, 5});
+ AddTfOp(testing::kAdd, {1, 4}, {6});
+ AddTfOp(testing::kAdd, {2, 5}, {7});
+ AddTfOp(testing::kMul, {6, 7}, {8});
+ ConfigureDelegate();
+ SetShape(0, {2, 2, 1});
+ SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f});
+ SetShape(3, {2, 2, 1});
+ SetValues(3, {1.1f, 2.2f, 3.3f, 4.4f});
+ }
+
+ // Create a new interpreter, inject into the test framework and build
+ // a different graph using the *same* delegate.
+ std::unique_ptr<Interpreter> interpreter(new Interpreter(&error_reporter_));
+ interpreter_.swap(interpreter);
+ {
+ AddTensors(10, {0}, {9}, kTfLiteFloat32, {3});
+ AddTfOp(testing::kUnpack, {0}, {1, 2});
+ AddTfOp(testing::kAdd, {1, 2}, {3});
+ AddTfOp(testing::kUnpack, {3}, {4, 5});
+ AddTfLiteMulOp({4, 5}, {6});
+ AddTfOp(testing::kUnpack, {6}, {7, 8});
+ AddTfOp(testing::kAdd, {7, 8}, {9});
+ ConfigureDelegate();
+ SetShape(0, {2, 2, 2, 1});
+ SetValues(0, {3.0f, 1.0f, 0.5f, -1.0f, 0.0f, 1.0f, 1.5f, 3.0f});
+ }
+
+ // Swap back in the first interpreter and validate inference.
+ interpreter_.swap(interpreter);
+ {
+ ASSERT_TRUE(Invoke());
+ EXPECT_THAT(GetShape(8), ElementsAre(2, 1));
+ EXPECT_THAT(GetValues(8), ElementsAre(14.52f, 38.72f));
+ }
+
+ // Swap in the second interpreter and validate inference.
+ interpreter_.swap(interpreter);
+ {
+ ASSERT_TRUE(Invoke());
+ EXPECT_THAT(GetShape(9), ElementsAre(1));
+ EXPECT_THAT(GetValues(9), ElementsAre(10.0f));
+ }
+}
+
} // namespace
} // namespace eager
} // namespace tflite