aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/interpreter_test.cc
diff options
context:
space:
mode:
authorGravatar Yu-Cheng Ling <ycling@google.com>2018-03-19 14:27:52 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-19 14:37:01 -0700
commitff43dff34ab525dd333128c73ebfb0f9723c34c0 (patch)
tree395c315428a0f084cfce47a562154ec15474e047 /tensorflow/contrib/lite/interpreter_test.cc
parente613e0844a95814457f3530eedb9baf812cf1e87 (diff)
TFLite Delegate: Add an `allow_dynamic_tensors` parameter.
PiperOrigin-RevId: 189641833
Diffstat (limited to 'tensorflow/contrib/lite/interpreter_test.cc')
-rw-r--r--tensorflow/contrib/lite/interpreter_test.cc118
1 files changed, 102 insertions, 16 deletions
diff --git a/tensorflow/contrib/lite/interpreter_test.cc b/tensorflow/contrib/lite/interpreter_test.cc
index 7a029c7df8..efb29d5c9d 100644
--- a/tensorflow/contrib/lite/interpreter_test.cc
+++ b/tensorflow/contrib/lite/interpreter_test.cc
@@ -17,9 +17,11 @@ limitations under the License.
#include <gtest/gtest.h>
#include "tensorflow/contrib/lite/error_reporter.h"
#include "tensorflow/contrib/lite/kernels/internal/compatibility.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
#include "tensorflow/contrib/lite/schema/schema_generated.h"
#include "tensorflow/contrib/lite/string_util.h"
#include "tensorflow/contrib/lite/testing/util.h"
+
namespace tflite {
namespace {
@@ -439,12 +441,12 @@ TEST(BasicInterpreter, ThreeStepAllocate) {
// String-in String-out node.
TfLiteRegistration reg_copy = {nullptr, nullptr, nullptr, nullptr};
reg_copy.invoke = [](TfLiteContext* context, TfLiteNode* node) {
- TfLiteTensor* a0 = &context->tensors[node->inputs->data[0]];
- TfLiteTensor* a1 = &context->tensors[node->outputs->data[0]];
+ TfLiteTensor* input = &context->tensors[node->inputs->data[0]];
+ TfLiteTensor* output = &context->tensors[node->outputs->data[0]];
DynamicBuffer buf;
- StringRef str_ref = GetString(a0, 0);
+ StringRef str_ref = GetString(input, 0);
buf.AddString(str_ref);
- buf.WriteToTensor(a1);
+ buf.WriteToTensor(output);
return kTfLiteOk;
};
@@ -778,13 +780,17 @@ TfLiteRegistration AddOpRegistration() {
reg.prepare = [](TfLiteContext* context, TfLiteNode* node) {
// Set output size to input size
- TfLiteTensor* tensor0 = &context->tensors[node->inputs->data[0]];
- TfLiteTensor* tensor1 = &context->tensors[node->inputs->data[1]];
- TfLiteTensor* tensor2 = &context->tensors[node->outputs->data[0]];
- TfLiteIntArray* newSize = TfLiteIntArrayCopy(tensor0->dims);
- TfLiteIntArray* newSizeOther = TfLiteIntArrayCopy(tensor1->dims);
- TF_LITE_ENSURE_EQ(context, newSize->size, newSizeOther->size);
- TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, tensor2, newSize));
+ TfLiteTensor* input1 = &context->tensors[node->inputs->data[0]];
+ TfLiteTensor* input2 = &context->tensors[node->inputs->data[1]];
+ TfLiteTensor* output = &context->tensors[node->outputs->data[0]];
+
+ TF_LITE_ENSURE_EQ(context, input1->dims->size, input2->dims->size);
+ for (int i = 0; i < input1->dims->size; ++i) {
+ TF_LITE_ENSURE_EQ(context, input1->dims->data[i], input2->dims->data[i]);
+ }
+
+ TF_LITE_ENSURE_STATUS(context->ResizeTensor(
+ context, output, TfLiteIntArrayCopy(input1->dims)));
return kTfLiteOk;
};
@@ -818,6 +824,8 @@ class TestDelegate : public ::testing::Test {
quant);
interpreter_->SetTensorParametersReadWrite(3, kTfLiteFloat32, "", {3},
quant);
+ interpreter_->SetTensorParametersReadWrite(4, kTfLiteFloat32, "", {3},
+ quant);
TfLiteRegistration reg = AddOpRegistration();
interpreter_->AddNodeWithParameters({0, 0}, {2}, nullptr, 0, nullptr, &reg);
interpreter_->AddNodeWithParameters({1, 1}, {3}, nullptr, 0, nullptr, &reg);
@@ -916,7 +924,6 @@ class TestDelegate : public ::testing::Test {
};
TEST_F(TestDelegate, BasicDelegate) {
- interpreter_->Invoke();
delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate({0, 1, 2}));
interpreter_->ModifyGraphWithDelegate(delegate_->get_tf_lite_delegate());
@@ -944,7 +951,6 @@ TEST_F(TestDelegate, BasicDelegate) {
}
TEST_F(TestDelegate, ComplexDeligate) {
- interpreter_->Invoke();
delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate({1, 2}));
interpreter_->ModifyGraphWithDelegate(delegate_->get_tf_lite_delegate());
@@ -959,7 +965,6 @@ TEST_F(TestDelegate, ComplexDeligate) {
}
TEST_F(TestDelegate, SetBufferHandleToInput) {
- interpreter_->Invoke();
delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate({0, 1, 2}));
TfLiteDelegate* delegate = delegate_->get_tf_lite_delegate();
interpreter_->ModifyGraphWithDelegate(delegate);
@@ -978,7 +983,6 @@ TEST_F(TestDelegate, SetBufferHandleToInput) {
}
TEST_F(TestDelegate, SetBufferHandleToOutput) {
- interpreter_->Invoke();
delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate({0, 1, 2}));
TfLiteDelegate* delegate = delegate_->get_tf_lite_delegate();
interpreter_->ModifyGraphWithDelegate(delegate);
@@ -1002,7 +1006,7 @@ TEST_F(TestDelegate, SetInvalidHandleToTensor) {
interpreter_->Invoke();
delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate({0, 1, 2}));
TfLiteDelegate* delegate = delegate_->get_tf_lite_delegate();
- interpreter_->ModifyGraphWithDelegate(delegate);
+ interpreter_->ModifyGraphWithDelegate(delegate, true);
SimpleDelegate another_simple_delegate({0, 1, 2});
@@ -1023,6 +1027,88 @@ TEST_F(TestDelegate, SetInvalidHandleToTensor) {
EXPECT_EQ(tensor->buffer_handle, kTfLiteNullBufferHandle);
}
+TEST_F(TestDelegate, ResizeInputWithNonDynamicDelegateShouldFail) {
+ delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate({0, 1, 2}));
+ ASSERT_EQ(interpreter_->ResizeInputTensor(0, {1, 2}), kTfLiteOk);
+ ASSERT_EQ(interpreter_->ResizeInputTensor(1, {1, 2}), kTfLiteOk);
+ ASSERT_EQ(
+ interpreter_->ModifyGraphWithDelegate(delegate_->get_tf_lite_delegate()),
+ kTfLiteOk);
+ ASSERT_EQ(interpreter_->ResizeInputTensor(0, {1, 2}), kTfLiteError);
+}
+
+class TestDelegateWithDynamicTensors : public ::testing::Test {
+ protected:
+ void SetUp() override {
+ interpreter_.reset(new Interpreter);
+
+ interpreter_->AddTensors(2);
+ interpreter_->SetInputs({0});
+ interpreter_->SetOutputs({1});
+ TfLiteQuantizationParams quant;
+ interpreter_->SetTensorParametersReadWrite(0, kTfLiteFloat32, "", {3},
+ quant);
+ interpreter_->SetTensorParametersReadWrite(1, kTfLiteFloat32, "", {3},
+ quant);
+ TfLiteRegistration reg = DynamicCopyOpRegistration();
+ interpreter_->AddNodeWithParameters({0}, {1}, nullptr, 0, nullptr, &reg);
+
+ delegate_.Prepare = [](TfLiteContext* context,
+ TfLiteDelegate* delegate) -> TfLiteStatus {
+ // In this test, the delegate replaces all the nodes if this function is
+ // called.
+ TfLiteIntArray* execution_plan;
+ TF_LITE_ENSURE_STATUS(
+ context->GetExecutionPlan(context, &execution_plan));
+ context->ReplaceSubgraphsWithDelegateKernels(
+ context, DelegateRegistration(), execution_plan, delegate);
+ return kTfLiteOk;
+ };
+ }
+
+ static TfLiteRegistration DynamicCopyOpRegistration() {
+ TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr};
+
+ reg.prepare = [](TfLiteContext* context, TfLiteNode* node) {
+ TfLiteTensor* output = &context->tensors[node->outputs->data[0]];
+ SetTensorToDynamic(output);
+ return kTfLiteOk;
+ };
+
+ reg.invoke = [](TfLiteContext* context, TfLiteNode* node) {
+ // Not implemented since this isn't required in testing.
+ return kTfLiteOk;
+ };
+ return reg;
+ }
+
+ static TfLiteRegistration DelegateRegistration() {
+ TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr};
+ return reg;
+ }
+
+ std::unique_ptr<Interpreter> interpreter_;
+ TfLiteDelegate delegate_;
+};
+
+TEST_F(TestDelegateWithDynamicTensors, DisallowDynamicTensors) {
+ interpreter_->ModifyGraphWithDelegate(&delegate_, false);
+
+ ASSERT_EQ(interpreter_->execution_plan().size(), 1);
+ // The interpreter should not call delegate's `Prepare` when dynamic tensors
+ // exist. So the node ID isn't changed.
+ ASSERT_EQ(interpreter_->execution_plan()[0], 0);
+}
+
+TEST_F(TestDelegateWithDynamicTensors, AllowDynamicTensors) {
+ interpreter_->ModifyGraphWithDelegate(&delegate_, true);
+
+ ASSERT_EQ(interpreter_->execution_plan().size(), 1);
+ // The node should be replaced because dynamic tensors are allowed. Therefore
+ // only node ID in the execution plan is changed from 0 to 1.
+ ASSERT_EQ(interpreter_->execution_plan()[0], 1);
+}
+
} // namespace
} // namespace tflite