aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/interpreter_test.cc
diff options
context:
space:
mode:
authorGravatar Andrew Selle <aselle@google.com>2018-02-09 15:45:00 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-09 15:48:34 -0800
commit2adb6bbb1b4d31bba7113a4213bf5e7f0e154c78 (patch)
tree309680694a800d7ff3605e3a387418a0f4ef59f3 /tensorflow/contrib/lite/interpreter_test.cc
parent40ec7202b63c32f2f5ed57116096e33677c4b5df (diff)
Add delegate API to tflite.
- Context gets GetNodes, num_nodes and PartitionNodesIntoSubgraphs. - TfLiteDelegate provides one function that need be implemented - Delegates choose nodes and those nodes are all compacted into a new macro kernel. PiperOrigin-RevId: 185204338
Diffstat (limited to 'tensorflow/contrib/lite/interpreter_test.cc')
-rw-r--r--tensorflow/contrib/lite/interpreter_test.cc146
1 files changed, 145 insertions, 1 deletions
diff --git a/tensorflow/contrib/lite/interpreter_test.cc b/tensorflow/contrib/lite/interpreter_test.cc
index cfda19d72c..4b309748f7 100644
--- a/tensorflow/contrib/lite/interpreter_test.cc
+++ b/tensorflow/contrib/lite/interpreter_test.cc
@@ -16,9 +16,10 @@ limitations under the License.
#include "tensorflow/contrib/lite/interpreter.h"
#include <gtest/gtest.h>
#include "tensorflow/contrib/lite/error_reporter.h"
+#include "tensorflow/contrib/lite/kernels/internal/compatibility.h"
+#include "tensorflow/contrib/lite/schema/schema_generated.h"
#include "tensorflow/contrib/lite/string_util.h"
#include "tensorflow/contrib/lite/testing/util.h"
-
namespace tflite {
namespace {
@@ -687,6 +688,149 @@ TEST_F(TestExecutionPlan, NullExecutionPlan) {
ASSERT_EQ(run_order_, std::vector<int>());
}
+// Build a kernel registration for an op that copies its one input
+// to an output
+TfLiteRegistration AddOpRegistration() {
+ TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr};
+
+ reg.custom_name = "my_add";
+ reg.builtin_code = tflite::BuiltinOperator_CUSTOM;
+
+ reg.prepare = [](TfLiteContext* context, TfLiteNode* node) {
+ // Set output size to input size
+ TfLiteTensor* tensor0 = &context->tensors[node->inputs->data[0]];
+ TfLiteTensor* tensor1 = &context->tensors[node->inputs->data[1]];
+ TfLiteTensor* tensor2 = &context->tensors[node->outputs->data[0]];
+ TfLiteIntArray* newSize = TfLiteIntArrayCopy(tensor0->dims);
+ TfLiteIntArray* newSizeOther = TfLiteIntArrayCopy(tensor1->dims);
+ TF_LITE_ENSURE_EQ(context, newSize->size, newSizeOther->size);
+ TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, tensor2, newSize));
+ return kTfLiteOk;
+ };
+
+ reg.invoke = [](TfLiteContext* context, TfLiteNode* node) {
+ // Copy input data to output data.
+ TfLiteTensor* a0 = &context->tensors[node->inputs->data[0]];
+ TfLiteTensor* a1 = &context->tensors[node->inputs->data[1]];
+ TfLiteTensor* out = &context->tensors[node->outputs->data[0]];
+ int num = a0->dims->data[0];
+ for (int i = 0; i < num; i++) {
+ out->data.f[i] = a0->data.f[i] + a1->data.f[i];
+ }
+ return kTfLiteOk;
+ };
+ return reg;
+}
+
+class TestDelegate : public ::testing::Test {
+ public:
+ TestDelegate() {
+ interpreter_.AddTensors(5);
+ interpreter_.SetInputs({0, 1});
+ interpreter_.SetOutputs({3, 4});
+ TfLiteQuantizationParams quant;
+ interpreter_.SetTensorParametersReadWrite(0, kTfLiteFloat32, "", {3},
+ quant);
+ interpreter_.SetTensorParametersReadWrite(1, kTfLiteFloat32, "", {3},
+ quant);
+ interpreter_.SetTensorParametersReadWrite(2, kTfLiteFloat32, "", {3},
+ quant);
+ interpreter_.SetTensorParametersReadWrite(3, kTfLiteFloat32, "", {3},
+ quant);
+ TfLiteRegistration reg = AddOpRegistration();
+ interpreter_.AddNodeWithParameters({0, 0}, {2}, nullptr, 0, nullptr, &reg);
+ interpreter_.AddNodeWithParameters({1, 1}, {3}, nullptr, 0, nullptr, &reg);
+ interpreter_.AddNodeWithParameters({2, 1}, {4}, nullptr, 0, nullptr, &reg);
+ }
+
+ protected:
+ class SimpleDelegate {
+ public:
+ // Create a simple implementation of a TfLiteDelegate. We use the C++ class
+ // SimpleDelegate and it can produce a handle TfLiteDelegate that is
+ // value-copyable and compatible with TfLite.
+ explicit SimpleDelegate(const std::vector<int>& nodes) : nodes_(nodes) {
+ delegate_.Prepare = [](TfLiteContext* context,
+ void* data) -> TfLiteStatus {
+ auto* simple = reinterpret_cast<SimpleDelegate*>(data);
+ TfLiteIntArray* nodes_to_separate =
+ TfLiteIntArrayCreate(simple->nodes_.size());
+ // Mark nodes that we want in TfLiteIntArray* structure.
+ int index = 0;
+ for (auto node_index : simple->nodes_) {
+ nodes_to_separate->data[index++] = node_index;
+ // make sure node is add
+ TfLiteNode* node;
+ TfLiteRegistration* reg;
+ context->GetNodeAndRegistration(context, node_index, &node, &reg);
+ TFLITE_CHECK_EQ(reg->builtin_code, tflite::BuiltinOperator_CUSTOM);
+ TFLITE_CHECK_EQ(strcmp(reg->custom_name, "my_add"), 0);
+ }
+ // Check that all nodes are available
+ TfLiteIntArray* execution_plan;
+ TF_LITE_ENSURE_STATUS(
+ context->GetExecutionPlan(context, &execution_plan));
+ for (int exec_index = 0; exec_index < execution_plan->size;
+ exec_index++) {
+ int node_index = execution_plan->data[exec_index];
+ TfLiteNode* node;
+ TfLiteRegistration* reg;
+ context->GetNodeAndRegistration(context, node_index, &node, &reg);
+ TFLITE_CHECK_EQ(reg->builtin_code, tflite::BuiltinOperator_CUSTOM);
+ TFLITE_CHECK_EQ(strcmp(reg->custom_name, "my_add"), 0);
+ }
+
+ context->ReplaceSubgraphsWithDelegateKernels(
+ context, FakeFusedRegistration(), nodes_to_separate);
+ TfLiteIntArrayFree(nodes_to_separate);
+ return kTfLiteOk;
+ };
+ // Store type-punned data SimpleDelegate structure.
+ delegate_.data_ = reinterpret_cast<void*>(this);
+ }
+
+ static TfLiteRegistration FakeFusedRegistration() {
+ TfLiteRegistration reg = {nullptr};
+ reg.custom_name = "fake_fused_op";
+ return reg;
+ }
+
+ TfLiteDelegate* get_tf_lite_delegate() { return &delegate_; }
+
+ private:
+ std::vector<int> nodes_;
+ TfLiteDelegate delegate_;
+ };
+ Interpreter interpreter_;
+};
+
+TEST_F(TestDelegate, BasicDelegate) {
+ interpreter_.Invoke();
+ SimpleDelegate simple({0, 1, 2});
+ interpreter_.ModifyGraphWithDelegate(simple.get_tf_lite_delegate());
+
+ ASSERT_EQ(interpreter_.execution_plan().size(), 1);
+ int node = interpreter_.execution_plan()[0];
+ const auto* node_and_reg = interpreter_.node_and_registration(node);
+ ASSERT_EQ(node_and_reg->second.custom_name,
+ SimpleDelegate::FakeFusedRegistration().custom_name);
+}
+
+TEST_F(TestDelegate, ComplexDeligate) {
+ interpreter_.Invoke();
+ SimpleDelegate simple({1, 2});
+ interpreter_.ModifyGraphWithDelegate(simple.get_tf_lite_delegate());
+
+ ASSERT_EQ(interpreter_.execution_plan().size(), 2);
+ // 0th should be a non-delegated original op
+ ASSERT_EQ(interpreter_.execution_plan()[0], 0);
+ // 1st should be a new macro op (3) which didn't exist)
+ ASSERT_EQ(interpreter_.execution_plan()[1], 3);
+ const auto* node_and_reg = interpreter_.node_and_registration(3);
+ ASSERT_EQ(node_and_reg->second.custom_name,
+ SimpleDelegate::FakeFusedRegistration().custom_name);
+}
+
} // namespace
} // namespace tflite