aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/interpreter_test.cc
diff options
context:
space:
mode:
authorGravatar Yu-Cheng Ling <ycling@google.com>2018-03-08 11:56:29 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-08 12:00:28 -0800
commit6e3a43f4b7a1288c878b5daff274f1229256fbe8 (patch)
tree65dbe6000d5bf896015d64519a2d5783c1857e83 /tensorflow/contrib/lite/interpreter_test.cc
parent214ad0978641a946c25b334c4a33ecd1793b4d70 (diff)
TFLite: Delegate Buffer Handle interface (take 2)
PiperOrigin-RevId: 188366045
Diffstat (limited to 'tensorflow/contrib/lite/interpreter_test.cc')
-rw-r--r--tensorflow/contrib/lite/interpreter_test.cc160
1 files changed, 127 insertions, 33 deletions
diff --git a/tensorflow/contrib/lite/interpreter_test.cc b/tensorflow/contrib/lite/interpreter_test.cc
index 2e6727b323..2586c15287 100644
--- a/tensorflow/contrib/lite/interpreter_test.cc
+++ b/tensorflow/contrib/lite/interpreter_test.cc
@@ -763,26 +763,38 @@ TfLiteRegistration AddOpRegistration() {
}
class TestDelegate : public ::testing::Test {
- public:
- TestDelegate() {
- interpreter_.AddTensors(5);
- interpreter_.SetInputs({0, 1});
- interpreter_.SetOutputs({3, 4});
+ protected:
+ void SetUp() override {
+ interpreter_.reset(new Interpreter);
+ interpreter_->AddTensors(5);
+ interpreter_->SetInputs({0, 1});
+ interpreter_->SetOutputs({3, 4});
TfLiteQuantizationParams quant;
- interpreter_.SetTensorParametersReadWrite(0, kTfLiteFloat32, "", {3},
- quant);
- interpreter_.SetTensorParametersReadWrite(1, kTfLiteFloat32, "", {3},
- quant);
- interpreter_.SetTensorParametersReadWrite(2, kTfLiteFloat32, "", {3},
- quant);
- interpreter_.SetTensorParametersReadWrite(3, kTfLiteFloat32, "", {3},
- quant);
+ interpreter_->SetTensorParametersReadWrite(0, kTfLiteFloat32, "", {3},
+ quant);
+ interpreter_->SetTensorParametersReadWrite(1, kTfLiteFloat32, "", {3},
+ quant);
+ interpreter_->SetTensorParametersReadWrite(2, kTfLiteFloat32, "", {3},
+ quant);
+ interpreter_->SetTensorParametersReadWrite(3, kTfLiteFloat32, "", {3},
+ quant);
TfLiteRegistration reg = AddOpRegistration();
- interpreter_.AddNodeWithParameters({0, 0}, {2}, nullptr, 0, nullptr, &reg);
- interpreter_.AddNodeWithParameters({1, 1}, {3}, nullptr, 0, nullptr, &reg);
- interpreter_.AddNodeWithParameters({2, 1}, {4}, nullptr, 0, nullptr, &reg);
+ interpreter_->AddNodeWithParameters({0, 0}, {2}, nullptr, 0, nullptr, &reg);
+ interpreter_->AddNodeWithParameters({1, 1}, {3}, nullptr, 0, nullptr, &reg);
+ interpreter_->AddNodeWithParameters({2, 1}, {4}, nullptr, 0, nullptr, &reg);
}
+ void TearDown() override {
+ // Interpreter relies on delegate_ to free the resources properly. Thus
+ // the life cycle of delegate must be longer than interpreter.
+ interpreter_.reset();
+ delegate_.reset();
+ }
+
+ TfLiteBufferHandle last_allocated_handle_ = kTfLiteNullBufferHandle;
+
+ TfLiteBufferHandle AllocateBufferHandle() { return ++last_allocated_handle_; }
+
protected:
class SimpleDelegate {
public:
@@ -791,8 +803,8 @@ class TestDelegate : public ::testing::Test {
// value-copyable and compatible with TfLite.
explicit SimpleDelegate(const std::vector<int>& nodes) : nodes_(nodes) {
delegate_.Prepare = [](TfLiteContext* context,
- void* data) -> TfLiteStatus {
- auto* simple = reinterpret_cast<SimpleDelegate*>(data);
+ TfLiteDelegate* delegate) -> TfLiteStatus {
+ auto* simple = reinterpret_cast<SimpleDelegate*>(delegate->data_);
TfLiteIntArray* nodes_to_separate =
TfLiteIntArrayCreate(simple->nodes_.size());
// Mark nodes that we want in TfLiteIntArray* structure.
@@ -823,10 +835,26 @@ class TestDelegate : public ::testing::Test {
}
context->ReplaceSubgraphsWithDelegateKernels(
- context, FakeFusedRegistration(), nodes_to_separate);
+ context, FakeFusedRegistration(), nodes_to_separate, delegate);
TfLiteIntArrayFree(nodes_to_separate);
return kTfLiteOk;
};
+ delegate_.CopyToBufferHandle = [](TfLiteDelegate* delegate,
+ TfLiteBufferHandle buffer_handle,
+ void* data, int size) -> TfLiteStatus {
+ // TODO(ycling): Implement tests to test buffer copying logic.
+ return kTfLiteOk;
+ };
+ delegate_.CopyFromBufferHandle =
+ [](TfLiteDelegate* delegate, TfLiteBufferHandle buffer_handle,
+ void* data, int size) -> TfLiteStatus {
+ // TODO(ycling): Implement tests to test buffer copying logic.
+ return kTfLiteOk;
+ };
+ delegate_.FreeBufferHandle = [](TfLiteDelegate* delegate,
+ TfLiteBufferHandle* handle) {
+ *handle = kTfLiteNullBufferHandle;
+ };
// Store type-punned data SimpleDelegate structure.
delegate_.data_ = reinterpret_cast<void*>(this);
}
@@ -843,36 +871,102 @@ class TestDelegate : public ::testing::Test {
std::vector<int> nodes_;
TfLiteDelegate delegate_;
};
- Interpreter interpreter_;
+ std::unique_ptr<Interpreter> interpreter_;
+ std::unique_ptr<SimpleDelegate> delegate_;
};
TEST_F(TestDelegate, BasicDelegate) {
- interpreter_.Invoke();
- SimpleDelegate simple({0, 1, 2});
- interpreter_.ModifyGraphWithDelegate(simple.get_tf_lite_delegate());
+ interpreter_->Invoke();
+ delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate({0, 1, 2}));
+ interpreter_->ModifyGraphWithDelegate(delegate_->get_tf_lite_delegate());
- ASSERT_EQ(interpreter_.execution_plan().size(), 1);
- int node = interpreter_.execution_plan()[0];
- const auto* node_and_reg = interpreter_.node_and_registration(node);
+ ASSERT_EQ(interpreter_->execution_plan().size(), 1);
+ int node = interpreter_->execution_plan()[0];
+ const auto* node_and_reg = interpreter_->node_and_registration(node);
ASSERT_EQ(node_and_reg->second.custom_name,
SimpleDelegate::FakeFusedRegistration().custom_name);
}
TEST_F(TestDelegate, ComplexDeligate) {
- interpreter_.Invoke();
- SimpleDelegate simple({1, 2});
- interpreter_.ModifyGraphWithDelegate(simple.get_tf_lite_delegate());
+ interpreter_->Invoke();
+ delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate({1, 2}));
+ interpreter_->ModifyGraphWithDelegate(delegate_->get_tf_lite_delegate());
- ASSERT_EQ(interpreter_.execution_plan().size(), 2);
+ ASSERT_EQ(interpreter_->execution_plan().size(), 2);
// 0th should be a non-delegated original op
- ASSERT_EQ(interpreter_.execution_plan()[0], 0);
+ ASSERT_EQ(interpreter_->execution_plan()[0], 0);
// 1st should be a new macro op (3) which didn't exist)
- ASSERT_EQ(interpreter_.execution_plan()[1], 3);
- const auto* node_and_reg = interpreter_.node_and_registration(3);
+ ASSERT_EQ(interpreter_->execution_plan()[1], 3);
+ const auto* node_and_reg = interpreter_->node_and_registration(3);
ASSERT_EQ(node_and_reg->second.custom_name,
SimpleDelegate::FakeFusedRegistration().custom_name);
}
+TEST_F(TestDelegate, SetBufferHandleToInput) {
+ interpreter_->Invoke();
+ delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate({0, 1, 2}));
+ TfLiteDelegate* delegate = delegate_->get_tf_lite_delegate();
+ interpreter_->ModifyGraphWithDelegate(delegate);
+
+ constexpr int kOutputTensorIndex = 0;
+ TfLiteTensor* tensor = interpreter_->tensor(kOutputTensorIndex);
+ ASSERT_EQ(tensor->delegate, nullptr);
+ ASSERT_EQ(tensor->buffer_handle, kTfLiteNullBufferHandle);
+
+ TfLiteBufferHandle handle = AllocateBufferHandle();
+ TfLiteStatus status =
+ interpreter_->SetBufferHandle(kOutputTensorIndex, handle, delegate);
+ ASSERT_EQ(status, kTfLiteOk);
+ EXPECT_EQ(tensor->delegate, delegate);
+ EXPECT_EQ(tensor->buffer_handle, handle);
+}
+
+TEST_F(TestDelegate, SetBufferHandleToOutput) {
+ interpreter_->Invoke();
+ delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate({0, 1, 2}));
+ TfLiteDelegate* delegate = delegate_->get_tf_lite_delegate();
+ interpreter_->ModifyGraphWithDelegate(delegate);
+
+ constexpr int kOutputTensorIndex = 3;
+ TfLiteTensor* tensor = interpreter_->tensor(kOutputTensorIndex);
+ // Before setting the buffer handle, the tensor's `delegate` is already set
+ // because it will be written by the delegate.
+ ASSERT_EQ(tensor->delegate, delegate);
+ ASSERT_EQ(tensor->buffer_handle, kTfLiteNullBufferHandle);
+
+ TfLiteBufferHandle handle = AllocateBufferHandle();
+ TfLiteStatus status =
+ interpreter_->SetBufferHandle(kOutputTensorIndex, handle, delegate);
+ ASSERT_EQ(status, kTfLiteOk);
+ EXPECT_EQ(tensor->delegate, delegate);
+ EXPECT_EQ(tensor->buffer_handle, handle);
+}
+
+TEST_F(TestDelegate, SetInvalidHandleToTensor) {
+ interpreter_->Invoke();
+ delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate({0, 1, 2}));
+ TfLiteDelegate* delegate = delegate_->get_tf_lite_delegate();
+ interpreter_->ModifyGraphWithDelegate(delegate);
+
+ SimpleDelegate another_simple_delegate({0, 1, 2});
+
+ constexpr int kOutputTensorIndex = 3;
+ TfLiteTensor* tensor = interpreter_->tensor(kOutputTensorIndex);
+ // Before setting the buffer handle, the tensor's `delegate` is already set
+ // because it will be written by the delegate.
+ ASSERT_EQ(tensor->delegate, delegate);
+ ASSERT_EQ(tensor->buffer_handle, kTfLiteNullBufferHandle);
+
+ TfLiteBufferHandle handle = AllocateBufferHandle();
+ TfLiteStatus status = interpreter_->SetBufferHandle(
+ kOutputTensorIndex, handle,
+ another_simple_delegate.get_tf_lite_delegate());
+ // Setting a buffer handle to a tensor with another delegate will fail.
+ ASSERT_EQ(status, kTfLiteError);
+ EXPECT_EQ(tensor->delegate, delegate);
+ EXPECT_EQ(tensor->buffer_handle, kTfLiteNullBufferHandle);
+}
+
} // namespace
} // namespace tflite