aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/delegates
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/delegates')
-rw-r--r--tensorflow/contrib/lite/delegates/eager/BUILD37
-rw-r--r--tensorflow/contrib/lite/delegates/eager/buffer_map.cc4
-rw-r--r--tensorflow/contrib/lite/delegates/eager/delegate_data.cc3
-rw-r--r--tensorflow/contrib/lite/delegates/eager/kernel.cc289
-rw-r--r--tensorflow/contrib/lite/delegates/eager/kernel.h34
-rw-r--r--tensorflow/contrib/lite/delegates/eager/kernel_test.cc351
6 files changed, 717 insertions, 1 deletions
diff --git a/tensorflow/contrib/lite/delegates/eager/BUILD b/tensorflow/contrib/lite/delegates/eager/BUILD
index 03a4b7bf1d..a28707382e 100644
--- a/tensorflow/contrib/lite/delegates/eager/BUILD
+++ b/tensorflow/contrib/lite/delegates/eager/BUILD
@@ -68,6 +68,43 @@ cc_test(
)
cc_library(
+ name = "kernel",
+ srcs = ["kernel.cc"],
+ hdrs = ["kernel.h"],
+ deps = [
+ ":delegate_data",
+ ":util",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite:kernel_api",
+ "//tensorflow/contrib/lite/kernels:kernel_util",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core/common_runtime/eager:context",
+ "//tensorflow/core/common_runtime/eager:execute",
+ "//tensorflow/core/common_runtime/eager:tensor_handle",
+ "@flatbuffers",
+ ],
+)
+
+cc_test(
+ name = "kernel_test",
+ size = "small",
+ srcs = ["kernel_test.cc"],
+ tags = [
+ "no_oss",
+ "tflite_not_portable",
+ ],
+ deps = [
+ ":delegate_data",
+ ":kernel",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "//tensorflow/contrib/lite/testing:util",
+ "@com_google_absl//absl/memory",
+ "@com_google_googletest//:gtest",
+ "@flatbuffers",
+ ],
+)
+
+cc_library(
name = "util",
srcs = ["util.cc"],
hdrs = ["util.h"],
diff --git a/tensorflow/contrib/lite/delegates/eager/buffer_map.cc b/tensorflow/contrib/lite/delegates/eager/buffer_map.cc
index 1d6453f498..e5a19c3997 100644
--- a/tensorflow/contrib/lite/delegates/eager/buffer_map.cc
+++ b/tensorflow/contrib/lite/delegates/eager/buffer_map.cc
@@ -91,6 +91,10 @@ void BufferMap::SetFromTfLite(int tensor_index, const TfLiteTensor* tensor) {
for (int i = 0; i < num_dims; ++i) {
shape.AddDim(tensor->dims->data[i]);
}
+ // TODO(ahentz): we assume this is a new tensor and allocate a new buffer
+ // for it. This is not always the best approach. For example, this might
+ // be a reallocation after resizing tensors. In that case we would be
+ // preferable to somehow reuse the buffer.
auto* buf = new TfLiteTensorBuffer(tensor);
tensorflow::Tensor t = tensorflow::TensorCApi::MakeTensor(
GetTensorFlowDataType(tensor->type), shape, buf);
diff --git a/tensorflow/contrib/lite/delegates/eager/delegate_data.cc b/tensorflow/contrib/lite/delegates/eager/delegate_data.cc
index 29687694bd..0fd5c976f8 100644
--- a/tensorflow/contrib/lite/delegates/eager/delegate_data.cc
+++ b/tensorflow/contrib/lite/delegates/eager/delegate_data.cc
@@ -23,7 +23,8 @@ tensorflow::Status DelegateData::Create(std::unique_ptr<DelegateData>* data) {
std::vector<tensorflow::Device*> devices;
TF_RETURN_IF_ERROR(tensorflow::DeviceFactory::AddDevices(
- tensorflow::SessionOptions(), "/device:cpu:*", &devices));
+ tensorflow::SessionOptions(), "/job:localhost/replica:0/task:0",
+ &devices));
std::unique_ptr<tensorflow::DeviceMgr> device_mgr(
new tensorflow::DeviceMgr(devices));
diff --git a/tensorflow/contrib/lite/delegates/eager/kernel.cc b/tensorflow/contrib/lite/delegates/eager/kernel.cc
new file mode 100644
index 0000000000..1727981807
--- /dev/null
+++ b/tensorflow/contrib/lite/delegates/eager/kernel.cc
@@ -0,0 +1,289 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/delegates/eager/kernel.h"
+
+#include "third_party/flatbuffers/include/flatbuffers/flexbuffers.h"
+#include "tensorflow/contrib/lite/builtin_ops.h"
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/context_util.h"
+#include "tensorflow/contrib/lite/delegates/eager/delegate_data.h"
+#include "tensorflow/contrib/lite/delegates/eager/util.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/core/common_runtime/eager/context.h"
+#include "tensorflow/core/common_runtime/eager/execute.h"
+#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+
+// Note: this is part of TF Lite's Eager delegation code which is to be
+// completed soon.
+
+// This is the TF Lite op that is created by the eager delegate to handle
+// execution of a supported subgraph. The usual flow is that the delegate
+// informs the interpreter of supported nodes in a graph, and each supported
+// subgraph is replaced with one instance of this kernel.
+//
+// The kernel is initialized with TfLiteDelegateParams from which we retrieve
+// the global EagerContext and BufferMap, as well as a list of inputs and
+// outputs to the subgraph. Those are used to build the OpData, with a list of
+// TensorFlow Ops that should be executed in order (which we call an OpNode).
+//
+// For each node included in the subgraph, we query the interpreter and
+// retrieve the associated NodeDef, which is then used to configure the
+// corresponding TensorFlow/Eager Op.
+
+namespace tflite {
+namespace eager {
+namespace kernel {
+
+// Controls the lifetime of tensor handles in a vector.
+class VectorOfHandles {
+ public:
+ explicit VectorOfHandles(int num_elements) : vector_(num_elements, nullptr) {}
+
+ ~VectorOfHandles() {
+ for (auto* handle : vector_) {
+ if (handle) handle->Unref();
+ }
+ }
+
+ tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 2>* GetVector() {
+ return &vector_;
+ }
+
+ tensorflow::TensorHandle* GetHandle(int index) { return vector_[index]; }
+
+ private:
+ tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 2> vector_;
+};
+
+// Executes the TensorFlow op given by 'op_name', with the attributes specified
+// in 'nodedef'. Inputs and outputs are given as indices into the 'buffer_map'.
+tensorflow::Status ExecuteEagerOp(tensorflow::EagerContext* eager_context,
+ BufferMap* buffer_map, const string& op_name,
+ const tensorflow::NodeDef& nodedef,
+ const std::vector<int>& inputs,
+ const std::vector<int>& outputs) {
+ const tensorflow::AttrTypeMap* attr_types;
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(
+ tensorflow::AttrTypeMapForOp(op_name.c_str(), &attr_types),
+ " (while processing attributes of '", op_name, "')");
+
+ tensorflow::EagerOperation op(eager_context, op_name.c_str(), attr_types);
+ for (const auto& attr : nodedef.attr()) {
+ op.MutableAttrs()->Set(attr.first, attr.second);
+ }
+
+ for (int input_index : inputs) {
+ if (!buffer_map->HasTensor(input_index)) {
+ return tensorflow::errors::Internal(
+ "Cannot read from invalid tensor index ", input_index);
+ }
+ auto* handle = new tensorflow::TensorHandle(
+ buffer_map->GetTensor(input_index), nullptr, nullptr, nullptr);
+ op.AddInput(handle);
+ handle->Unref();
+ }
+
+ int num_retvals = outputs.size();
+ VectorOfHandles retvals(num_retvals);
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(
+ EagerExecute(&op, retvals.GetVector(), &num_retvals),
+ " (while executing '", op_name, "' via Eager)");
+
+ if (num_retvals != outputs.size()) {
+ return tensorflow::errors::Internal(
+ "Unexpected number of outputs from EagerExecute");
+ }
+
+ for (int i = 0; i < num_retvals; ++i) {
+ const tensorflow::Tensor* tensor = nullptr;
+ TF_RETURN_IF_ERROR(retvals.GetHandle(i)->Tensor(&tensor));
+ buffer_map->SetFromTensorFlow(outputs[i], *tensor);
+ }
+
+ return tensorflow::Status::OK();
+}
+
+// A single node within the larger 'op'. Note that this kernel executes many
+// TensorFlow ops within a single TF Lite op.
+struct OpNode {
+ // The name of the TensorFlow op to execute.
+ string name;
+ // The corresponding NodeDef, containing the attributes for the op.
+ tensorflow::NodeDef nodedef;
+ // List of inputs, as TF Lite tensor indices.
+ std::vector<int> inputs;
+ // List of outputs, as TF Lite tensor indices.
+ std::vector<int> outputs;
+};
+
+// The Larger 'op', which contains all the nodes in a supported subgraph.
+struct OpData {
+ tensorflow::EagerContext* eager_context;
+ BufferMap* buffer_map;
+ std::vector<OpNode> nodes;
+ std::vector<int> subgraph_inputs;
+ std::vector<int> subgraph_outputs;
+};
+
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+ auto* op_data = new OpData;
+
+ const TfLiteDelegateParams* params =
+ reinterpret_cast<const TfLiteDelegateParams*>(buffer);
+ CHECK(params);
+ CHECK(params->delegate);
+ CHECK(params->delegate->data_);
+ op_data->eager_context =
+ reinterpret_cast<DelegateData*>(params->delegate->data_)
+ ->GetEagerContext();
+ op_data->buffer_map =
+ reinterpret_cast<DelegateData*>(params->delegate->data_)->GetBufferMap();
+
+ CHECK(params->output_tensors);
+ for (auto tensor_index : TfLiteIntArrayView(params->output_tensors)) {
+ op_data->subgraph_outputs.push_back(tensor_index);
+ }
+
+ CHECK(params->input_tensors);
+ for (auto tensor_index : TfLiteIntArrayView(params->input_tensors)) {
+ op_data->subgraph_inputs.push_back(tensor_index);
+ }
+
+ CHECK(params->nodes_to_replace);
+ for (auto node_index : TfLiteIntArrayView(params->nodes_to_replace)) {
+ TfLiteNode* node;
+ TfLiteRegistration* reg;
+ context->GetNodeAndRegistration(context, node_index, &node, &reg);
+
+ op_data->nodes.push_back(OpNode());
+ OpNode& node_data = op_data->nodes.back();
+
+ node_data.name = "";
+ if (node->custom_initial_data) {
+ // The flexbuffer contains a vector where the first elements is the
+ // op name and the second is a serialized NodeDef.
+ const flexbuffers::Vector& v =
+ flexbuffers::GetRoot(
+ reinterpret_cast<const uint8_t*>(node->custom_initial_data),
+ node->custom_initial_data_size)
+ .AsVector();
+
+ node_data.name = v[0].AsString().str();
+ if (!node_data.nodedef.ParseFromString(v[1].AsString().str())) {
+ // We will just leave the nodedef empty and error out in Eval().
+ node_data.nodedef.Clear();
+ }
+ }
+
+ for (auto input_index : TfLiteIntArrayView(node->inputs)) {
+ node_data.inputs.push_back(input_index);
+ }
+ for (auto output_index : TfLiteIntArrayView(node->outputs)) {
+ node_data.outputs.push_back(output_index);
+ }
+ }
+
+ return op_data;
+}
+
+void Free(TfLiteContext* context, void* buffer) {
+ delete reinterpret_cast<OpData*>(buffer);
+}
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ const auto* op_data = reinterpret_cast<OpData*>(node->user_data);
+ TF_LITE_ENSURE_MSG(
+ context, op_data->eager_context != nullptr,
+ "Failed to initialize eager context. This often happens when a CPU "
+ "device has not been registered, presumably because some symbols from "
+ "tensorflow/core:core_cpu_impl were not linked into the binary.");
+
+ // Whenever we find a constant tensor, insert it in the buffer map.
+ BufferMap* buffer_map = op_data->buffer_map;
+ for (auto tensor_index : op_data->subgraph_inputs) {
+ TfLiteTensor* tensor = &context->tensors[tensor_index];
+ if (IsConstantTensor(tensor)) {
+ if (!buffer_map->HasTensor(tensor_index)) {
+ buffer_map->SetFromTfLite(tensor_index, tensor);
+ }
+ }
+ }
+
+ // All output tensors are allocated by TensorFlow/Eager, so we
+ // mark them as kTfLiteDynamic.
+ for (auto tensor_index : op_data->subgraph_outputs) {
+ SetTensorToDynamic(&context->tensors[tensor_index]);
+ }
+
+ return kTfLiteOk;
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ const auto* op_data = reinterpret_cast<OpData*>(node->user_data);
+ BufferMap* buffer_map = op_data->buffer_map;
+ tensorflow::EagerContext* eager_context = op_data->eager_context;
+
+ // Insert a tensor in the buffer map for all inputs that are not constant.
+ // Constants were handled in Prepare() already.
+ for (auto tensor_index : op_data->subgraph_inputs) {
+ TfLiteTensor* tensor = &context->tensors[tensor_index];
+ if (!IsConstantTensor(tensor)) {
+ buffer_map->SetFromTfLite(tensor_index, tensor);
+ }
+ }
+
+ // Execute the TensorFlow Ops sequentially.
+ for (const auto& node_data : op_data->nodes) {
+ if (node_data.nodedef.op().empty()) {
+ context->ReportError(context, "Invalid NodeDef in Eager op '%s'",
+ node_data.name.c_str());
+ return kTfLiteError;
+ }
+ auto status =
+ ExecuteEagerOp(eager_context, buffer_map, node_data.name,
+ node_data.nodedef, node_data.inputs, node_data.outputs);
+ TF_LITE_ENSURE_OK(context, ConvertStatus(context, status));
+ }
+
+ for (auto tensor_index : op_data->subgraph_outputs) {
+ if (!buffer_map->HasTensor(tensor_index)) {
+ context->ReportError(context, "Cannot write to invalid tensor index %d",
+ tensor_index);
+ return kTfLiteError;
+ }
+
+ TfLiteTensor* tensor = &context->tensors[tensor_index];
+ TF_LITE_ENSURE_OK(
+ context,
+ CopyShape(context, buffer_map->GetTensor(tensor_index), tensor));
+ tensor->buffer_handle = tensor_index;
+ tensor->data_is_stale = true;
+ }
+
+ return kTfLiteOk;
+}
+
+} // namespace kernel
+
+TfLiteRegistration GetKernel() {
+ TfLiteRegistration registration{&kernel::Init, &kernel::Free,
+ &kernel::Prepare, &kernel::Eval,
+ nullptr, kTfLiteBuiltinDelegate};
+ return registration;
+}
+
+} // namespace eager
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/delegates/eager/kernel.h b/tensorflow/contrib/lite/delegates/eager/kernel.h
new file mode 100644
index 0000000000..100672c82d
--- /dev/null
+++ b/tensorflow/contrib/lite/delegates/eager/kernel.h
@@ -0,0 +1,34 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_KERNEL_H_
+#define TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_KERNEL_H_
+
+#include "tensorflow/contrib/lite/context.h"
+
+namespace tflite {
+namespace eager {
+
+// Return the registration object used to initialize and execute ops that will
+// be delegated to TensorFlow's Eager runtime. This TF Lite op is created by
+// the eager delegate to handle execution of a supported subgraph. The usual
+// flow is that the delegate informs the interpreter of supported nodes in a
+// graph, and each supported subgraph is replaced with one instance of this
+// kernel.
+TfLiteRegistration GetKernel();
+
+} // namespace eager
+} // namespace tflite
+
+#endif // TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_KERNEL_H_
diff --git a/tensorflow/contrib/lite/delegates/eager/kernel_test.cc b/tensorflow/contrib/lite/delegates/eager/kernel_test.cc
new file mode 100644
index 0000000000..7d9dddef93
--- /dev/null
+++ b/tensorflow/contrib/lite/delegates/eager/kernel_test.cc
@@ -0,0 +1,351 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/delegates/eager/kernel.h"
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "absl/memory/memory.h"
+#include "third_party/flatbuffers/include/flatbuffers/flexbuffers.h"
+#include "tensorflow/contrib/lite/delegates/eager/delegate_data.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/testing/util.h"
+
+namespace tflite {
+namespace eager {
+namespace {
+
+using tensorflow::protobuf::TextFormat;
+using ::testing::ContainsRegex;
+using ::testing::ElementsAre;
+
+// We will use these are custom_names, so they need to be static.
+static const char kIdentity[] = "Identity";
+static const char kUnpack[] = "Unpack";
+static const char kAdd[] = "Add";
+static const char kMul[] = "Mul";
+
+TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteDelegate* delegate,
+ const std::vector<int>& supported_nodes) {
+ TfLiteIntArray* size_and_nodes =
+ ConvertVectorToTfLiteIntArray(supported_nodes);
+ TF_LITE_ENSURE_STATUS(context->ReplaceSubgraphsWithDelegateKernels(
+ context, eager::GetKernel(), size_and_nodes, delegate));
+ TfLiteIntArrayFree(size_and_nodes);
+ return kTfLiteOk;
+}
+
+class KernelTest : public ::testing::Test {
+ public:
+ KernelTest() {
+ CHECK(DelegateData::Create(&delegate_data_).ok());
+ interpreter_.reset(new Interpreter(&error_reporter_));
+ }
+
+ bool Invoke() { return interpreter_->Invoke() == kTfLiteOk; }
+
+ void SetValues(int tensor_index, const std::vector<float>& values) {
+ float* v = interpreter_->typed_tensor<float>(tensor_index);
+ for (float f : values) {
+ *v++ = f;
+ }
+ }
+
+ std::vector<float> GetValues(int tensor_index) {
+ TfLiteTensor* o = interpreter_->tensor(tensor_index);
+ return std::vector<float>(o->data.f, o->data.f + o->bytes / sizeof(float));
+ }
+
+ void SetShape(int tensor_index, const std::vector<int>& values) {
+ ASSERT_EQ(interpreter_->ResizeInputTensor(tensor_index, values), kTfLiteOk);
+ ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
+ }
+
+ std::vector<int> GetShape(int tensor_index) {
+ std::vector<int> result;
+ auto* dims = interpreter_->tensor(tensor_index)->dims;
+ for (int i = 0; i < dims->size; ++i) {
+ result.push_back(dims->data[i]);
+ }
+ return result;
+ }
+
+ template <typename T>
+ void ConfigureDelegate(T prepare_function) {
+ delegate_.data_ = delegate_data_.get();
+ delegate_.FreeBufferHandle = nullptr;
+ delegate_.Prepare = prepare_function;
+ delegate_.CopyFromBufferHandle = [](TfLiteDelegate* delegate,
+ TfLiteBufferHandle buffer_handle,
+ void* data, size_t size) {
+ auto* delegate_data = reinterpret_cast<DelegateData*>(delegate->data_);
+ tensorflow::StringPiece values =
+ delegate_data->GetBufferMap()->GetTensor(buffer_handle).tensor_data();
+ memcpy(data, values.data(), values.size());
+ return kTfLiteOk;
+ };
+ CHECK(interpreter_->ModifyGraphWithDelegate(
+ &delegate_, /*allow_dynamic_tensors=*/true) == kTfLiteOk);
+ }
+
+ void AddOp(const char* name, const std::vector<int>& inputs,
+ const std::vector<int>& outputs) {
+ auto attr = [](const string& key, const string& value) {
+ return " attr{ key: '" + key + "' value {" + value + "}}";
+ };
+
+ string attributes;
+ if (name == string(kUnpack)) {
+ attributes = attr("T", "type: DT_FLOAT") + attr("num", "i: 2") +
+ attr("axis", "i: 0");
+ } else if (name == string(kIdentity)) {
+ attributes = attr("T", "type: DT_FLOAT");
+ } else if (name == string(kAdd)) {
+ attributes = attr("T", "type: DT_FLOAT");
+ } else if (name == string(kMul)) {
+ attributes = attr("T", "type: DT_FLOAT");
+ }
+ AddTfOp(name, attributes, inputs, outputs);
+ }
+
+ void AddTensors(int num_tensors, const std::vector<int>& inputs,
+ const std::vector<int>& outputs) {
+ interpreter_->AddTensors(num_tensors);
+ for (int i = 0; i < num_tensors; ++i) {
+ TfLiteQuantizationParams quant;
+ CHECK_EQ(interpreter_->SetTensorParametersReadWrite(i, kTfLiteFloat32,
+ /*name=*/"",
+ /*dims=*/{3}, quant),
+ kTfLiteOk);
+ }
+
+ CHECK_EQ(interpreter_->SetInputs(inputs), kTfLiteOk);
+ CHECK_EQ(interpreter_->SetOutputs(outputs), kTfLiteOk);
+ }
+
+ const TestErrorReporter& error_reporter() const { return error_reporter_; }
+
+ void AddTfLiteOp(const char* name, const std::vector<int>& inputs,
+ const std::vector<int>& outputs) {
+ CHECK_EQ(string(name), kMul); // can only add MUL
+ static TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr};
+ reg.builtin_code = BuiltinOperator_MUL;
+ reg.prepare = [](TfLiteContext* context, TfLiteNode* node) {
+ auto* i0 = &context->tensors[node->inputs->data[0]];
+ auto* o = &context->tensors[node->outputs->data[0]];
+ return context->ResizeTensor(context, o, TfLiteIntArrayCopy(i0->dims));
+ };
+ reg.invoke = [](TfLiteContext* context, TfLiteNode* node) {
+ auto* i0 = &context->tensors[node->inputs->data[0]];
+ auto* i1 = &context->tensors[node->inputs->data[1]];
+ auto* o = &context->tensors[node->outputs->data[0]];
+ for (int i = 0; i < o->bytes / sizeof(float); ++i) {
+ o->data.f[i] = i0->data.f[i] * i1->data.f[i];
+ }
+ return kTfLiteOk;
+ };
+
+ CHECK_EQ(interpreter_->AddNodeWithParameters(inputs, outputs, nullptr, 0,
+ nullptr, &reg),
+ kTfLiteOk);
+ }
+
+ private:
+ void AddTfOp(const char* name, const string& nodedef_str,
+ const std::vector<int>& inputs,
+ const std::vector<int>& outputs) {
+ static TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr};
+ reg.builtin_code = BuiltinOperator_CUSTOM;
+ reg.custom_name = name;
+
+ tensorflow::NodeDef nodedef;
+ CHECK(TextFormat::ParseFromString(nodedef_str + " op: '" + name + "'",
+ &nodedef));
+ string serialized_nodedef;
+ CHECK(nodedef.SerializeToString(&serialized_nodedef));
+ flexbuffers::Builder fbb;
+ fbb.Vector([&]() {
+ fbb.String(nodedef.op());
+ fbb.String(serialized_nodedef);
+ });
+ fbb.Finish();
+
+ flexbuffers_.push_back(fbb.GetBuffer());
+ auto& buffer = flexbuffers_.back();
+ CHECK_EQ(interpreter_->AddNodeWithParameters(
+ inputs, outputs, reinterpret_cast<const char*>(buffer.data()),
+ buffer.size(), nullptr, &reg),
+ kTfLiteOk);
+ }
+
+ std::unique_ptr<Interpreter> interpreter_;
+ std::unique_ptr<DelegateData> delegate_data_;
+ TfLiteDelegate delegate_;
+ std::vector<std::vector<uint8_t>> flexbuffers_;
+ TestErrorReporter error_reporter_;
+};
+
+TEST_F(KernelTest, FullGraph) {
+ // Define the graph.
+ AddTensors(9, {0, 3}, {8});
+
+ AddOp(kUnpack, {0}, {1, 2});
+ AddOp(kUnpack, {3}, {4, 5});
+ AddOp(kAdd, {1, 4}, {6});
+ AddOp(kAdd, {2, 5}, {7});
+ AddOp(kMul, {6, 7}, {8});
+
+ // Apply Delegate.
+ ConfigureDelegate([](TfLiteContext* context, TfLiteDelegate* delegate) {
+ return GenericPrepare(context, delegate, {0, 1, 2, 3, 4});
+ });
+
+ // Define inputs.
+ SetShape(0, {2, 2, 1});
+ SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f});
+ SetShape(3, {2, 2, 1});
+ SetValues(3, {1.1f, 2.2f, 3.3f, 4.4f});
+
+ ASSERT_TRUE(Invoke());
+
+ ASSERT_THAT(GetShape(8), ElementsAre(2, 1));
+ ASSERT_THAT(GetValues(8), ElementsAre(14.52f, 38.72f));
+}
+
+TEST_F(KernelTest, BadTensorFlowOp) {
+ AddTensors(2, {0}, {1});
+ AddOp("NonExistentOp", {0}, {1});
+
+ ConfigureDelegate([](TfLiteContext* context, TfLiteDelegate* delegate) {
+ return GenericPrepare(context, delegate, {0});
+ });
+
+ SetShape(0, {2, 2, 1});
+ SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f});
+
+ ASSERT_FALSE(Invoke());
+ ASSERT_THAT(error_reporter().error_messages(),
+ ContainsRegex("while processing attributes of 'NonExistentOp'"));
+}
+
+TEST_F(KernelTest, BadNumberOfOutputs) {
+ AddTensors(3, {0}, {1, 2});
+ AddOp(kIdentity, {0}, {1, 2});
+
+ ConfigureDelegate([](TfLiteContext* context, TfLiteDelegate* delegate) {
+ return GenericPrepare(context, delegate, {0});
+ });
+
+ SetShape(0, {2, 2, 1});
+ SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f});
+
+ ASSERT_FALSE(Invoke());
+ ASSERT_THAT(error_reporter().error_messages(),
+ ContainsRegex("Unexpected number of outputs"));
+}
+
+TEST_F(KernelTest, IncompatibleNodeDef) {
+ AddTensors(2, {0}, {1});
+
+ // Cast is a TF op, but we don't add the proper nodedef to it in AddOp.
+ AddOp("Cast", {0}, {1});
+
+ ConfigureDelegate([](TfLiteContext* context, TfLiteDelegate* delegate) {
+ return GenericPrepare(context, delegate, {0});
+ });
+
+ SetShape(0, {2, 2, 1});
+ SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f});
+
+ ASSERT_FALSE(Invoke());
+ ASSERT_THAT(error_reporter().error_messages(),
+ ContainsRegex("while executing 'Cast' via Eager"));
+}
+
+TEST_F(KernelTest, WrongSetOfNodes) {
+ AddTensors(4, {0}, {3});
+ AddOp(kUnpack, {0}, {1, 2});
+ AddTfLiteOp(kMul, {1, 2}, {3});
+
+ // Specify that kMul (#1) is supported when it actually isn't.
+ ConfigureDelegate([](TfLiteContext* context, TfLiteDelegate* delegate) {
+ return GenericPrepare(context, delegate, {0, 1});
+ });
+
+ SetShape(0, {2, 2, 1});
+ SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f});
+
+ ASSERT_FALSE(Invoke());
+ ASSERT_THAT(error_reporter().error_messages(),
+ ContainsRegex("Invalid NodeDef in Eager op"));
+}
+
+TEST_F(KernelTest, MixedGraph) {
+ AddTensors(9, {0, 3}, {8});
+
+ AddOp(kUnpack, {0}, {1, 2});
+ AddOp(kUnpack, {3}, {4, 5});
+ AddOp(kAdd, {1, 4}, {6});
+ AddOp(kAdd, {2, 5}, {7});
+ AddTfLiteOp(kMul, {6, 7}, {8});
+
+ ConfigureDelegate([](TfLiteContext* context, TfLiteDelegate* delegate) {
+ return GenericPrepare(context, delegate, {0, 1, 2, 3});
+ });
+
+ SetShape(0, {2, 2, 1});
+ SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f});
+ SetShape(3, {2, 2, 1});
+ SetValues(3, {1.1f, 2.2f, 3.3f, 4.4f});
+
+ ASSERT_TRUE(Invoke());
+
+ ASSERT_THAT(GetShape(8), ElementsAre(2, 1));
+ ASSERT_THAT(GetValues(8), ElementsAre(14.52f, 38.72f));
+}
+
+TEST_F(KernelTest, SplitGraph) {
+ AddTensors(10, {0}, {9});
+
+ AddOp(kUnpack, {0}, {1, 2});
+ AddOp(kAdd, {1, 2}, {3});
+ AddOp(kUnpack, {3}, {4, 5});
+
+ AddTfLiteOp(kMul, {4, 5}, {6});
+
+ AddOp(kUnpack, {6}, {7, 8});
+ AddOp(kAdd, {7, 8}, {9});
+
+ ConfigureDelegate([](TfLiteContext* context, TfLiteDelegate* delegate) {
+ return GenericPrepare(context, delegate, {0, 1, 2, 4, 5});
+ });
+
+ SetShape(0, {2, 2, 2, 1});
+ SetValues(0, {3.0f, 1.0f, 0.5f, -1.0f, 0.0f, 1.0f, 1.5f, 3.0f});
+
+ ASSERT_TRUE(Invoke());
+
+ ASSERT_THAT(GetShape(9), ElementsAre(1));
+ ASSERT_THAT(GetValues(9), ElementsAre(10.0f));
+}
+
+} // namespace
+} // namespace eager
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}