aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/interpreter.cc
diff options
context:
space:
mode:
authorGravatar Yu-Cheng Ling <ycling@google.com>2018-03-08 11:56:29 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-08 12:00:28 -0800
commit6e3a43f4b7a1288c878b5daff274f1229256fbe8 (patch)
tree65dbe6000d5bf896015d64519a2d5783c1857e83 /tensorflow/contrib/lite/interpreter.cc
parent214ad0978641a946c25b334c4a33ecd1793b4d70 (diff)
TFLite: Delegate Buffer Handle interface (take 2)
PiperOrigin-RevId: 188366045
Diffstat (limited to 'tensorflow/contrib/lite/interpreter.cc')
-rw-r--r--tensorflow/contrib/lite/interpreter.cc154
1 files changed, 122 insertions, 32 deletions
diff --git a/tensorflow/contrib/lite/interpreter.cc b/tensorflow/contrib/lite/interpreter.cc
index 0f5e17f0de..8fd1085544 100644
--- a/tensorflow/contrib/lite/interpreter.cc
+++ b/tensorflow/contrib/lite/interpreter.cc
@@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/contrib/lite/memory_planner.h"
#include "tensorflow/contrib/lite/nnapi_delegate.h"
#include "tensorflow/contrib/lite/schema/schema_generated.h"
+#include "tensorflow/contrib/lite/util.h"
namespace tflite {
@@ -96,19 +97,57 @@ Interpreter::~Interpreter() {
}
for (int i = 0; i < context_.tensors_size; i++) {
- TfLiteTensorFree(&context_.tensors[i]);
+ TfLiteTensor* tensor = &context_.tensors[i];
+ if (tensor->buffer_handle != kTfLiteNullBufferHandle) {
+ tensor->delegate->FreeBufferHandle(tensor->delegate,
+ &tensor->buffer_handle);
+ }
+ TfLiteTensorFree(tensor);
}
}
TfLiteStatus Interpreter::ReplaceSubgraphsWithDelegateKernels(
TfLiteContext* context, TfLiteRegistration registration,
- const TfLiteIntArray* nodes_to_replace) {
+ const TfLiteIntArray* nodes_to_replace, TfLiteDelegate* delegate) {
return static_cast<Interpreter*>(context->impl_)
- ->ReplaceSubgraphsWithDelegateKernels(registration, nodes_to_replace);
+ ->ReplaceSubgraphsWithDelegateKernels(registration, nodes_to_replace,
+ delegate);
+}
+
+namespace {
+
+// This function allocates a continuous memory space that contains a
+// TfLiteDelegateParams followed by a TfLiteIntArray. The pointer will be
+// deallocated by C `free` function later.
+TfLiteDelegateParams* CreateDelegateParams(
+ TfLiteDelegate* delegate, const std::vector<int>& nodes_to_replace) {
+ int nodes_to_replace_size_in_bytes =
+ TfLiteIntArrayGetSizeInBytes(nodes_to_replace.size());
+ void* allocation =
+ malloc(sizeof(TfLiteDelegateParams) + nodes_to_replace_size_in_bytes);
+ TfLiteDelegateParams* params =
+ reinterpret_cast<TfLiteDelegateParams*>(allocation);
+ TfLiteIntArray* nodes_to_replace_arr = reinterpret_cast<TfLiteIntArray*>(
+ static_cast<char*>(allocation) + sizeof(TfLiteDelegateParams));
+
+ nodes_to_replace_arr->size = nodes_to_replace.size();
+ for (int i = 0; i < nodes_to_replace.size(); ++i) {
+ nodes_to_replace_arr->data[i] = nodes_to_replace[i];
+ }
+
+ params->delegate = delegate;
+ params->nodes_to_replace = nodes_to_replace_arr;
+ return params;
}
+} // Anonymous namespace
+
TfLiteStatus Interpreter::ReplaceSubgraphsWithDelegateKernels(
- TfLiteRegistration registration, const TfLiteIntArray* nodes_to_replace) {
+ TfLiteRegistration registration, const TfLiteIntArray* nodes_to_replace,
+ TfLiteDelegate* delegate) {
+ // Annotate the registration as DELEGATE op.
+ registration.builtin_code = BuiltinOperator_DELEGATE;
+
// Annotate the registration as DELEGATE op.
registration.builtin_code = BuiltinOperator_DELEGATE;
@@ -120,30 +159,38 @@ TfLiteStatus Interpreter::ReplaceSubgraphsWithDelegateKernels(
execution_plan_.clear();
for (auto& subgraph : subgraphs) {
- // Turn subgraph.nodes into a TfLiteIntArray compatible data structure.
- // TODO(aselle): Avoid this copy by constructing subgraph.nodes that way
- // in the first place
- subgraph.nodes.insert(subgraph.nodes.begin(),
- static_cast<int>(subgraph.nodes.size()));
// Subgraphs calimed by the delegate should have a "macro" op created, the
// other subgraphs (kTfNonPartition) just have their nodes added back to
// the execution plan.
switch (subgraph.type) {
case Subgraph::kTfNonPartition:
- for (auto it = subgraph.nodes.begin() + 1; it != subgraph.nodes.end();
+ for (auto it = subgraph.nodes.begin(); it != subgraph.nodes.end();
++it) {
execution_plan_.push_back(*it);
}
break;
case Subgraph::kTfPartition: {
- void* builtin_data = nullptr;
int node_index;
- // Create a node that represents computation of this subgraph.
- AddNodeWithParameters(
- subgraph.input_tensors, subgraph.output_tensors,
- reinterpret_cast<const char*>(subgraph.nodes.data()),
- subgraph.nodes.size() * sizeof(subgraph.nodes[0]), builtin_data,
- &registration, &node_index);
+
+ TfLiteDelegateParams* params =
+ CreateDelegateParams(delegate, subgraph.nodes);
+ AddNodeWithParameters(subgraph.input_tensors, subgraph.output_tensors,
+ nullptr, 0, params, &registration, &node_index);
+
+ // Initialize the output tensors's delegate-related fields.
+ for (int tensor_index : subgraph.output_tensors) {
+ TfLiteTensor* tensor = &tensors_[tensor_index];
+ TF_LITE_ENSURE_EQ(&context_, tensor->delegate, nullptr);
+ TF_LITE_ENSURE_EQ(&context_, tensor->buffer_handle,
+ kTfLiteNullBufferHandle);
+ // buffer_handle will be filled in delegate's `Prepare`
+ // function.
+ tensor->delegate = delegate;
+ }
+
+ // Associate the node with the delegate.
+ TfLiteNode* node = &nodes_and_registration_[node_index].first;
+ node->delegate = delegate;
} break;
case Subgraph::kTfUnexplored:
return kTfLiteError;
@@ -233,14 +280,6 @@ TfLiteStatus Interpreter::BytesRequired(TfLiteType type, const int* dims,
return kTfLiteOk;
}
-namespace {
-TfLiteIntArray* convertVectorToTfLiteIntArray(const std::vector<int>& x) {
- TfLiteIntArray* lite = TfLiteIntArrayCreate(x.size());
- for (size_t i = 0; i < x.size(); i++) lite->data[i] = x[i];
- return lite;
-}
-} // namespace
-
TfLiteStatus Interpreter::AllocateTensors() {
next_execution_plan_index_to_prepare_ = 0;
if (memory_planner_) {
@@ -275,7 +314,6 @@ TfLiteStatus Interpreter::AddNodeWithParameters(
int new_node_index = nodes_and_registration_.size();
if (node_index) *node_index = new_node_index;
nodes_and_registration_.resize(nodes_and_registration_.size() + 1);
-
auto& node_and_reg = nodes_and_registration_.back();
TfLiteNode& node = node_and_reg.first;
if (node.inputs) TfLiteIntArrayFree(node.inputs);
@@ -285,8 +323,8 @@ TfLiteStatus Interpreter::AddNodeWithParameters(
// NOTE, here we are not using move semantics yet, since our internal
// representation isn't std::vector, but in the future we would like to avoid
// copies, so we want the interface to take r-value references now.
- node.inputs = convertVectorToTfLiteIntArray(inputs);
- node.outputs = convertVectorToTfLiteIntArray(outputs);
+ node.inputs = ConvertVectorToTfLiteIntArray(inputs);
+ node.outputs = ConvertVectorToTfLiteIntArray(outputs);
node.temporaries = TfLiteIntArrayCreate(0);
if (init_data) {
node.user_data = OpInit(*registration, init_data, init_data_size);
@@ -299,6 +337,7 @@ TfLiteStatus Interpreter::AddNodeWithParameters(
node.builtin_data = builtin_data_deleter.release();
// TODO(ycling): Filling `custom_initial_data` and `custom_initial_data_size`
// properly for nodes generated by ReplaceSubgraphsWithDelegateKernels.
+
if (registration->builtin_code == BuiltinOperator_CUSTOM) {
// When it's a CUSTOM op, the `custom_options` field in the Flatbuffer
// `Operator` table is passed in.
@@ -309,6 +348,7 @@ TfLiteStatus Interpreter::AddNodeWithParameters(
node.custom_initial_data_size = 0;
}
+ node.delegate = nullptr;
node_and_reg.second = *registration;
execution_plan_.push_back(new_node_index);
return kTfLiteOk;
@@ -322,7 +362,7 @@ TfLiteStatus Interpreter::ResizeInputTensor(int tensor_index,
TF_LITE_ENSURE(&context_,
tensor_index < context_.tensors_size && tensor_index >= 0);
invokable_ = false;
- TfLiteIntArray* dims_lite = convertVectorToTfLiteIntArray(dims);
+ TfLiteIntArray* dims_lite = ConvertVectorToTfLiteIntArray(dims);
return ResizeTensorImpl(&context_.tensors[tensor_index], dims_lite);
}
@@ -424,11 +464,29 @@ TfLiteStatus Interpreter::Invoke() {
TfLiteNode& node = nodes_and_registration_[node_index].first;
const TfLiteRegistration& registration =
nodes_and_registration_[node_index].second;
+
+ // TODO(ycling): This is an extra loop through inputs to check if the data
+ // need to be copied from Delegate buffer to raw memory, which is often not
+ // needed. We may want to cache this in prepare to know if this needs to be
+ // done for a node or not.
+ for (int i = 0; i < node.inputs->size; ++i) {
+ int tensor_index = node.inputs->data[i];
+ if (tensor_index == kOptionalTensor) {
+ continue;
+ }
+ TfLiteTensor* tensor = &tensors_[tensor_index];
+ if (tensor->delegate && tensor->delegate != node.delegate &&
+ tensor->data_is_stale) {
+ EnsureTensorDataIsReadable(tensor_index);
+ }
+ }
+
EnsureTensorsVectorCapacity();
if (OpInvoke(registration, &node) == kTfLiteError) {
status = kTfLiteError;
}
}
+
return status;
}
@@ -464,6 +522,7 @@ TfLiteStatus Interpreter::AddTensors(int tensors_to_add,
tensors_.resize(tensors_.size() + tensors_to_add);
for (int i = base_index; i < tensors_.size(); i++) {
memset(&tensors_[i], 0, sizeof(tensors_[i]));
+ tensors_[i].buffer_handle = kTfLiteNullBufferHandle;
}
context_.tensors = tensors_.data();
context_.tensors_size = tensors_.size();
@@ -511,7 +570,7 @@ TfLiteStatus Interpreter::SetTensorParametersReadOnly(
TF_LITE_ENSURE_EQ(&context_, required_bytes, bytes);
}
invokable_ = false;
- TfLiteTensorReset(type, name, convertVectorToTfLiteIntArray(dims),
+ TfLiteTensorReset(type, name, ConvertVectorToTfLiteIntArray(dims),
quantization, const_cast<char*>(buffer), bytes,
kTfLiteMmapRo, allocation, &context_.tensors[tensor_index]);
return kTfLiteOk;
@@ -536,7 +595,7 @@ TfLiteStatus Interpreter::SetTensorParametersReadWrite(
TF_LITE_ENSURE_OK(&context_, BytesRequired(type, dims.data(), dims.size(),
&required_bytes));
}
- TfLiteTensorReset(type, name, convertVectorToTfLiteIntArray(dims),
+ TfLiteTensorReset(type, name, ConvertVectorToTfLiteIntArray(dims),
quantization,
/*buffer=*/nullptr, required_bytes,
type == kTfLiteString ? kTfLiteDynamic : kTfLiteArenaRw,
@@ -613,7 +672,7 @@ TfLiteStatus Interpreter::ModifyGraphWithDelegate(TfLiteDelegate* delegate) {
ReplaceSubgraphsWithDelegateKernels;
context_.GetExecutionPlan = GetExecutionPlan;
- TfLiteStatus status = delegate->Prepare(&context_, delegate->data_);
+ TfLiteStatus status = delegate->Prepare(&context_, delegate);
// Remove additional context info.
context_.GetNodeAndRegistration = nullptr;
context_.ReplaceSubgraphsWithDelegateKernels = nullptr;
@@ -621,4 +680,35 @@ TfLiteStatus Interpreter::ModifyGraphWithDelegate(TfLiteDelegate* delegate) {
return status;
}
+TfLiteStatus Interpreter::SetBufferHandle(int tensor_index,
+ TfLiteBufferHandle buffer_handle,
+ TfLiteDelegate* delegate) {
+ TF_LITE_ENSURE(&context_, tensor_index < tensors_size());
+ TfLiteTensor* tensor = &tensors_[tensor_index];
+
+ TF_LITE_ENSURE(&context_,
+ tensor->delegate == nullptr || tensor->delegate == delegate);
+ tensor->delegate = delegate;
+ if (tensor->buffer_handle != kTfLiteNullBufferHandle) {
+ TF_LITE_ENSURE(&context_, tensor->delegate->FreeBufferHandle != nullptr);
+ tensor->delegate->FreeBufferHandle(tensor->delegate,
+ &tensor->buffer_handle);
+ }
+ tensor->buffer_handle = buffer_handle;
+
+ return kTfLiteOk;
+}
+
+TfLiteStatus Interpreter::GetBufferHandle(int tensor_index,
+ TfLiteBufferHandle* buffer_handle,
+ TfLiteDelegate** delegate) {
+ TF_LITE_ENSURE(&context_, tensor_index < tensors_size());
+ TfLiteTensor* tensor = &tensors_[tensor_index];
+
+ *delegate = tensor->delegate;
+ *buffer_handle = tensor->buffer_handle;
+
+ return kTfLiteOk;
+}
+
} // namespace tflite