diff options
author | Yu-Cheng Ling <ycling@google.com> | 2018-03-08 11:56:29 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-03-08 12:00:28 -0800 |
commit | 6e3a43f4b7a1288c878b5daff274f1229256fbe8 (patch) | |
tree | 65dbe6000d5bf896015d64519a2d5783c1857e83 /tensorflow/contrib/lite/interpreter.cc | |
parent | 214ad0978641a946c25b334c4a33ecd1793b4d70 (diff) |
TFLite: Delegate Buffer Handle interface (take 2)
PiperOrigin-RevId: 188366045
Diffstat (limited to 'tensorflow/contrib/lite/interpreter.cc')
-rw-r--r-- | tensorflow/contrib/lite/interpreter.cc | 154 |
1 files changed, 122 insertions, 32 deletions
diff --git a/tensorflow/contrib/lite/interpreter.cc b/tensorflow/contrib/lite/interpreter.cc index 0f5e17f0de..8fd1085544 100644 --- a/tensorflow/contrib/lite/interpreter.cc +++ b/tensorflow/contrib/lite/interpreter.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/contrib/lite/memory_planner.h" #include "tensorflow/contrib/lite/nnapi_delegate.h" #include "tensorflow/contrib/lite/schema/schema_generated.h" +#include "tensorflow/contrib/lite/util.h" namespace tflite { @@ -96,19 +97,57 @@ Interpreter::~Interpreter() { } for (int i = 0; i < context_.tensors_size; i++) { - TfLiteTensorFree(&context_.tensors[i]); + TfLiteTensor* tensor = &context_.tensors[i]; + if (tensor->buffer_handle != kTfLiteNullBufferHandle) { + tensor->delegate->FreeBufferHandle(tensor->delegate, + &tensor->buffer_handle); + } + TfLiteTensorFree(tensor); } } TfLiteStatus Interpreter::ReplaceSubgraphsWithDelegateKernels( TfLiteContext* context, TfLiteRegistration registration, - const TfLiteIntArray* nodes_to_replace) { + const TfLiteIntArray* nodes_to_replace, TfLiteDelegate* delegate) { return static_cast<Interpreter*>(context->impl_) - ->ReplaceSubgraphsWithDelegateKernels(registration, nodes_to_replace); + ->ReplaceSubgraphsWithDelegateKernels(registration, nodes_to_replace, + delegate); +} + +namespace { + +// This function allocates a continuous memory space that contains a +// TfLiteDelegateParams followed by a TfLiteIntArray. The pointer will be +// deallocated by C `free` function later. +TfLiteDelegateParams* CreateDelegateParams( + TfLiteDelegate* delegate, const std::vector<int>& nodes_to_replace) { + int nodes_to_replace_size_in_bytes = + TfLiteIntArrayGetSizeInBytes(nodes_to_replace.size()); + void* allocation = + malloc(sizeof(TfLiteDelegateParams) + nodes_to_replace_size_in_bytes); + TfLiteDelegateParams* params = + reinterpret_cast<TfLiteDelegateParams*>(allocation); + TfLiteIntArray* nodes_to_replace_arr = reinterpret_cast<TfLiteIntArray*>( + static_cast<char*>(allocation) + sizeof(TfLiteDelegateParams)); + + nodes_to_replace_arr->size = nodes_to_replace.size(); + for (int i = 0; i < nodes_to_replace.size(); ++i) { + nodes_to_replace_arr->data[i] = nodes_to_replace[i]; + } + + params->delegate = delegate; + params->nodes_to_replace = nodes_to_replace_arr; + return params; } +} // Anonymous namespace + TfLiteStatus Interpreter::ReplaceSubgraphsWithDelegateKernels( - TfLiteRegistration registration, const TfLiteIntArray* nodes_to_replace) { + TfLiteRegistration registration, const TfLiteIntArray* nodes_to_replace, + TfLiteDelegate* delegate) { + // Annotate the registration as DELEGATE op. + registration.builtin_code = BuiltinOperator_DELEGATE; + // Annotate the registration as DELEGATE op. registration.builtin_code = BuiltinOperator_DELEGATE; @@ -120,30 +159,38 @@ TfLiteStatus Interpreter::ReplaceSubgraphsWithDelegateKernels( execution_plan_.clear(); for (auto& subgraph : subgraphs) { - // Turn subgraph.nodes into a TfLiteIntArray compatible data structure. - // TODO(aselle): Avoid this copy by constructing subgraph.nodes that way - // in the first place - subgraph.nodes.insert(subgraph.nodes.begin(), - static_cast<int>(subgraph.nodes.size())); // Subgraphs calimed by the delegate should have a "macro" op created, the // other subgraphs (kTfNonPartition) just have their nodes added back to // the execution plan. switch (subgraph.type) { case Subgraph::kTfNonPartition: - for (auto it = subgraph.nodes.begin() + 1; it != subgraph.nodes.end(); + for (auto it = subgraph.nodes.begin(); it != subgraph.nodes.end(); ++it) { execution_plan_.push_back(*it); } break; case Subgraph::kTfPartition: { - void* builtin_data = nullptr; int node_index; - // Create a node that represents computation of this subgraph. - AddNodeWithParameters( - subgraph.input_tensors, subgraph.output_tensors, - reinterpret_cast<const char*>(subgraph.nodes.data()), - subgraph.nodes.size() * sizeof(subgraph.nodes[0]), builtin_data, - ®istration, &node_index); + + TfLiteDelegateParams* params = + CreateDelegateParams(delegate, subgraph.nodes); + AddNodeWithParameters(subgraph.input_tensors, subgraph.output_tensors, + nullptr, 0, params, ®istration, &node_index); + + // Initialize the output tensors's delegate-related fields. + for (int tensor_index : subgraph.output_tensors) { + TfLiteTensor* tensor = &tensors_[tensor_index]; + TF_LITE_ENSURE_EQ(&context_, tensor->delegate, nullptr); + TF_LITE_ENSURE_EQ(&context_, tensor->buffer_handle, + kTfLiteNullBufferHandle); + // buffer_handle will be filled in delegate's `Prepare` + // function. + tensor->delegate = delegate; + } + + // Associate the node with the delegate. + TfLiteNode* node = &nodes_and_registration_[node_index].first; + node->delegate = delegate; } break; case Subgraph::kTfUnexplored: return kTfLiteError; @@ -233,14 +280,6 @@ TfLiteStatus Interpreter::BytesRequired(TfLiteType type, const int* dims, return kTfLiteOk; } -namespace { -TfLiteIntArray* convertVectorToTfLiteIntArray(const std::vector<int>& x) { - TfLiteIntArray* lite = TfLiteIntArrayCreate(x.size()); - for (size_t i = 0; i < x.size(); i++) lite->data[i] = x[i]; - return lite; -} -} // namespace - TfLiteStatus Interpreter::AllocateTensors() { next_execution_plan_index_to_prepare_ = 0; if (memory_planner_) { @@ -275,7 +314,6 @@ TfLiteStatus Interpreter::AddNodeWithParameters( int new_node_index = nodes_and_registration_.size(); if (node_index) *node_index = new_node_index; nodes_and_registration_.resize(nodes_and_registration_.size() + 1); - auto& node_and_reg = nodes_and_registration_.back(); TfLiteNode& node = node_and_reg.first; if (node.inputs) TfLiteIntArrayFree(node.inputs); @@ -285,8 +323,8 @@ TfLiteStatus Interpreter::AddNodeWithParameters( // NOTE, here we are not using move semantics yet, since our internal // representation isn't std::vector, but in the future we would like to avoid // copies, so we want the interface to take r-value references now. - node.inputs = convertVectorToTfLiteIntArray(inputs); - node.outputs = convertVectorToTfLiteIntArray(outputs); + node.inputs = ConvertVectorToTfLiteIntArray(inputs); + node.outputs = ConvertVectorToTfLiteIntArray(outputs); node.temporaries = TfLiteIntArrayCreate(0); if (init_data) { node.user_data = OpInit(*registration, init_data, init_data_size); @@ -299,6 +337,7 @@ TfLiteStatus Interpreter::AddNodeWithParameters( node.builtin_data = builtin_data_deleter.release(); // TODO(ycling): Filling `custom_initial_data` and `custom_initial_data_size` // properly for nodes generated by ReplaceSubgraphsWithDelegateKernels. + if (registration->builtin_code == BuiltinOperator_CUSTOM) { // When it's a CUSTOM op, the `custom_options` field in the Flatbuffer // `Operator` table is passed in. @@ -309,6 +348,7 @@ TfLiteStatus Interpreter::AddNodeWithParameters( node.custom_initial_data_size = 0; } + node.delegate = nullptr; node_and_reg.second = *registration; execution_plan_.push_back(new_node_index); return kTfLiteOk; @@ -322,7 +362,7 @@ TfLiteStatus Interpreter::ResizeInputTensor(int tensor_index, TF_LITE_ENSURE(&context_, tensor_index < context_.tensors_size && tensor_index >= 0); invokable_ = false; - TfLiteIntArray* dims_lite = convertVectorToTfLiteIntArray(dims); + TfLiteIntArray* dims_lite = ConvertVectorToTfLiteIntArray(dims); return ResizeTensorImpl(&context_.tensors[tensor_index], dims_lite); } @@ -424,11 +464,29 @@ TfLiteStatus Interpreter::Invoke() { TfLiteNode& node = nodes_and_registration_[node_index].first; const TfLiteRegistration& registration = nodes_and_registration_[node_index].second; + + // TODO(ycling): This is an extra loop through inputs to check if the data + // need to be copied from Delegate buffer to raw memory, which is often not + // needed. We may want to cache this in prepare to know if this needs to be + // done for a node or not. + for (int i = 0; i < node.inputs->size; ++i) { + int tensor_index = node.inputs->data[i]; + if (tensor_index == kOptionalTensor) { + continue; + } + TfLiteTensor* tensor = &tensors_[tensor_index]; + if (tensor->delegate && tensor->delegate != node.delegate && + tensor->data_is_stale) { + EnsureTensorDataIsReadable(tensor_index); + } + } + EnsureTensorsVectorCapacity(); if (OpInvoke(registration, &node) == kTfLiteError) { status = kTfLiteError; } } + return status; } @@ -464,6 +522,7 @@ TfLiteStatus Interpreter::AddTensors(int tensors_to_add, tensors_.resize(tensors_.size() + tensors_to_add); for (int i = base_index; i < tensors_.size(); i++) { memset(&tensors_[i], 0, sizeof(tensors_[i])); + tensors_[i].buffer_handle = kTfLiteNullBufferHandle; } context_.tensors = tensors_.data(); context_.tensors_size = tensors_.size(); @@ -511,7 +570,7 @@ TfLiteStatus Interpreter::SetTensorParametersReadOnly( TF_LITE_ENSURE_EQ(&context_, required_bytes, bytes); } invokable_ = false; - TfLiteTensorReset(type, name, convertVectorToTfLiteIntArray(dims), + TfLiteTensorReset(type, name, ConvertVectorToTfLiteIntArray(dims), quantization, const_cast<char*>(buffer), bytes, kTfLiteMmapRo, allocation, &context_.tensors[tensor_index]); return kTfLiteOk; @@ -536,7 +595,7 @@ TfLiteStatus Interpreter::SetTensorParametersReadWrite( TF_LITE_ENSURE_OK(&context_, BytesRequired(type, dims.data(), dims.size(), &required_bytes)); } - TfLiteTensorReset(type, name, convertVectorToTfLiteIntArray(dims), + TfLiteTensorReset(type, name, ConvertVectorToTfLiteIntArray(dims), quantization, /*buffer=*/nullptr, required_bytes, type == kTfLiteString ? kTfLiteDynamic : kTfLiteArenaRw, @@ -613,7 +672,7 @@ TfLiteStatus Interpreter::ModifyGraphWithDelegate(TfLiteDelegate* delegate) { ReplaceSubgraphsWithDelegateKernels; context_.GetExecutionPlan = GetExecutionPlan; - TfLiteStatus status = delegate->Prepare(&context_, delegate->data_); + TfLiteStatus status = delegate->Prepare(&context_, delegate); // Remove additional context info. context_.GetNodeAndRegistration = nullptr; context_.ReplaceSubgraphsWithDelegateKernels = nullptr; @@ -621,4 +680,35 @@ TfLiteStatus Interpreter::ModifyGraphWithDelegate(TfLiteDelegate* delegate) { return status; } +TfLiteStatus Interpreter::SetBufferHandle(int tensor_index, + TfLiteBufferHandle buffer_handle, + TfLiteDelegate* delegate) { + TF_LITE_ENSURE(&context_, tensor_index < tensors_size()); + TfLiteTensor* tensor = &tensors_[tensor_index]; + + TF_LITE_ENSURE(&context_, + tensor->delegate == nullptr || tensor->delegate == delegate); + tensor->delegate = delegate; + if (tensor->buffer_handle != kTfLiteNullBufferHandle) { + TF_LITE_ENSURE(&context_, tensor->delegate->FreeBufferHandle != nullptr); + tensor->delegate->FreeBufferHandle(tensor->delegate, + &tensor->buffer_handle); + } + tensor->buffer_handle = buffer_handle; + + return kTfLiteOk; +} + +TfLiteStatus Interpreter::GetBufferHandle(int tensor_index, + TfLiteBufferHandle* buffer_handle, + TfLiteDelegate** delegate) { + TF_LITE_ENSURE(&context_, tensor_index < tensors_size()); + TfLiteTensor* tensor = &tensors_[tensor_index]; + + *delegate = tensor->delegate; + *buffer_handle = tensor->buffer_handle; + + return kTfLiteOk; +} + } // namespace tflite |