aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite
diff options
context:
space:
mode:
authorGravatar Yu-Cheng Ling <ycling@google.com>2018-03-07 17:42:08 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-07 17:46:14 -0800
commit5594bc3c43f6829b7ea77f96852c98fb41e4deb2 (patch)
treea7d2f5771c1758ec5a2fbf69ebe79c8bab6d9693 /tensorflow/contrib/lite
parent9cdfd3878935fb6c3c2a5da7f65ee0db6c751170 (diff)
TFLite: Delegate Buffer Handle interface
PiperOrigin-RevId: 188263046
Diffstat (limited to 'tensorflow/contrib/lite')
-rw-r--r--tensorflow/contrib/lite/BUILD22
-rw-r--r--tensorflow/contrib/lite/context.c7
-rw-r--r--tensorflow/contrib/lite/context.h64
-rw-r--r--tensorflow/contrib/lite/interpreter.cc154
-rw-r--r--tensorflow/contrib/lite/interpreter.h45
-rw-r--r--tensorflow/contrib/lite/interpreter_test.cc164
-rw-r--r--tensorflow/contrib/lite/util.cc27
-rw-r--r--tensorflow/contrib/lite/util.h34
-rw-r--r--tensorflow/contrib/lite/util_test.cc50
9 files changed, 496 insertions, 71 deletions
diff --git a/tensorflow/contrib/lite/BUILD b/tensorflow/contrib/lite/BUILD
index 44c4a7e2ca..5cfbb544b7 100644
--- a/tensorflow/contrib/lite/BUILD
+++ b/tensorflow/contrib/lite/BUILD
@@ -132,6 +132,7 @@ cc_library(
":memory_planner",
":schema_fbs_version",
":simple_memory_arena",
+ ":util",
"//tensorflow/contrib/lite/kernels:gemm_support",
"//tensorflow/contrib/lite/nnapi:nnapi_lib",
"//tensorflow/contrib/lite/schema:schema_fbs",
@@ -232,6 +233,27 @@ cc_test(
],
)
+cc_library(
+ name = "util",
+ srcs = ["util.cc"],
+ hdrs = ["util.h"],
+ deps = [
+ ":context",
+ ],
+)
+
+cc_test(
+ name = "util_test",
+ size = "small",
+ srcs = ["util_test.cc"],
+ deps = [
+ ":context",
+ ":util",
+ "//tensorflow/contrib/lite/testing:util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
# Test the serialization of a model with optional tensors.
# Model tests
diff --git a/tensorflow/contrib/lite/context.c b/tensorflow/contrib/lite/context.c
index c09e838c5c..620de5d678 100644
--- a/tensorflow/contrib/lite/context.c
+++ b/tensorflow/contrib/lite/context.c
@@ -17,9 +17,14 @@ limitations under the License.
#include <stdio.h>
#include <string.h>
+int TfLiteIntArrayGetSizeInBytes(int size) {
+ static TfLiteIntArray dummy;
+ return sizeof(dummy) + sizeof(dummy.data[0]) * size;
+}
+
TfLiteIntArray* TfLiteIntArrayCreate(int size) {
TfLiteIntArray* ret =
- (TfLiteIntArray*)malloc(sizeof(*ret) + sizeof(ret->data[0]) * size);
+ (TfLiteIntArray*)malloc(TfLiteIntArrayGetSizeInBytes(size));
ret->size = size;
return ret;
}
diff --git a/tensorflow/contrib/lite/context.h b/tensorflow/contrib/lite/context.h
index ed7f4515fa..d901b9f065 100644
--- a/tensorflow/contrib/lite/context.h
+++ b/tensorflow/contrib/lite/context.h
@@ -29,6 +29,7 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_LITE_CONTEXT_H_
#define TENSORFLOW_CONTRIB_LITE_CONTEXT_H_
+#include <stdbool.h>
#include <stdint.h>
#include <stdlib.h>
@@ -40,6 +41,7 @@ typedef enum { kTfLiteOk = 0, kTfLiteError = 1 } TfLiteStatus;
// Forward declare so GetNode can use this is in Context.
typedef struct _TfLiteRegistration TfLiteRegistration;
+typedef struct _TfLiteDelegate TfLiteDelegate;
#define kOptionalTensor (-1)
@@ -57,6 +59,10 @@ typedef struct {
#endif
} TfLiteIntArray;
+// Given the size (number of elements) in a TfLiteIntArray, calculate its size
+// in bytes.
+int TfLiteIntArrayGetSizeInBytes(int size);
+
// Create a array of a given `size` (uninitialized entries).
// This returns a pointer, that you must free using TfLiteIntArrayFree().
TfLiteIntArray* TfLiteIntArrayCreate(int size);
@@ -162,6 +168,11 @@ typedef enum {
kTfLiteDynamic,
} TfLiteAllocationType;
+// The delegates should use zero or positive integers to represent handles.
+// -1 is reserved from unallocated status.
+typedef int TfLiteDelegateBufferHandle;
+const TfLiteDelegateBufferHandle kTfLiteNullBufferHandle = -1;
+
// An tensor in the interpreter system which is a wrapper around a buffer of
// data including a dimensionality (or NULL if not currently defined).
typedef struct {
@@ -194,6 +205,22 @@ typedef struct {
// Null-terminated name of this tensor.
const char* name;
+
+ // The delegate which knows how to handle `delegate_buffer_handle`.
+ // WARNING: This is an experimental interface that is subject to change.
+ TfLiteDelegate* delegate;
+
+ // An integer buffer handle that can be handled by `delegate`.
+ // The value is valid only when delegate is not null.
+ // WARNING: This is an experimental interface that is subject to change.
+ TfLiteDelegateBufferHandle delegate_buffer_handle;
+
+ // If the delegate uses its own buffer (e.g. GPU memory), the delegate is
+ // responsible to set data_is_stale to true.
+ // `delegate->CopyFromBufferHandle` can be called to copy the data from
+ // delegate buffer.
+ // WARNING: This is an // experimental interface that is subject to change.
+ bool data_is_stale;
} TfLiteTensor;
// Free memory of tensor `t`;
@@ -234,6 +261,11 @@ typedef struct {
// WARNING: This is an experimental interface that is subject to change.
const void* custom_initial_data;
int custom_initial_data_size;
+
+ // The pointer to the delegate. This is non-null only when the node is
+ // created by calling `interpreter.ModifyGraphWithDelegate`.
+ // WARNING: This is an experimental interface that is subject to change.
+ TfLiteDelegate* delegate;
} TfLiteNode;
typedef struct TfLiteContext {
@@ -287,7 +319,7 @@ typedef struct TfLiteContext {
// does not take ownership of `nodes_to_replace`.
TfLiteStatus (*ReplaceSubgraphsWithDelegateKernels)(
struct TfLiteContext*, TfLiteRegistration registration,
- const TfLiteIntArray* nodes_to_replace);
+ const TfLiteIntArray* nodes_to_replace, TfLiteDelegate* delegate);
// TODO(ahentz): we should create a more general mechanism for this sort of
// library-global objects.
@@ -338,19 +370,45 @@ typedef struct _TfLiteRegistration {
} TfLiteRegistration;
// WARNING: This is an experimental interface that is subject to change.
-typedef struct {
+typedef struct _TfLiteDelegate {
// Data that delegate needs to identify itself. This data is owned by the
// delegate. The delegate is owned in the user code, so the delegate is
// responsible for doing this when it is destroyed.
void* data_;
+
// Invoked by ModifyGraphWithDelegate. This prepare is called, giving the
// delegate a view of the current graph through TfLiteContext*. It typically
// will look at the nodes and call ReplaceSubgraphsWithDelegateKernels()
// to ask the TensorFlow lite runtime to create macro-nodes to represent
// delegated subgraphs of the original graph.
- TfLiteStatus (*Prepare)(TfLiteContext* context, void* data);
+ TfLiteStatus (*Prepare)(TfLiteContext* context, TfLiteDelegate* delegate);
+
+ // Copy the data from delegate buffer handle to raw memory.
+ // This can be null if the delegate doesn't use its own buffer.
+ TfLiteStatus (*CopyFromBufferHandle)(
+ TfLiteDelegate* delegate,
+ TfLiteDelegateBufferHandle delegate_buffer_handle, void* data, int size);
+
+ // Copy the data from raw memory to delegate buffer handle.
+ // This can be null if the delegate doesn't use its own buffer.
+ TfLiteStatus (*CopyToBufferHandle)(
+ TfLiteDelegate* delegate,
+ TfLiteDelegateBufferHandle delegate_buffer_handle, void* data, int size);
+
+ // Free the Delegate Buffer Handle. Note: This only frees the handle, but
+ // this doesn't release the underlying resource (e.g. textures). The
+ // resources are either owned by application layer or the delegate.
+ // This can be null if the delegate doesn't use its own buffer.
+ void (*FreeBufferHandle)(TfLiteDelegate* delegate,
+ TfLiteDelegateBufferHandle* handle);
} TfLiteDelegate;
+// WARNING: This is an experimental interface that is subject to change.
+typedef struct {
+ TfLiteDelegate* delegate;
+ TfLiteIntArray* nodes_to_replace;
+} TfLiteDelegateParams;
+
#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus
diff --git a/tensorflow/contrib/lite/interpreter.cc b/tensorflow/contrib/lite/interpreter.cc
index 0f5e17f0de..733c47852e 100644
--- a/tensorflow/contrib/lite/interpreter.cc
+++ b/tensorflow/contrib/lite/interpreter.cc
@@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/contrib/lite/memory_planner.h"
#include "tensorflow/contrib/lite/nnapi_delegate.h"
#include "tensorflow/contrib/lite/schema/schema_generated.h"
+#include "tensorflow/contrib/lite/util.h"
namespace tflite {
@@ -96,19 +97,57 @@ Interpreter::~Interpreter() {
}
for (int i = 0; i < context_.tensors_size; i++) {
- TfLiteTensorFree(&context_.tensors[i]);
+ TfLiteTensor* tensor = &context_.tensors[i];
+ if (tensor->delegate_buffer_handle != kTfLiteNullBufferHandle) {
+ tensor->delegate->FreeBufferHandle(tensor->delegate,
+ &tensor->delegate_buffer_handle);
+ }
+ TfLiteTensorFree(tensor);
}
}
TfLiteStatus Interpreter::ReplaceSubgraphsWithDelegateKernels(
TfLiteContext* context, TfLiteRegistration registration,
- const TfLiteIntArray* nodes_to_replace) {
+ const TfLiteIntArray* nodes_to_replace, TfLiteDelegate* delegate) {
return static_cast<Interpreter*>(context->impl_)
- ->ReplaceSubgraphsWithDelegateKernels(registration, nodes_to_replace);
+ ->ReplaceSubgraphsWithDelegateKernels(registration, nodes_to_replace,
+ delegate);
+}
+
+namespace {
+
+// This function allocates a continuous memory space that contains a
+// TfLiteDelegateParams followed by a TfLiteIntArray. The pointer will be
+// deallocated by C `free` function later.
+TfLiteDelegateParams* CreateDelegateParams(
+ TfLiteDelegate* delegate, const std::vector<int>& nodes_to_replace) {
+ int nodes_to_replace_size_in_bytes =
+ TfLiteIntArrayGetSizeInBytes(nodes_to_replace.size());
+ void* allocation =
+ malloc(sizeof(TfLiteDelegateParams) + nodes_to_replace_size_in_bytes);
+ TfLiteDelegateParams* params =
+ reinterpret_cast<TfLiteDelegateParams*>(allocation);
+ TfLiteIntArray* nodes_to_replace_arr = reinterpret_cast<TfLiteIntArray*>(
+ static_cast<char*>(allocation) + sizeof(TfLiteDelegateParams));
+
+ nodes_to_replace_arr->size = nodes_to_replace.size();
+ for (int i = 0; i < nodes_to_replace.size(); ++i) {
+ nodes_to_replace_arr->data[i] = nodes_to_replace[i];
+ }
+
+ params->delegate = delegate;
+ params->nodes_to_replace = nodes_to_replace_arr;
+ return params;
}
+} // Anonymous namespace
+
TfLiteStatus Interpreter::ReplaceSubgraphsWithDelegateKernels(
- TfLiteRegistration registration, const TfLiteIntArray* nodes_to_replace) {
+ TfLiteRegistration registration, const TfLiteIntArray* nodes_to_replace,
+ TfLiteDelegate* delegate) {
+ // Annotate the registration as DELEGATE op.
+ registration.builtin_code = BuiltinOperator_DELEGATE;
+
// Annotate the registration as DELEGATE op.
registration.builtin_code = BuiltinOperator_DELEGATE;
@@ -120,30 +159,38 @@ TfLiteStatus Interpreter::ReplaceSubgraphsWithDelegateKernels(
execution_plan_.clear();
for (auto& subgraph : subgraphs) {
- // Turn subgraph.nodes into a TfLiteIntArray compatible data structure.
- // TODO(aselle): Avoid this copy by constructing subgraph.nodes that way
- // in the first place
- subgraph.nodes.insert(subgraph.nodes.begin(),
- static_cast<int>(subgraph.nodes.size()));
// Subgraphs calimed by the delegate should have a "macro" op created, the
// other subgraphs (kTfNonPartition) just have their nodes added back to
// the execution plan.
switch (subgraph.type) {
case Subgraph::kTfNonPartition:
- for (auto it = subgraph.nodes.begin() + 1; it != subgraph.nodes.end();
+ for (auto it = subgraph.nodes.begin(); it != subgraph.nodes.end();
++it) {
execution_plan_.push_back(*it);
}
break;
case Subgraph::kTfPartition: {
- void* builtin_data = nullptr;
int node_index;
- // Create a node that represents computation of this subgraph.
- AddNodeWithParameters(
- subgraph.input_tensors, subgraph.output_tensors,
- reinterpret_cast<const char*>(subgraph.nodes.data()),
- subgraph.nodes.size() * sizeof(subgraph.nodes[0]), builtin_data,
- &registration, &node_index);
+
+ TfLiteDelegateParams* params =
+ CreateDelegateParams(delegate, subgraph.nodes);
+ AddNodeWithParameters(subgraph.input_tensors, subgraph.output_tensors,
+ nullptr, 0, params, &registration, &node_index);
+
+ // Initialize the output tensors's delegate-related fields.
+ for (int tensor_index : subgraph.output_tensors) {
+ TfLiteTensor* tensor = &tensors_[tensor_index];
+ TF_LITE_ENSURE_EQ(&context_, tensor->delegate, nullptr);
+ TF_LITE_ENSURE_EQ(&context_, tensor->delegate_buffer_handle,
+ kTfLiteNullBufferHandle);
+ // delegate_buffer_handle will be filled in delegate's `Prepare`
+ // function.
+ tensor->delegate = delegate;
+ }
+
+ // Associate the node with the delegate.
+ TfLiteNode* node = &nodes_and_registration_[node_index].first;
+ node->delegate = delegate;
} break;
case Subgraph::kTfUnexplored:
return kTfLiteError;
@@ -233,14 +280,6 @@ TfLiteStatus Interpreter::BytesRequired(TfLiteType type, const int* dims,
return kTfLiteOk;
}
-namespace {
-TfLiteIntArray* convertVectorToTfLiteIntArray(const std::vector<int>& x) {
- TfLiteIntArray* lite = TfLiteIntArrayCreate(x.size());
- for (size_t i = 0; i < x.size(); i++) lite->data[i] = x[i];
- return lite;
-}
-} // namespace
-
TfLiteStatus Interpreter::AllocateTensors() {
next_execution_plan_index_to_prepare_ = 0;
if (memory_planner_) {
@@ -275,7 +314,6 @@ TfLiteStatus Interpreter::AddNodeWithParameters(
int new_node_index = nodes_and_registration_.size();
if (node_index) *node_index = new_node_index;
nodes_and_registration_.resize(nodes_and_registration_.size() + 1);
-
auto& node_and_reg = nodes_and_registration_.back();
TfLiteNode& node = node_and_reg.first;
if (node.inputs) TfLiteIntArrayFree(node.inputs);
@@ -285,8 +323,8 @@ TfLiteStatus Interpreter::AddNodeWithParameters(
// NOTE, here we are not using move semantics yet, since our internal
// representation isn't std::vector, but in the future we would like to avoid
// copies, so we want the interface to take r-value references now.
- node.inputs = convertVectorToTfLiteIntArray(inputs);
- node.outputs = convertVectorToTfLiteIntArray(outputs);
+ node.inputs = ConvertVectorToTfLiteIntArray(inputs);
+ node.outputs = ConvertVectorToTfLiteIntArray(outputs);
node.temporaries = TfLiteIntArrayCreate(0);
if (init_data) {
node.user_data = OpInit(*registration, init_data, init_data_size);
@@ -299,6 +337,7 @@ TfLiteStatus Interpreter::AddNodeWithParameters(
node.builtin_data = builtin_data_deleter.release();
// TODO(ycling): Filling `custom_initial_data` and `custom_initial_data_size`
// properly for nodes generated by ReplaceSubgraphsWithDelegateKernels.
+
if (registration->builtin_code == BuiltinOperator_CUSTOM) {
// When it's a CUSTOM op, the `custom_options` field in the Flatbuffer
// `Operator` table is passed in.
@@ -309,6 +348,7 @@ TfLiteStatus Interpreter::AddNodeWithParameters(
node.custom_initial_data_size = 0;
}
+ node.delegate = nullptr;
node_and_reg.second = *registration;
execution_plan_.push_back(new_node_index);
return kTfLiteOk;
@@ -322,7 +362,7 @@ TfLiteStatus Interpreter::ResizeInputTensor(int tensor_index,
TF_LITE_ENSURE(&context_,
tensor_index < context_.tensors_size && tensor_index >= 0);
invokable_ = false;
- TfLiteIntArray* dims_lite = convertVectorToTfLiteIntArray(dims);
+ TfLiteIntArray* dims_lite = ConvertVectorToTfLiteIntArray(dims);
return ResizeTensorImpl(&context_.tensors[tensor_index], dims_lite);
}
@@ -424,11 +464,29 @@ TfLiteStatus Interpreter::Invoke() {
TfLiteNode& node = nodes_and_registration_[node_index].first;
const TfLiteRegistration& registration =
nodes_and_registration_[node_index].second;
+
+ // TODO(ycling): This is an extra loop through inputs to check if the data
+ // need to be copied from Delegate buffer to raw memory, which is often not
+ // needed. We may want to cache this in prepare to know if this needs to be
+ // done for a node or not.
+ for (int i = 0; i < node.inputs->size; ++i) {
+ int tensor_index = node.inputs->data[i];
+ if (tensor_index == kOptionalTensor) {
+ continue;
+ }
+ TfLiteTensor* tensor = &tensors_[tensor_index];
+ if (tensor->delegate && tensor->delegate != node.delegate &&
+ tensor->data_is_stale) {
+ EnsureTensorDataIsReadable(tensor_index);
+ }
+ }
+
EnsureTensorsVectorCapacity();
if (OpInvoke(registration, &node) == kTfLiteError) {
status = kTfLiteError;
}
}
+
return status;
}
@@ -464,6 +522,7 @@ TfLiteStatus Interpreter::AddTensors(int tensors_to_add,
tensors_.resize(tensors_.size() + tensors_to_add);
for (int i = base_index; i < tensors_.size(); i++) {
memset(&tensors_[i], 0, sizeof(tensors_[i]));
+ tensors_[i].delegate_buffer_handle = kTfLiteNullBufferHandle;
}
context_.tensors = tensors_.data();
context_.tensors_size = tensors_.size();
@@ -511,7 +570,7 @@ TfLiteStatus Interpreter::SetTensorParametersReadOnly(
TF_LITE_ENSURE_EQ(&context_, required_bytes, bytes);
}
invokable_ = false;
- TfLiteTensorReset(type, name, convertVectorToTfLiteIntArray(dims),
+ TfLiteTensorReset(type, name, ConvertVectorToTfLiteIntArray(dims),
quantization, const_cast<char*>(buffer), bytes,
kTfLiteMmapRo, allocation, &context_.tensors[tensor_index]);
return kTfLiteOk;
@@ -536,7 +595,7 @@ TfLiteStatus Interpreter::SetTensorParametersReadWrite(
TF_LITE_ENSURE_OK(&context_, BytesRequired(type, dims.data(), dims.size(),
&required_bytes));
}
- TfLiteTensorReset(type, name, convertVectorToTfLiteIntArray(dims),
+ TfLiteTensorReset(type, name, ConvertVectorToTfLiteIntArray(dims),
quantization,
/*buffer=*/nullptr, required_bytes,
type == kTfLiteString ? kTfLiteDynamic : kTfLiteArenaRw,
@@ -613,7 +672,7 @@ TfLiteStatus Interpreter::ModifyGraphWithDelegate(TfLiteDelegate* delegate) {
ReplaceSubgraphsWithDelegateKernels;
context_.GetExecutionPlan = GetExecutionPlan;
- TfLiteStatus status = delegate->Prepare(&context_, delegate->data_);
+ TfLiteStatus status = delegate->Prepare(&context_, delegate);
// Remove additional context info.
context_.GetNodeAndRegistration = nullptr;
context_.ReplaceSubgraphsWithDelegateKernels = nullptr;
@@ -621,4 +680,35 @@ TfLiteStatus Interpreter::ModifyGraphWithDelegate(TfLiteDelegate* delegate) {
return status;
}
+TfLiteStatus Interpreter::SetDelegateBufferHandle(
+ int tensor_index, TfLiteDelegateBufferHandle delegate_buffer_handle,
+ TfLiteDelegate* delegate) {
+ TF_LITE_ENSURE(&context_, tensor_index < tensors_size());
+ TfLiteTensor* tensor = &tensors_[tensor_index];
+
+ TF_LITE_ENSURE(&context_,
+ tensor->delegate == nullptr || tensor->delegate == delegate);
+ tensor->delegate = delegate;
+ if (tensor->delegate_buffer_handle != kTfLiteNullBufferHandle) {
+ TF_LITE_ENSURE(&context_, tensor->delegate->FreeBufferHandle != nullptr);
+ tensor->delegate->FreeBufferHandle(tensor->delegate,
+ &tensor->delegate_buffer_handle);
+ }
+ tensor->delegate_buffer_handle = delegate_buffer_handle;
+
+ return kTfLiteOk;
+}
+
+TfLiteStatus Interpreter::GetDelegateBufferHandle(
+ int tensor_index, TfLiteDelegateBufferHandle* delegate_buffer_handle,
+ TfLiteDelegate** delegate) {
+ TF_LITE_ENSURE(&context_, tensor_index < tensors_size());
+ TfLiteTensor* tensor = &tensors_[tensor_index];
+
+ *delegate = tensor->delegate;
+ *delegate_buffer_handle = tensor->delegate_buffer_handle;
+
+ return kTfLiteOk;
+}
+
} // namespace tflite
diff --git a/tensorflow/contrib/lite/interpreter.h b/tensorflow/contrib/lite/interpreter.h
index 04c19644a0..f5fcae90cc 100644
--- a/tensorflow/contrib/lite/interpreter.h
+++ b/tensorflow/contrib/lite/interpreter.h
@@ -265,6 +265,46 @@ class Interpreter {
void set_model(const Model* model) { model_ = const_cast<Model*>(model); }
Model* model() const { return model_; }
+ // Ensure the data in `tensor.data` is readable. In case delegate is used,
+ // it might require to copy the data from delegate buffer to raw memory.
+ TfLiteStatus EnsureTensorDataIsReadable(int tensor_index) {
+ TF_LITE_ENSURE(&context_, tensor_index < tensors_size());
+ TfLiteTensor* tensor = &tensors_[tensor_index];
+ if (tensor->data_is_stale) {
+ TF_LITE_ENSURE(&context_, tensor->delegate != nullptr);
+ TF_LITE_ENSURE(&context_,
+ tensor->delegate_buffer_handle != kTfLiteNullBufferHandle);
+ // This can be null if the delegate doesn't use its own buffer.
+ TF_LITE_ENSURE(&context_,
+ tensor->delegate->CopyFromBufferHandle != nullptr);
+ tensor->delegate->CopyFromBufferHandle(tensor->delegate,
+ tensor->delegate_buffer_handle,
+ tensor->data.raw, tensor->bytes);
+ tensor->data_is_stale = false;
+ }
+ return kTfLiteOk;
+ }
+
+ // Set the delegate buffer handle to a tensor. It can be called in the
+ // following cases:
+ // 1. Set the buffer handle to a tensor that's not being written by a
+ // delegate. For example, feeding an OpenGL texture as the input of the
+ // inference graph.
+ // 2. Set the buffer handle to a tensor that uses the same delegate.
+ // For example, set an OpenGL texture as the output of inference, while
+ // the node which produces output is an OpenGL delegate node.
+ // WARNING: This is an experimental API and subject to change.
+ TfLiteStatus SetDelegateBufferHandle(
+ int tensor_index, TfLiteDelegateBufferHandle delegate_buffer_handle,
+ TfLiteDelegate* delegate);
+
+ // Get the delegate buffer handle, and the delegate which can process the
+ // buffer handle.
+ // WARNING: This is an experimental API and subject to change.
+ TfLiteStatus GetDelegateBufferHandle(
+ int tensor_index, TfLiteDelegateBufferHandle* delegate_buffer_handle,
+ TfLiteDelegate** delegate);
+
// The default capacity of `tensors_` vector.
static constexpr int kTensorsReservedCapacity = 128;
// The capacity headroom of `tensors_` vector before calling ops'
@@ -355,14 +395,15 @@ class Interpreter {
// Entry point for C API ReplaceSubgraphsWithDelegateKernels
static TfLiteStatus ReplaceSubgraphsWithDelegateKernels(
TfLiteContext* context, TfLiteRegistration registration,
- const TfLiteIntArray* nodes_to_replace);
+ const TfLiteIntArray* nodes_to_replace, TfLiteDelegate* delegate);
// Update the execution graph to replace some of the nodes with stub
// nodes. Specifically any node index that has `nodes[index]==1` will be
// slated for replacement with a delegate kernel specified by registration.
// WARNING: This is an experimental interface that is subject to change.
TfLiteStatus ReplaceSubgraphsWithDelegateKernels(
- TfLiteRegistration registration, const TfLiteIntArray* nodes_to_replace);
+ TfLiteRegistration registration, const TfLiteIntArray* nodes_to_replace,
+ TfLiteDelegate* delegate);
// WARNING: This is an experimental interface that is subject to change.
// Gets the internal pointer to a TensorFlow lite node by node_index.
diff --git a/tensorflow/contrib/lite/interpreter_test.cc b/tensorflow/contrib/lite/interpreter_test.cc
index 2e6727b323..11578fcb69 100644
--- a/tensorflow/contrib/lite/interpreter_test.cc
+++ b/tensorflow/contrib/lite/interpreter_test.cc
@@ -763,24 +763,38 @@ TfLiteRegistration AddOpRegistration() {
}
class TestDelegate : public ::testing::Test {
- public:
- TestDelegate() {
- interpreter_.AddTensors(5);
- interpreter_.SetInputs({0, 1});
- interpreter_.SetOutputs({3, 4});
+ protected:
+ void SetUp() override {
+ interpreter_ = absl::make_unique<Interpreter>();
+ interpreter_->AddTensors(5);
+ interpreter_->SetInputs({0, 1});
+ interpreter_->SetOutputs({3, 4});
TfLiteQuantizationParams quant;
- interpreter_.SetTensorParametersReadWrite(0, kTfLiteFloat32, "", {3},
- quant);
- interpreter_.SetTensorParametersReadWrite(1, kTfLiteFloat32, "", {3},
- quant);
- interpreter_.SetTensorParametersReadWrite(2, kTfLiteFloat32, "", {3},
- quant);
- interpreter_.SetTensorParametersReadWrite(3, kTfLiteFloat32, "", {3},
- quant);
+ interpreter_->SetTensorParametersReadWrite(0, kTfLiteFloat32, "", {3},
+ quant);
+ interpreter_->SetTensorParametersReadWrite(1, kTfLiteFloat32, "", {3},
+ quant);
+ interpreter_->SetTensorParametersReadWrite(2, kTfLiteFloat32, "", {3},
+ quant);
+ interpreter_->SetTensorParametersReadWrite(3, kTfLiteFloat32, "", {3},
+ quant);
TfLiteRegistration reg = AddOpRegistration();
- interpreter_.AddNodeWithParameters({0, 0}, {2}, nullptr, 0, nullptr, &reg);
- interpreter_.AddNodeWithParameters({1, 1}, {3}, nullptr, 0, nullptr, &reg);
- interpreter_.AddNodeWithParameters({2, 1}, {4}, nullptr, 0, nullptr, &reg);
+ interpreter_->AddNodeWithParameters({0, 0}, {2}, nullptr, 0, nullptr, &reg);
+ interpreter_->AddNodeWithParameters({1, 1}, {3}, nullptr, 0, nullptr, &reg);
+ interpreter_->AddNodeWithParameters({2, 1}, {4}, nullptr, 0, nullptr, &reg);
+ }
+
+ void TearDown() override {
+ // Interpreter relies on delegate_ to free the resources properly. Thus
+ // the life cycle of delegate must be longer than interpreter.
+ interpreter_.reset();
+ delegate_.reset();
+ }
+
+ TfLiteDelegateBufferHandle last_allocated_handle_ = kTfLiteNullBufferHandle;
+
+ TfLiteDelegateBufferHandle AllocateBufferHandle() {
+ return ++last_allocated_handle_;
}
protected:
@@ -791,8 +805,8 @@ class TestDelegate : public ::testing::Test {
// value-copyable and compatible with TfLite.
explicit SimpleDelegate(const std::vector<int>& nodes) : nodes_(nodes) {
delegate_.Prepare = [](TfLiteContext* context,
- void* data) -> TfLiteStatus {
- auto* simple = reinterpret_cast<SimpleDelegate*>(data);
+ TfLiteDelegate* delegate) -> TfLiteStatus {
+ auto* simple = reinterpret_cast<SimpleDelegate*>(delegate->data_);
TfLiteIntArray* nodes_to_separate =
TfLiteIntArrayCreate(simple->nodes_.size());
// Mark nodes that we want in TfLiteIntArray* structure.
@@ -823,10 +837,28 @@ class TestDelegate : public ::testing::Test {
}
context->ReplaceSubgraphsWithDelegateKernels(
- context, FakeFusedRegistration(), nodes_to_separate);
+ context, FakeFusedRegistration(), nodes_to_separate, delegate);
TfLiteIntArrayFree(nodes_to_separate);
return kTfLiteOk;
};
+ delegate_.CopyToBufferHandle =
+ [](TfLiteDelegate* delegate,
+ TfLiteDelegateBufferHandle delegate_buffer_handle, void* data,
+ int size) -> TfLiteStatus {
+ // TODO(ycling): Implement tests to test buffer copying logic.
+ return kTfLiteOk;
+ };
+ delegate_.CopyFromBufferHandle =
+ [](TfLiteDelegate* delegate,
+ TfLiteDelegateBufferHandle delegate_buffer_handle, void* data,
+ int size) -> TfLiteStatus {
+ // TODO(ycling): Implement tests to test buffer copying logic.
+ return kTfLiteOk;
+ };
+ delegate_.FreeBufferHandle = [](TfLiteDelegate* delegate,
+ TfLiteDelegateBufferHandle* handle) {
+ *handle = kTfLiteNullBufferHandle;
+ };
// Store type-punned data SimpleDelegate structure.
delegate_.data_ = reinterpret_cast<void*>(this);
}
@@ -843,36 +875,102 @@ class TestDelegate : public ::testing::Test {
std::vector<int> nodes_;
TfLiteDelegate delegate_;
};
- Interpreter interpreter_;
+ std::unique_ptr<Interpreter> interpreter_;
+ std::unique_ptr<SimpleDelegate> delegate_;
};
TEST_F(TestDelegate, BasicDelegate) {
- interpreter_.Invoke();
- SimpleDelegate simple({0, 1, 2});
- interpreter_.ModifyGraphWithDelegate(simple.get_tf_lite_delegate());
+ interpreter_->Invoke();
+ delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate({0, 1, 2}));
+ interpreter_->ModifyGraphWithDelegate(delegate_->get_tf_lite_delegate());
- ASSERT_EQ(interpreter_.execution_plan().size(), 1);
- int node = interpreter_.execution_plan()[0];
- const auto* node_and_reg = interpreter_.node_and_registration(node);
+ ASSERT_EQ(interpreter_->execution_plan().size(), 1);
+ int node = interpreter_->execution_plan()[0];
+ const auto* node_and_reg = interpreter_->node_and_registration(node);
ASSERT_EQ(node_and_reg->second.custom_name,
SimpleDelegate::FakeFusedRegistration().custom_name);
}
TEST_F(TestDelegate, ComplexDeligate) {
- interpreter_.Invoke();
- SimpleDelegate simple({1, 2});
- interpreter_.ModifyGraphWithDelegate(simple.get_tf_lite_delegate());
+ interpreter_->Invoke();
+ delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate({1, 2}));
+ interpreter_->ModifyGraphWithDelegate(delegate_->get_tf_lite_delegate());
- ASSERT_EQ(interpreter_.execution_plan().size(), 2);
+ ASSERT_EQ(interpreter_->execution_plan().size(), 2);
// 0th should be a non-delegated original op
- ASSERT_EQ(interpreter_.execution_plan()[0], 0);
+ ASSERT_EQ(interpreter_->execution_plan()[0], 0);
// 1st should be a new macro op (3) which didn't exist)
- ASSERT_EQ(interpreter_.execution_plan()[1], 3);
- const auto* node_and_reg = interpreter_.node_and_registration(3);
+ ASSERT_EQ(interpreter_->execution_plan()[1], 3);
+ const auto* node_and_reg = interpreter_->node_and_registration(3);
ASSERT_EQ(node_and_reg->second.custom_name,
SimpleDelegate::FakeFusedRegistration().custom_name);
}
+TEST_F(TestDelegate, SetBufferHandleToInput) {
+ interpreter_->Invoke();
+ delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate({0, 1, 2}));
+ TfLiteDelegate* delegate = delegate_->get_tf_lite_delegate();
+ interpreter_->ModifyGraphWithDelegate(delegate);
+
+ constexpr int kOutputTensorIndex = 0;
+ TfLiteTensor* tensor = interpreter_->tensor(kOutputTensorIndex);
+ ASSERT_EQ(tensor->delegate, nullptr);
+ ASSERT_EQ(tensor->delegate_buffer_handle, kTfLiteNullBufferHandle);
+
+ TfLiteDelegateBufferHandle handle = AllocateBufferHandle();
+ TfLiteStatus status = interpreter_->SetDelegateBufferHandle(
+ kOutputTensorIndex, handle, delegate);
+ ASSERT_EQ(status, kTfLiteOk);
+ EXPECT_EQ(tensor->delegate, delegate);
+ EXPECT_EQ(tensor->delegate_buffer_handle, handle);
+}
+
+TEST_F(TestDelegate, SetBufferHandleToOutput) {
+ interpreter_->Invoke();
+ delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate({0, 1, 2}));
+ TfLiteDelegate* delegate = delegate_->get_tf_lite_delegate();
+ interpreter_->ModifyGraphWithDelegate(delegate);
+
+ constexpr int kOutputTensorIndex = 3;
+ TfLiteTensor* tensor = interpreter_->tensor(kOutputTensorIndex);
+ // Before setting the buffer handle, the tensor's `delegate` is already set
+ // because it will be written by the delegate.
+ ASSERT_EQ(tensor->delegate, delegate);
+ ASSERT_EQ(tensor->delegate_buffer_handle, kTfLiteNullBufferHandle);
+
+ TfLiteDelegateBufferHandle handle = AllocateBufferHandle();
+ TfLiteStatus status = interpreter_->SetDelegateBufferHandle(
+ kOutputTensorIndex, handle, delegate);
+ ASSERT_EQ(status, kTfLiteOk);
+ EXPECT_EQ(tensor->delegate, delegate);
+ EXPECT_EQ(tensor->delegate_buffer_handle, handle);
+}
+
+TEST_F(TestDelegate, SetInvalidHandleToTensor) {
+ interpreter_->Invoke();
+ delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate({0, 1, 2}));
+ TfLiteDelegate* delegate = delegate_->get_tf_lite_delegate();
+ interpreter_->ModifyGraphWithDelegate(delegate);
+
+ SimpleDelegate another_simple_delegate({0, 1, 2});
+
+ constexpr int kOutputTensorIndex = 3;
+ TfLiteTensor* tensor = interpreter_->tensor(kOutputTensorIndex);
+ // Before setting the buffer handle, the tensor's `delegate` is already set
+ // because it will be written by the delegate.
+ ASSERT_EQ(tensor->delegate, delegate);
+ ASSERT_EQ(tensor->delegate_buffer_handle, kTfLiteNullBufferHandle);
+
+ TfLiteDelegateBufferHandle handle = AllocateBufferHandle();
+ TfLiteStatus status = interpreter_->SetDelegateBufferHandle(
+ kOutputTensorIndex, handle,
+ another_simple_delegate.get_tf_lite_delegate());
+ // Setting a buffer handle to a tensor with another delegate will fail.
+ ASSERT_EQ(status, kTfLiteError);
+ EXPECT_EQ(tensor->delegate, delegate);
+ EXPECT_EQ(tensor->delegate_buffer_handle, kTfLiteNullBufferHandle);
+}
+
} // namespace
} // namespace tflite
diff --git a/tensorflow/contrib/lite/util.cc b/tensorflow/contrib/lite/util.cc
new file mode 100644
index 0000000000..b2c7e6c7a6
--- /dev/null
+++ b/tensorflow/contrib/lite/util.cc
@@ -0,0 +1,27 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/util.h"
+
+namespace tflite {
+
+TfLiteIntArray* ConvertVectorToTfLiteIntArray(const std::vector<int>& input) {
+ TfLiteIntArray* output = TfLiteIntArrayCreate(input.size());
+ for (size_t i = 0; i < input.size(); i++) {
+ output->data[i] = input[i];
+ }
+ return output;
+}
+
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/util.h b/tensorflow/contrib/lite/util.h
new file mode 100644
index 0000000000..50e4fb839e
--- /dev/null
+++ b/tensorflow/contrib/lite/util.h
@@ -0,0 +1,34 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// This file provides general C++ utility functions in TFLite.
+// For example: Converting between `TfLiteIntArray`, `std::vector` and
+// Flatbuffer vectors. These functions can't live in `context.h` since it's pure
+// C.
+
+#ifndef TENSORFLOW_CONTRIB_LITE_UTIL_H_
+#define TENSORFLOW_CONTRIB_LITE_UTIL_H_
+
+#include <vector>
+#include "tensorflow/contrib/lite/context.h"
+
+namespace tflite {
+
+// Converts a `std::vector` to a `TfLiteIntArray`.
+TfLiteIntArray* ConvertVectorToTfLiteIntArray(const std::vector<int>& input);
+
+} // namespace tflite
+
+#endif // TENSORFLOW_CONTRIB_LITE_UTIL_H_
diff --git a/tensorflow/contrib/lite/util_test.cc b/tensorflow/contrib/lite/util_test.cc
new file mode 100644
index 0000000000..04579c53aa
--- /dev/null
+++ b/tensorflow/contrib/lite/util_test.cc
@@ -0,0 +1,50 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <vector>
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/util.h"
+
+namespace tflite {
+namespace {
+
+TEST(ConvertVectorToTfLiteIntArray, TestWithVector) {
+ std::vector<int> input = {1, 2};
+ TfLiteIntArray* output = ConvertVectorToTfLiteIntArray(input);
+ ASSERT_NE(output, nullptr);
+ EXPECT_EQ(output->size, 2);
+ EXPECT_EQ(output->data[0], 1);
+ EXPECT_EQ(output->data[1], 2);
+ TfLiteIntArrayFree(output);
+}
+
+TEST(ConvertVectorToTfLiteIntArray, TestWithEmptyVector) {
+ std::vector<int> input;
+ TfLiteIntArray* output = ConvertVectorToTfLiteIntArray(input);
+ ASSERT_NE(output, nullptr);
+ EXPECT_EQ(output->size, 0);
+ TfLiteIntArrayFree(output);
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}