diff options
author | 2018-02-09 15:45:00 -0800 | |
---|---|---|
committer | 2018-02-09 15:48:34 -0800 | |
commit | 2adb6bbb1b4d31bba7113a4213bf5e7f0e154c78 (patch) | |
tree | 309680694a800d7ff3605e3a387418a0f4ef59f3 /tensorflow/contrib/lite/interpreter_test.cc | |
parent | 40ec7202b63c32f2f5ed57116096e33677c4b5df (diff) |
Add delegate API to tflite.
- Context gets GetNodes, num_nodes and PartitionNodesIntoSubgraphs.
- TfLiteDelegate provides one function that need be implemented
- Delegates choose nodes and those nodes are all compacted into
a new macro kernel.
PiperOrigin-RevId: 185204338
Diffstat (limited to 'tensorflow/contrib/lite/interpreter_test.cc')
-rw-r--r-- | tensorflow/contrib/lite/interpreter_test.cc | 146 |
1 files changed, 145 insertions, 1 deletions
diff --git a/tensorflow/contrib/lite/interpreter_test.cc b/tensorflow/contrib/lite/interpreter_test.cc index cfda19d72c..4b309748f7 100644 --- a/tensorflow/contrib/lite/interpreter_test.cc +++ b/tensorflow/contrib/lite/interpreter_test.cc @@ -16,9 +16,10 @@ limitations under the License. #include "tensorflow/contrib/lite/interpreter.h" #include <gtest/gtest.h> #include "tensorflow/contrib/lite/error_reporter.h" +#include "tensorflow/contrib/lite/kernels/internal/compatibility.h" +#include "tensorflow/contrib/lite/schema/schema_generated.h" #include "tensorflow/contrib/lite/string_util.h" #include "tensorflow/contrib/lite/testing/util.h" - namespace tflite { namespace { @@ -687,6 +688,149 @@ TEST_F(TestExecutionPlan, NullExecutionPlan) { ASSERT_EQ(run_order_, std::vector<int>()); } +// Build a kernel registration for an op that copies its one input +// to an output +TfLiteRegistration AddOpRegistration() { + TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr}; + + reg.custom_name = "my_add"; + reg.builtin_code = tflite::BuiltinOperator_CUSTOM; + + reg.prepare = [](TfLiteContext* context, TfLiteNode* node) { + // Set output size to input size + TfLiteTensor* tensor0 = &context->tensors[node->inputs->data[0]]; + TfLiteTensor* tensor1 = &context->tensors[node->inputs->data[1]]; + TfLiteTensor* tensor2 = &context->tensors[node->outputs->data[0]]; + TfLiteIntArray* newSize = TfLiteIntArrayCopy(tensor0->dims); + TfLiteIntArray* newSizeOther = TfLiteIntArrayCopy(tensor1->dims); + TF_LITE_ENSURE_EQ(context, newSize->size, newSizeOther->size); + TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, tensor2, newSize)); + return kTfLiteOk; + }; + + reg.invoke = [](TfLiteContext* context, TfLiteNode* node) { + // Copy input data to output data. + TfLiteTensor* a0 = &context->tensors[node->inputs->data[0]]; + TfLiteTensor* a1 = &context->tensors[node->inputs->data[1]]; + TfLiteTensor* out = &context->tensors[node->outputs->data[0]]; + int num = a0->dims->data[0]; + for (int i = 0; i < num; i++) { + out->data.f[i] = a0->data.f[i] + a1->data.f[i]; + } + return kTfLiteOk; + }; + return reg; +} + +class TestDelegate : public ::testing::Test { + public: + TestDelegate() { + interpreter_.AddTensors(5); + interpreter_.SetInputs({0, 1}); + interpreter_.SetOutputs({3, 4}); + TfLiteQuantizationParams quant; + interpreter_.SetTensorParametersReadWrite(0, kTfLiteFloat32, "", {3}, + quant); + interpreter_.SetTensorParametersReadWrite(1, kTfLiteFloat32, "", {3}, + quant); + interpreter_.SetTensorParametersReadWrite(2, kTfLiteFloat32, "", {3}, + quant); + interpreter_.SetTensorParametersReadWrite(3, kTfLiteFloat32, "", {3}, + quant); + TfLiteRegistration reg = AddOpRegistration(); + interpreter_.AddNodeWithParameters({0, 0}, {2}, nullptr, 0, nullptr, ®); + interpreter_.AddNodeWithParameters({1, 1}, {3}, nullptr, 0, nullptr, ®); + interpreter_.AddNodeWithParameters({2, 1}, {4}, nullptr, 0, nullptr, ®); + } + + protected: + class SimpleDelegate { + public: + // Create a simple implementation of a TfLiteDelegate. We use the C++ class + // SimpleDelegate and it can produce a handle TfLiteDelegate that is + // value-copyable and compatible with TfLite. + explicit SimpleDelegate(const std::vector<int>& nodes) : nodes_(nodes) { + delegate_.Prepare = [](TfLiteContext* context, + void* data) -> TfLiteStatus { + auto* simple = reinterpret_cast<SimpleDelegate*>(data); + TfLiteIntArray* nodes_to_separate = + TfLiteIntArrayCreate(simple->nodes_.size()); + // Mark nodes that we want in TfLiteIntArray* structure. + int index = 0; + for (auto node_index : simple->nodes_) { + nodes_to_separate->data[index++] = node_index; + // make sure node is add + TfLiteNode* node; + TfLiteRegistration* reg; + context->GetNodeAndRegistration(context, node_index, &node, ®); + TFLITE_CHECK_EQ(reg->builtin_code, tflite::BuiltinOperator_CUSTOM); + TFLITE_CHECK_EQ(strcmp(reg->custom_name, "my_add"), 0); + } + // Check that all nodes are available + TfLiteIntArray* execution_plan; + TF_LITE_ENSURE_STATUS( + context->GetExecutionPlan(context, &execution_plan)); + for (int exec_index = 0; exec_index < execution_plan->size; + exec_index++) { + int node_index = execution_plan->data[exec_index]; + TfLiteNode* node; + TfLiteRegistration* reg; + context->GetNodeAndRegistration(context, node_index, &node, ®); + TFLITE_CHECK_EQ(reg->builtin_code, tflite::BuiltinOperator_CUSTOM); + TFLITE_CHECK_EQ(strcmp(reg->custom_name, "my_add"), 0); + } + + context->ReplaceSubgraphsWithDelegateKernels( + context, FakeFusedRegistration(), nodes_to_separate); + TfLiteIntArrayFree(nodes_to_separate); + return kTfLiteOk; + }; + // Store type-punned data SimpleDelegate structure. + delegate_.data_ = reinterpret_cast<void*>(this); + } + + static TfLiteRegistration FakeFusedRegistration() { + TfLiteRegistration reg = {nullptr}; + reg.custom_name = "fake_fused_op"; + return reg; + } + + TfLiteDelegate* get_tf_lite_delegate() { return &delegate_; } + + private: + std::vector<int> nodes_; + TfLiteDelegate delegate_; + }; + Interpreter interpreter_; +}; + +TEST_F(TestDelegate, BasicDelegate) { + interpreter_.Invoke(); + SimpleDelegate simple({0, 1, 2}); + interpreter_.ModifyGraphWithDelegate(simple.get_tf_lite_delegate()); + + ASSERT_EQ(interpreter_.execution_plan().size(), 1); + int node = interpreter_.execution_plan()[0]; + const auto* node_and_reg = interpreter_.node_and_registration(node); + ASSERT_EQ(node_and_reg->second.custom_name, + SimpleDelegate::FakeFusedRegistration().custom_name); +} + +TEST_F(TestDelegate, ComplexDeligate) { + interpreter_.Invoke(); + SimpleDelegate simple({1, 2}); + interpreter_.ModifyGraphWithDelegate(simple.get_tf_lite_delegate()); + + ASSERT_EQ(interpreter_.execution_plan().size(), 2); + // 0th should be a non-delegated original op + ASSERT_EQ(interpreter_.execution_plan()[0], 0); + // 1st should be a new macro op (3) which didn't exist) + ASSERT_EQ(interpreter_.execution_plan()[1], 3); + const auto* node_and_reg = interpreter_.node_and_registration(3); + ASSERT_EQ(node_and_reg->second.custom_name, + SimpleDelegate::FakeFusedRegistration().custom_name); +} + } // namespace } // namespace tflite |