aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/delegates/eager/delegate.cc
diff options
context:
space:
mode:
authorGravatar Avijit <Avijit.Chakraborty@intel.com>2018-08-12 16:21:41 -0700
committerGravatar Avijit <Avijit.Chakraborty@intel.com>2018-08-12 16:21:41 -0700
commit9523a98466d16cf01fc76a67b489f1124cf626ac (patch)
treebd4c460b67fab60c2fb1a6c56bf22d1cbb5391e6 /tensorflow/contrib/lite/delegates/eager/delegate.cc
parent93e950c308071071f35d6dcb35b9f91b8a34876c (diff)
parent1a22b0b982fa1a953651b98af8f3cd30542048fd (diff)
Merge remote-tracking branch 'upstream/master'
Diffstat (limited to 'tensorflow/contrib/lite/delegates/eager/delegate.cc')
-rw-r--r--tensorflow/contrib/lite/delegates/eager/delegate.cc38
1 files changed, 23 insertions, 15 deletions
diff --git a/tensorflow/contrib/lite/delegates/eager/delegate.cc b/tensorflow/contrib/lite/delegates/eager/delegate.cc
index 673859da48..7d22b45419 100644
--- a/tensorflow/contrib/lite/delegates/eager/delegate.cc
+++ b/tensorflow/contrib/lite/delegates/eager/delegate.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/contrib/lite/context_util.h"
#include "tensorflow/contrib/lite/delegates/eager/buffer_map.h"
#include "tensorflow/contrib/lite/delegates/eager/kernel.h"
+#include "tensorflow/contrib/lite/delegates/eager/util.h"
#include "tensorflow/contrib/lite/util.h"
#include "tensorflow/core/lib/core/status.h"
@@ -27,7 +28,7 @@ namespace eager {
namespace delegate {
TfLiteStatus Prepare(TfLiteContext* context, TfLiteDelegate* delegate) {
- // Get the nodes in the current execution plan.
+ // Get the nodes in the current execution plan. Interpreter owns this array.
TfLiteIntArray* plan;
TF_LITE_ENSURE_STATUS(context->GetExecutionPlan(context, &plan));
@@ -39,8 +40,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteDelegate* delegate) {
TF_LITE_ENSURE_STATUS(context->GetNodeAndRegistration(
context, node_index, &node, &registration));
- if (registration->custom_name &&
- strncmp(registration->custom_name, "Eager", 5) == 0) {
+ if (IsEagerOp(registration->custom_name)) {
supported_nodes.push_back(node_index);
}
}
@@ -63,6 +63,7 @@ TfLiteStatus CopyFromBufferHandle(TfLiteDelegate* delegate,
BufferMap* buffer_map =
reinterpret_cast<DelegateData*>(delegate->data_)->GetBufferMap();
+ // TODO(nupurgarg): Use TfLiteContext's ReportError instead of fprinf.
if (!buffer_map->HasTensor(buffer_handle)) {
fprintf(stderr, "Invalid tensor index %d.\n", buffer_handle);
return kTfLiteError;
@@ -83,20 +84,27 @@ TfLiteStatus CopyFromBufferHandle(TfLiteDelegate* delegate,
} // namespace delegate
} // namespace eager
-EagerDelegate::EagerDelegate() {
- if (!eager::DelegateData::Create(&delegate_data_).ok()) {
- fprintf(stderr, "Unable to initialize TensorFlow context.\n");
- return;
+EagerDelegate::EagerDelegate() {}
+
+EagerDelegate::~EagerDelegate() {}
+
+TfLiteStatus EagerDelegate::Apply(Interpreter* interpreter) {
+ if (!delegate_) {
+ if (!eager::DelegateData::Create(&delegate_data_).ok()) {
+ fprintf(stderr, "Unable to initialize TensorFlow context.\n");
+ return kTfLiteError;
+ }
+
+ delegate_.reset(new TfLiteDelegate{
+ /*data_=*/delegate_data_.get(),
+ /*nullptr,*/ &eager::delegate::Prepare,
+ /*CopyFromBufferHandle=*/&eager::delegate::CopyFromBufferHandle,
+ /*CopyToBufferHandle=*/nullptr,
+ /*FreeBufferHandle=*/nullptr});
}
- delegate_.reset(new TfLiteDelegate{
- /*data_=*/delegate_data_.get(),
- /*nullptr,*/ &eager::delegate::Prepare,
- /*CopyFromBufferHandle=*/&eager::delegate::CopyFromBufferHandle,
- /*CopyToBufferHandle=*/nullptr,
- /*FreeBufferHandle=*/nullptr});
+ return interpreter->ModifyGraphWithDelegate(delegate_.get(),
+ /*allow_dynamic_tensors=*/true);
}
-EagerDelegate::~EagerDelegate() {}
-
} // namespace tflite