diff options
author | 2018-08-12 16:21:41 -0700 | |
---|---|---|
committer | 2018-08-12 16:21:41 -0700 | |
commit | 9523a98466d16cf01fc76a67b489f1124cf626ac (patch) | |
tree | bd4c460b67fab60c2fb1a6c56bf22d1cbb5391e6 /tensorflow/contrib/lite/delegates/eager/delegate.cc | |
parent | 93e950c308071071f35d6dcb35b9f91b8a34876c (diff) | |
parent | 1a22b0b982fa1a953651b98af8f3cd30542048fd (diff) |
Merge remote-tracking branch 'upstream/master'
Diffstat (limited to 'tensorflow/contrib/lite/delegates/eager/delegate.cc')
-rw-r--r-- | tensorflow/contrib/lite/delegates/eager/delegate.cc | 38 |
1 files changed, 23 insertions, 15 deletions
diff --git a/tensorflow/contrib/lite/delegates/eager/delegate.cc b/tensorflow/contrib/lite/delegates/eager/delegate.cc index 673859da48..7d22b45419 100644 --- a/tensorflow/contrib/lite/delegates/eager/delegate.cc +++ b/tensorflow/contrib/lite/delegates/eager/delegate.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/contrib/lite/context_util.h" #include "tensorflow/contrib/lite/delegates/eager/buffer_map.h" #include "tensorflow/contrib/lite/delegates/eager/kernel.h" +#include "tensorflow/contrib/lite/delegates/eager/util.h" #include "tensorflow/contrib/lite/util.h" #include "tensorflow/core/lib/core/status.h" @@ -27,7 +28,7 @@ namespace eager { namespace delegate { TfLiteStatus Prepare(TfLiteContext* context, TfLiteDelegate* delegate) { - // Get the nodes in the current execution plan. + // Get the nodes in the current execution plan. Interpreter owns this array. TfLiteIntArray* plan; TF_LITE_ENSURE_STATUS(context->GetExecutionPlan(context, &plan)); @@ -39,8 +40,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteDelegate* delegate) { TF_LITE_ENSURE_STATUS(context->GetNodeAndRegistration( context, node_index, &node, ®istration)); - if (registration->custom_name && - strncmp(registration->custom_name, "Eager", 5) == 0) { + if (IsEagerOp(registration->custom_name)) { supported_nodes.push_back(node_index); } } @@ -63,6 +63,7 @@ TfLiteStatus CopyFromBufferHandle(TfLiteDelegate* delegate, BufferMap* buffer_map = reinterpret_cast<DelegateData*>(delegate->data_)->GetBufferMap(); + // TODO(nupurgarg): Use TfLiteContext's ReportError instead of fprinf. if (!buffer_map->HasTensor(buffer_handle)) { fprintf(stderr, "Invalid tensor index %d.\n", buffer_handle); return kTfLiteError; @@ -83,20 +84,27 @@ TfLiteStatus CopyFromBufferHandle(TfLiteDelegate* delegate, } // namespace delegate } // namespace eager -EagerDelegate::EagerDelegate() { - if (!eager::DelegateData::Create(&delegate_data_).ok()) { - fprintf(stderr, "Unable to initialize TensorFlow context.\n"); - return; +EagerDelegate::EagerDelegate() {} + +EagerDelegate::~EagerDelegate() {} + +TfLiteStatus EagerDelegate::Apply(Interpreter* interpreter) { + if (!delegate_) { + if (!eager::DelegateData::Create(&delegate_data_).ok()) { + fprintf(stderr, "Unable to initialize TensorFlow context.\n"); + return kTfLiteError; + } + + delegate_.reset(new TfLiteDelegate{ + /*data_=*/delegate_data_.get(), + /*nullptr,*/ &eager::delegate::Prepare, + /*CopyFromBufferHandle=*/&eager::delegate::CopyFromBufferHandle, + /*CopyToBufferHandle=*/nullptr, + /*FreeBufferHandle=*/nullptr}); } - delegate_.reset(new TfLiteDelegate{ - /*data_=*/delegate_data_.get(), - /*nullptr,*/ &eager::delegate::Prepare, - /*CopyFromBufferHandle=*/&eager::delegate::CopyFromBufferHandle, - /*CopyToBufferHandle=*/nullptr, - /*FreeBufferHandle=*/nullptr}); + return interpreter->ModifyGraphWithDelegate(delegate_.get(), + /*allow_dynamic_tensors=*/true); } -EagerDelegate::~EagerDelegate() {} - } // namespace tflite |