diff options
Diffstat (limited to 'tensorflow/contrib/lite/interpreter_test.cc')
-rw-r--r-- | tensorflow/contrib/lite/interpreter_test.cc | 164 |
1 files changed, 131 insertions, 33 deletions
diff --git a/tensorflow/contrib/lite/interpreter_test.cc b/tensorflow/contrib/lite/interpreter_test.cc index 2e6727b323..11578fcb69 100644 --- a/tensorflow/contrib/lite/interpreter_test.cc +++ b/tensorflow/contrib/lite/interpreter_test.cc @@ -763,24 +763,38 @@ TfLiteRegistration AddOpRegistration() { } class TestDelegate : public ::testing::Test { - public: - TestDelegate() { - interpreter_.AddTensors(5); - interpreter_.SetInputs({0, 1}); - interpreter_.SetOutputs({3, 4}); + protected: + void SetUp() override { + interpreter_ = absl::make_unique<Interpreter>(); + interpreter_->AddTensors(5); + interpreter_->SetInputs({0, 1}); + interpreter_->SetOutputs({3, 4}); TfLiteQuantizationParams quant; - interpreter_.SetTensorParametersReadWrite(0, kTfLiteFloat32, "", {3}, - quant); - interpreter_.SetTensorParametersReadWrite(1, kTfLiteFloat32, "", {3}, - quant); - interpreter_.SetTensorParametersReadWrite(2, kTfLiteFloat32, "", {3}, - quant); - interpreter_.SetTensorParametersReadWrite(3, kTfLiteFloat32, "", {3}, - quant); + interpreter_->SetTensorParametersReadWrite(0, kTfLiteFloat32, "", {3}, + quant); + interpreter_->SetTensorParametersReadWrite(1, kTfLiteFloat32, "", {3}, + quant); + interpreter_->SetTensorParametersReadWrite(2, kTfLiteFloat32, "", {3}, + quant); + interpreter_->SetTensorParametersReadWrite(3, kTfLiteFloat32, "", {3}, + quant); TfLiteRegistration reg = AddOpRegistration(); - interpreter_.AddNodeWithParameters({0, 0}, {2}, nullptr, 0, nullptr, ®); - interpreter_.AddNodeWithParameters({1, 1}, {3}, nullptr, 0, nullptr, ®); - interpreter_.AddNodeWithParameters({2, 1}, {4}, nullptr, 0, nullptr, ®); + interpreter_->AddNodeWithParameters({0, 0}, {2}, nullptr, 0, nullptr, ®); + interpreter_->AddNodeWithParameters({1, 1}, {3}, nullptr, 0, nullptr, ®); + interpreter_->AddNodeWithParameters({2, 1}, {4}, nullptr, 0, nullptr, ®); + } + + void TearDown() override { + // Interpreter relies on delegate_ to free the resources properly. Thus + // the life cycle of delegate must be longer than interpreter. + interpreter_.reset(); + delegate_.reset(); + } + + TfLiteDelegateBufferHandle last_allocated_handle_ = kTfLiteNullBufferHandle; + + TfLiteDelegateBufferHandle AllocateBufferHandle() { + return ++last_allocated_handle_; } protected: @@ -791,8 +805,8 @@ class TestDelegate : public ::testing::Test { // value-copyable and compatible with TfLite. explicit SimpleDelegate(const std::vector<int>& nodes) : nodes_(nodes) { delegate_.Prepare = [](TfLiteContext* context, - void* data) -> TfLiteStatus { - auto* simple = reinterpret_cast<SimpleDelegate*>(data); + TfLiteDelegate* delegate) -> TfLiteStatus { + auto* simple = reinterpret_cast<SimpleDelegate*>(delegate->data_); TfLiteIntArray* nodes_to_separate = TfLiteIntArrayCreate(simple->nodes_.size()); // Mark nodes that we want in TfLiteIntArray* structure. @@ -823,10 +837,28 @@ class TestDelegate : public ::testing::Test { } context->ReplaceSubgraphsWithDelegateKernels( - context, FakeFusedRegistration(), nodes_to_separate); + context, FakeFusedRegistration(), nodes_to_separate, delegate); TfLiteIntArrayFree(nodes_to_separate); return kTfLiteOk; }; + delegate_.CopyToBufferHandle = + [](TfLiteDelegate* delegate, + TfLiteDelegateBufferHandle delegate_buffer_handle, void* data, + int size) -> TfLiteStatus { + // TODO(ycling): Implement tests to test buffer copying logic. + return kTfLiteOk; + }; + delegate_.CopyFromBufferHandle = + [](TfLiteDelegate* delegate, + TfLiteDelegateBufferHandle delegate_buffer_handle, void* data, + int size) -> TfLiteStatus { + // TODO(ycling): Implement tests to test buffer copying logic. + return kTfLiteOk; + }; + delegate_.FreeBufferHandle = [](TfLiteDelegate* delegate, + TfLiteDelegateBufferHandle* handle) { + *handle = kTfLiteNullBufferHandle; + }; // Store type-punned data SimpleDelegate structure. delegate_.data_ = reinterpret_cast<void*>(this); } @@ -843,36 +875,102 @@ class TestDelegate : public ::testing::Test { std::vector<int> nodes_; TfLiteDelegate delegate_; }; - Interpreter interpreter_; + std::unique_ptr<Interpreter> interpreter_; + std::unique_ptr<SimpleDelegate> delegate_; }; TEST_F(TestDelegate, BasicDelegate) { - interpreter_.Invoke(); - SimpleDelegate simple({0, 1, 2}); - interpreter_.ModifyGraphWithDelegate(simple.get_tf_lite_delegate()); + interpreter_->Invoke(); + delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate({0, 1, 2})); + interpreter_->ModifyGraphWithDelegate(delegate_->get_tf_lite_delegate()); - ASSERT_EQ(interpreter_.execution_plan().size(), 1); - int node = interpreter_.execution_plan()[0]; - const auto* node_and_reg = interpreter_.node_and_registration(node); + ASSERT_EQ(interpreter_->execution_plan().size(), 1); + int node = interpreter_->execution_plan()[0]; + const auto* node_and_reg = interpreter_->node_and_registration(node); ASSERT_EQ(node_and_reg->second.custom_name, SimpleDelegate::FakeFusedRegistration().custom_name); } TEST_F(TestDelegate, ComplexDeligate) { - interpreter_.Invoke(); - SimpleDelegate simple({1, 2}); - interpreter_.ModifyGraphWithDelegate(simple.get_tf_lite_delegate()); + interpreter_->Invoke(); + delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate({1, 2})); + interpreter_->ModifyGraphWithDelegate(delegate_->get_tf_lite_delegate()); - ASSERT_EQ(interpreter_.execution_plan().size(), 2); + ASSERT_EQ(interpreter_->execution_plan().size(), 2); // 0th should be a non-delegated original op - ASSERT_EQ(interpreter_.execution_plan()[0], 0); + ASSERT_EQ(interpreter_->execution_plan()[0], 0); // 1st should be a new macro op (3) which didn't exist) - ASSERT_EQ(interpreter_.execution_plan()[1], 3); - const auto* node_and_reg = interpreter_.node_and_registration(3); + ASSERT_EQ(interpreter_->execution_plan()[1], 3); + const auto* node_and_reg = interpreter_->node_and_registration(3); ASSERT_EQ(node_and_reg->second.custom_name, SimpleDelegate::FakeFusedRegistration().custom_name); } +TEST_F(TestDelegate, SetBufferHandleToInput) { + interpreter_->Invoke(); + delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate({0, 1, 2})); + TfLiteDelegate* delegate = delegate_->get_tf_lite_delegate(); + interpreter_->ModifyGraphWithDelegate(delegate); + + constexpr int kOutputTensorIndex = 0; + TfLiteTensor* tensor = interpreter_->tensor(kOutputTensorIndex); + ASSERT_EQ(tensor->delegate, nullptr); + ASSERT_EQ(tensor->delegate_buffer_handle, kTfLiteNullBufferHandle); + + TfLiteDelegateBufferHandle handle = AllocateBufferHandle(); + TfLiteStatus status = interpreter_->SetDelegateBufferHandle( + kOutputTensorIndex, handle, delegate); + ASSERT_EQ(status, kTfLiteOk); + EXPECT_EQ(tensor->delegate, delegate); + EXPECT_EQ(tensor->delegate_buffer_handle, handle); +} + +TEST_F(TestDelegate, SetBufferHandleToOutput) { + interpreter_->Invoke(); + delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate({0, 1, 2})); + TfLiteDelegate* delegate = delegate_->get_tf_lite_delegate(); + interpreter_->ModifyGraphWithDelegate(delegate); + + constexpr int kOutputTensorIndex = 3; + TfLiteTensor* tensor = interpreter_->tensor(kOutputTensorIndex); + // Before setting the buffer handle, the tensor's `delegate` is already set + // because it will be written by the delegate. + ASSERT_EQ(tensor->delegate, delegate); + ASSERT_EQ(tensor->delegate_buffer_handle, kTfLiteNullBufferHandle); + + TfLiteDelegateBufferHandle handle = AllocateBufferHandle(); + TfLiteStatus status = interpreter_->SetDelegateBufferHandle( + kOutputTensorIndex, handle, delegate); + ASSERT_EQ(status, kTfLiteOk); + EXPECT_EQ(tensor->delegate, delegate); + EXPECT_EQ(tensor->delegate_buffer_handle, handle); +} + +TEST_F(TestDelegate, SetInvalidHandleToTensor) { + interpreter_->Invoke(); + delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate({0, 1, 2})); + TfLiteDelegate* delegate = delegate_->get_tf_lite_delegate(); + interpreter_->ModifyGraphWithDelegate(delegate); + + SimpleDelegate another_simple_delegate({0, 1, 2}); + + constexpr int kOutputTensorIndex = 3; + TfLiteTensor* tensor = interpreter_->tensor(kOutputTensorIndex); + // Before setting the buffer handle, the tensor's `delegate` is already set + // because it will be written by the delegate. + ASSERT_EQ(tensor->delegate, delegate); + ASSERT_EQ(tensor->delegate_buffer_handle, kTfLiteNullBufferHandle); + + TfLiteDelegateBufferHandle handle = AllocateBufferHandle(); + TfLiteStatus status = interpreter_->SetDelegateBufferHandle( + kOutputTensorIndex, handle, + another_simple_delegate.get_tf_lite_delegate()); + // Setting a buffer handle to a tensor with another delegate will fail. + ASSERT_EQ(status, kTfLiteError); + EXPECT_EQ(tensor->delegate, delegate); + EXPECT_EQ(tensor->delegate_buffer_handle, kTfLiteNullBufferHandle); +} + } // namespace } // namespace tflite |