diff options
Diffstat (limited to 'tensorflow/contrib/lite/delegates/eager/delegate.cc')
-rw-r--r-- | tensorflow/contrib/lite/delegates/eager/delegate.cc | 108 |
1 files changed, 108 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/delegates/eager/delegate.cc b/tensorflow/contrib/lite/delegates/eager/delegate.cc new file mode 100644 index 0000000000..45fc158157 --- /dev/null +++ b/tensorflow/contrib/lite/delegates/eager/delegate.cc @@ -0,0 +1,108 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/delegates/eager/delegate.h" + +#include <vector> + +#include "tensorflow/contrib/lite/context_util.h" +#include "tensorflow/contrib/lite/delegates/eager/buffer_map.h" +#include "tensorflow/contrib/lite/delegates/eager/kernel.h" +#include "tensorflow/contrib/lite/delegates/eager/util.h" +#include "tensorflow/contrib/lite/util.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tflite { +namespace eager { +namespace delegate { + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteDelegate* delegate) { + // Get the nodes in the current execution plan. Interpreter owns this array. + TfLiteIntArray* plan; + TF_LITE_ENSURE_STATUS(context->GetExecutionPlan(context, &plan)); + + // Add all custom ops starting with "Eager" to list of supported nodes. + std::vector<int> supported_nodes; + for (int node_index : TfLiteIntArrayView(plan)) { + TfLiteNode* node; + TfLiteRegistration* registration; + TF_LITE_ENSURE_STATUS(context->GetNodeAndRegistration( + context, node_index, &node, ®istration)); + + if (IsEagerOp(registration->custom_name)) { + supported_nodes.push_back(node_index); + } + } + + // Request TFLite to partition the graph and make kernels for each independent + // subgraph. + TfLiteIntArray* size_and_nodes = + ConvertVectorToTfLiteIntArray(supported_nodes); + context->ReplaceSubgraphsWithDelegateKernels(context, GetKernel(), + size_and_nodes, delegate); + TfLiteIntArrayFree(size_and_nodes); + return kTfLiteOk; +} + +TfLiteStatus CopyFromBufferHandle(TfLiteContext* context, + TfLiteDelegate* delegate, + TfLiteBufferHandle buffer_handle, void* data, + size_t size) { + BufferMap* buffer_map = + reinterpret_cast<DelegateData*>(delegate->data_)->GetBufferMap(context); + + if (!buffer_map->HasTensor(buffer_handle)) { + context->ReportError(context, "Invalid tensor index %d.", buffer_handle); + return kTfLiteError; + } + + tensorflow::Tensor t = buffer_map->GetTensor(buffer_handle); + tensorflow::StringPiece t_data = t.tensor_data(); + + if (size != t_data.size()) { + context->ReportError( + context, "Not enough space to store TensorFlow's aligned buffer."); + return kTfLiteError; + } + + memcpy(data, t_data.data(), t_data.size()); + return kTfLiteOk; +} + +} // namespace delegate +} // namespace eager + +std::unique_ptr<EagerDelegate> EagerDelegate::Create() { + std::unique_ptr<eager::DelegateData> delegate_data; + if (!eager::DelegateData::Create(&delegate_data).ok()) { + fprintf(stderr, "Unable to initialize TensorFlow context.\n"); + return nullptr; + } + + return std::unique_ptr<EagerDelegate>( + new EagerDelegate(std::move(delegate_data))); +} + +EagerDelegate::EagerDelegate(std::unique_ptr<eager::DelegateData> delegate_data) + : TfLiteDelegate{ + /*data_=*/delegate_data.get(), + /*nullptr,*/ &eager::delegate::Prepare, + /*CopyFromBufferHandle=*/&eager::delegate::CopyFromBufferHandle, + /*CopyToBufferHandle=*/nullptr, + /*FreeBufferHandle=*/nullptr}, + delegate_data_(std::move(delegate_data)) {} + +EagerDelegate::~EagerDelegate() {} + +} // namespace tflite |