diff options
author | 2018-03-19 14:27:52 -0700 | |
---|---|---|
committer | 2018-03-19 14:37:01 -0700 | |
commit | ff43dff34ab525dd333128c73ebfb0f9723c34c0 (patch) | |
tree | 395c315428a0f084cfce47a562154ec15474e047 /tensorflow/contrib/lite/interpreter_test.cc | |
parent | e613e0844a95814457f3530eedb9baf812cf1e87 (diff) |
TFLite Delegate: Add an `allow_dynamic_tensors` parameter.
PiperOrigin-RevId: 189641833
Diffstat (limited to 'tensorflow/contrib/lite/interpreter_test.cc')
-rw-r--r-- | tensorflow/contrib/lite/interpreter_test.cc | 118 |
1 files changed, 102 insertions, 16 deletions
diff --git a/tensorflow/contrib/lite/interpreter_test.cc b/tensorflow/contrib/lite/interpreter_test.cc index 7a029c7df8..efb29d5c9d 100644 --- a/tensorflow/contrib/lite/interpreter_test.cc +++ b/tensorflow/contrib/lite/interpreter_test.cc @@ -17,9 +17,11 @@ limitations under the License. #include <gtest/gtest.h> #include "tensorflow/contrib/lite/error_reporter.h" #include "tensorflow/contrib/lite/kernels/internal/compatibility.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" #include "tensorflow/contrib/lite/schema/schema_generated.h" #include "tensorflow/contrib/lite/string_util.h" #include "tensorflow/contrib/lite/testing/util.h" + namespace tflite { namespace { @@ -439,12 +441,12 @@ TEST(BasicInterpreter, ThreeStepAllocate) { // String-in String-out node. TfLiteRegistration reg_copy = {nullptr, nullptr, nullptr, nullptr}; reg_copy.invoke = [](TfLiteContext* context, TfLiteNode* node) { - TfLiteTensor* a0 = &context->tensors[node->inputs->data[0]]; - TfLiteTensor* a1 = &context->tensors[node->outputs->data[0]]; + TfLiteTensor* input = &context->tensors[node->inputs->data[0]]; + TfLiteTensor* output = &context->tensors[node->outputs->data[0]]; DynamicBuffer buf; - StringRef str_ref = GetString(a0, 0); + StringRef str_ref = GetString(input, 0); buf.AddString(str_ref); - buf.WriteToTensor(a1); + buf.WriteToTensor(output); return kTfLiteOk; }; @@ -778,13 +780,17 @@ TfLiteRegistration AddOpRegistration() { reg.prepare = [](TfLiteContext* context, TfLiteNode* node) { // Set output size to input size - TfLiteTensor* tensor0 = &context->tensors[node->inputs->data[0]]; - TfLiteTensor* tensor1 = &context->tensors[node->inputs->data[1]]; - TfLiteTensor* tensor2 = &context->tensors[node->outputs->data[0]]; - TfLiteIntArray* newSize = TfLiteIntArrayCopy(tensor0->dims); - TfLiteIntArray* newSizeOther = TfLiteIntArrayCopy(tensor1->dims); - TF_LITE_ENSURE_EQ(context, newSize->size, newSizeOther->size); - TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, tensor2, newSize)); + TfLiteTensor* input1 = &context->tensors[node->inputs->data[0]]; + TfLiteTensor* input2 = &context->tensors[node->inputs->data[1]]; + TfLiteTensor* output = &context->tensors[node->outputs->data[0]]; + + TF_LITE_ENSURE_EQ(context, input1->dims->size, input2->dims->size); + for (int i = 0; i < input1->dims->size; ++i) { + TF_LITE_ENSURE_EQ(context, input1->dims->data[i], input2->dims->data[i]); + } + + TF_LITE_ENSURE_STATUS(context->ResizeTensor( + context, output, TfLiteIntArrayCopy(input1->dims))); return kTfLiteOk; }; @@ -818,6 +824,8 @@ class TestDelegate : public ::testing::Test { quant); interpreter_->SetTensorParametersReadWrite(3, kTfLiteFloat32, "", {3}, quant); + interpreter_->SetTensorParametersReadWrite(4, kTfLiteFloat32, "", {3}, + quant); TfLiteRegistration reg = AddOpRegistration(); interpreter_->AddNodeWithParameters({0, 0}, {2}, nullptr, 0, nullptr, ®); interpreter_->AddNodeWithParameters({1, 1}, {3}, nullptr, 0, nullptr, ®); @@ -916,7 +924,6 @@ class TestDelegate : public ::testing::Test { }; TEST_F(TestDelegate, BasicDelegate) { - interpreter_->Invoke(); delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate({0, 1, 2})); interpreter_->ModifyGraphWithDelegate(delegate_->get_tf_lite_delegate()); @@ -944,7 +951,6 @@ TEST_F(TestDelegate, BasicDelegate) { } TEST_F(TestDelegate, ComplexDeligate) { - interpreter_->Invoke(); delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate({1, 2})); interpreter_->ModifyGraphWithDelegate(delegate_->get_tf_lite_delegate()); @@ -959,7 +965,6 @@ TEST_F(TestDelegate, ComplexDeligate) { } TEST_F(TestDelegate, SetBufferHandleToInput) { - interpreter_->Invoke(); delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate({0, 1, 2})); TfLiteDelegate* delegate = delegate_->get_tf_lite_delegate(); interpreter_->ModifyGraphWithDelegate(delegate); @@ -978,7 +983,6 @@ TEST_F(TestDelegate, SetBufferHandleToInput) { } TEST_F(TestDelegate, SetBufferHandleToOutput) { - interpreter_->Invoke(); delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate({0, 1, 2})); TfLiteDelegate* delegate = delegate_->get_tf_lite_delegate(); interpreter_->ModifyGraphWithDelegate(delegate); @@ -1002,7 +1006,7 @@ TEST_F(TestDelegate, SetInvalidHandleToTensor) { interpreter_->Invoke(); delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate({0, 1, 2})); TfLiteDelegate* delegate = delegate_->get_tf_lite_delegate(); - interpreter_->ModifyGraphWithDelegate(delegate); + interpreter_->ModifyGraphWithDelegate(delegate, true); SimpleDelegate another_simple_delegate({0, 1, 2}); @@ -1023,6 +1027,88 @@ TEST_F(TestDelegate, SetInvalidHandleToTensor) { EXPECT_EQ(tensor->buffer_handle, kTfLiteNullBufferHandle); } +TEST_F(TestDelegate, ResizeInputWithNonDynamicDelegateShouldFail) { + delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate({0, 1, 2})); + ASSERT_EQ(interpreter_->ResizeInputTensor(0, {1, 2}), kTfLiteOk); + ASSERT_EQ(interpreter_->ResizeInputTensor(1, {1, 2}), kTfLiteOk); + ASSERT_EQ( + interpreter_->ModifyGraphWithDelegate(delegate_->get_tf_lite_delegate()), + kTfLiteOk); + ASSERT_EQ(interpreter_->ResizeInputTensor(0, {1, 2}), kTfLiteError); +} + +class TestDelegateWithDynamicTensors : public ::testing::Test { + protected: + void SetUp() override { + interpreter_.reset(new Interpreter); + + interpreter_->AddTensors(2); + interpreter_->SetInputs({0}); + interpreter_->SetOutputs({1}); + TfLiteQuantizationParams quant; + interpreter_->SetTensorParametersReadWrite(0, kTfLiteFloat32, "", {3}, + quant); + interpreter_->SetTensorParametersReadWrite(1, kTfLiteFloat32, "", {3}, + quant); + TfLiteRegistration reg = DynamicCopyOpRegistration(); + interpreter_->AddNodeWithParameters({0}, {1}, nullptr, 0, nullptr, ®); + + delegate_.Prepare = [](TfLiteContext* context, + TfLiteDelegate* delegate) -> TfLiteStatus { + // In this test, the delegate replaces all the nodes if this function is + // called. + TfLiteIntArray* execution_plan; + TF_LITE_ENSURE_STATUS( + context->GetExecutionPlan(context, &execution_plan)); + context->ReplaceSubgraphsWithDelegateKernels( + context, DelegateRegistration(), execution_plan, delegate); + return kTfLiteOk; + }; + } + + static TfLiteRegistration DynamicCopyOpRegistration() { + TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr}; + + reg.prepare = [](TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* output = &context->tensors[node->outputs->data[0]]; + SetTensorToDynamic(output); + return kTfLiteOk; + }; + + reg.invoke = [](TfLiteContext* context, TfLiteNode* node) { + // Not implemented since this isn't required in testing. + return kTfLiteOk; + }; + return reg; + } + + static TfLiteRegistration DelegateRegistration() { + TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr}; + return reg; + } + + std::unique_ptr<Interpreter> interpreter_; + TfLiteDelegate delegate_; +}; + +TEST_F(TestDelegateWithDynamicTensors, DisallowDynamicTensors) { + interpreter_->ModifyGraphWithDelegate(&delegate_, false); + + ASSERT_EQ(interpreter_->execution_plan().size(), 1); + // The interpreter should not call delegate's `Prepare` when dynamic tensors + // exist. So the node ID isn't changed. + ASSERT_EQ(interpreter_->execution_plan()[0], 0); +} + +TEST_F(TestDelegateWithDynamicTensors, AllowDynamicTensors) { + interpreter_->ModifyGraphWithDelegate(&delegate_, true); + + ASSERT_EQ(interpreter_->execution_plan().size(), 1); + // The node should be replaced because dynamic tensors are allowed. Therefore + // only node ID in the execution plan is changed from 0 to 1. + ASSERT_EQ(interpreter_->execution_plan()[0], 1); +} + } // namespace } // namespace tflite |