aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Nupur Garg <nupurgarg@google.com>2018-08-08 13:33:49 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-08 13:38:02 -0700
commit54d92a58cad8619460889bd1b1ef34df89d1b612 (patch)
tree5c61666c2331dfc7e5d2001a679bff761c78325e
parent1ce16d0f9e5282d74c51f8471d1962ada8c566ed (diff)
Minor fixes to Eager delegate.
PiperOrigin-RevId: 207937525
-rw-r--r--tensorflow/contrib/lite/delegates/eager/BUILD2
-rw-r--r--tensorflow/contrib/lite/delegates/eager/delegate.cc38
-rw-r--r--tensorflow/contrib/lite/delegates/eager/delegate.h8
-rw-r--r--tensorflow/contrib/lite/delegates/eager/util.cc6
-rw-r--r--tensorflow/contrib/lite/delegates/eager/util.h4
-rw-r--r--tensorflow/contrib/lite/delegates/eager/util_test.cc10
-rw-r--r--tensorflow/contrib/lite/interpreter.h7
7 files changed, 52 insertions, 23 deletions
diff --git a/tensorflow/contrib/lite/delegates/eager/BUILD b/tensorflow/contrib/lite/delegates/eager/BUILD
index 332a871446..f21540d524 100644
--- a/tensorflow/contrib/lite/delegates/eager/BUILD
+++ b/tensorflow/contrib/lite/delegates/eager/BUILD
@@ -50,6 +50,7 @@ cc_library(
":buffer_map",
":delegate_data",
":kernel",
+ ":util",
"//tensorflow/contrib/lite:framework",
"//tensorflow/contrib/lite:kernel_api",
"//tensorflow/contrib/lite:util",
@@ -154,6 +155,7 @@ cc_library(
srcs = ["util.cc"],
hdrs = ["util.h"],
deps = [
+ ":constants",
"//tensorflow/c:c_api_internal",
"//tensorflow/contrib/lite:framework",
"//tensorflow/contrib/lite:kernel_api",
diff --git a/tensorflow/contrib/lite/delegates/eager/delegate.cc b/tensorflow/contrib/lite/delegates/eager/delegate.cc
index 673859da48..7d22b45419 100644
--- a/tensorflow/contrib/lite/delegates/eager/delegate.cc
+++ b/tensorflow/contrib/lite/delegates/eager/delegate.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/contrib/lite/context_util.h"
#include "tensorflow/contrib/lite/delegates/eager/buffer_map.h"
#include "tensorflow/contrib/lite/delegates/eager/kernel.h"
+#include "tensorflow/contrib/lite/delegates/eager/util.h"
#include "tensorflow/contrib/lite/util.h"
#include "tensorflow/core/lib/core/status.h"
@@ -27,7 +28,7 @@ namespace eager {
namespace delegate {
TfLiteStatus Prepare(TfLiteContext* context, TfLiteDelegate* delegate) {
- // Get the nodes in the current execution plan.
+ // Get the nodes in the current execution plan. Interpreter owns this array.
TfLiteIntArray* plan;
TF_LITE_ENSURE_STATUS(context->GetExecutionPlan(context, &plan));
@@ -39,8 +40,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteDelegate* delegate) {
TF_LITE_ENSURE_STATUS(context->GetNodeAndRegistration(
context, node_index, &node, &registration));
- if (registration->custom_name &&
- strncmp(registration->custom_name, "Eager", 5) == 0) {
+ if (IsEagerOp(registration->custom_name)) {
supported_nodes.push_back(node_index);
}
}
@@ -63,6 +63,7 @@ TfLiteStatus CopyFromBufferHandle(TfLiteDelegate* delegate,
BufferMap* buffer_map =
reinterpret_cast<DelegateData*>(delegate->data_)->GetBufferMap();
+ // TODO(nupurgarg): Use TfLiteContext's ReportError instead of fprinf.
if (!buffer_map->HasTensor(buffer_handle)) {
fprintf(stderr, "Invalid tensor index %d.\n", buffer_handle);
return kTfLiteError;
@@ -83,20 +84,27 @@ TfLiteStatus CopyFromBufferHandle(TfLiteDelegate* delegate,
} // namespace delegate
} // namespace eager
-EagerDelegate::EagerDelegate() {
- if (!eager::DelegateData::Create(&delegate_data_).ok()) {
- fprintf(stderr, "Unable to initialize TensorFlow context.\n");
- return;
+EagerDelegate::EagerDelegate() {}
+
+EagerDelegate::~EagerDelegate() {}
+
+TfLiteStatus EagerDelegate::Apply(Interpreter* interpreter) {
+ if (!delegate_) {
+ if (!eager::DelegateData::Create(&delegate_data_).ok()) {
+ fprintf(stderr, "Unable to initialize TensorFlow context.\n");
+ return kTfLiteError;
+ }
+
+ delegate_.reset(new TfLiteDelegate{
+ /*data_=*/delegate_data_.get(),
+ /*nullptr,*/ &eager::delegate::Prepare,
+ /*CopyFromBufferHandle=*/&eager::delegate::CopyFromBufferHandle,
+ /*CopyToBufferHandle=*/nullptr,
+ /*FreeBufferHandle=*/nullptr});
}
- delegate_.reset(new TfLiteDelegate{
- /*data_=*/delegate_data_.get(),
- /*nullptr,*/ &eager::delegate::Prepare,
- /*CopyFromBufferHandle=*/&eager::delegate::CopyFromBufferHandle,
- /*CopyToBufferHandle=*/nullptr,
- /*FreeBufferHandle=*/nullptr});
+ return interpreter->ModifyGraphWithDelegate(delegate_.get(),
+ /*allow_dynamic_tensors=*/true);
}
-EagerDelegate::~EagerDelegate() {}
-
} // namespace tflite
diff --git a/tensorflow/contrib/lite/delegates/eager/delegate.h b/tensorflow/contrib/lite/delegates/eager/delegate.h
index 6259b35931..0defca7c32 100644
--- a/tensorflow/contrib/lite/delegates/eager/delegate.h
+++ b/tensorflow/contrib/lite/delegates/eager/delegate.h
@@ -30,7 +30,7 @@ namespace tflite {
// interpreter.
//
// Usage:
-// EagerDelegate delegate();
+// EagerDelegate delegate;
// ... build interpreter ...
//
// delegate.Apply(interpreter);
@@ -42,10 +42,8 @@ class EagerDelegate {
EagerDelegate();
~EagerDelegate();
- TfLiteStatus Apply(Interpreter* interpreter) {
- return interpreter->ModifyGraphWithDelegate(delegate_.get(),
- /*allow_dynamic_tensors=*/true);
- }
+ // Modifies the graph loaded in the interpreter.
+ TfLiteStatus Apply(Interpreter* interpreter);
private:
std::unique_ptr<eager::DelegateData> delegate_data_;
diff --git a/tensorflow/contrib/lite/delegates/eager/util.cc b/tensorflow/contrib/lite/delegates/eager/util.cc
index 4426c653e6..c8aa0b7f69 100644
--- a/tensorflow/contrib/lite/delegates/eager/util.cc
+++ b/tensorflow/contrib/lite/delegates/eager/util.cc
@@ -13,10 +13,16 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/lite/delegates/eager/util.h"
+#include "tensorflow/contrib/lite/delegates/eager/constants.h"
namespace tflite {
namespace eager {
+bool IsEagerOp(const char* custom_name) {
+ return custom_name && strncmp(custom_name, kCustomCodePrefix,
+ strlen(kCustomCodePrefix)) == 0;
+}
+
TfLiteStatus ConvertStatus(TfLiteContext* context,
const tensorflow::Status& status) {
if (!status.ok()) {
diff --git a/tensorflow/contrib/lite/delegates/eager/util.h b/tensorflow/contrib/lite/delegates/eager/util.h
index a9407be071..b7363361be 100644
--- a/tensorflow/contrib/lite/delegates/eager/util.h
+++ b/tensorflow/contrib/lite/delegates/eager/util.h
@@ -23,6 +23,10 @@ limitations under the License.
namespace tflite {
namespace eager {
+// Checks whether the prefix of the custom name indicates the operation is an
+// Eager operation.
+bool IsEagerOp(const char* custom_name);
+
// Converts a tensorflow:Status into a TfLiteStatus. If the original status
// represented an error, reports it using the given 'context'.
TfLiteStatus ConvertStatus(TfLiteContext* context,
diff --git a/tensorflow/contrib/lite/delegates/eager/util_test.cc b/tensorflow/contrib/lite/delegates/eager/util_test.cc
index c4fbf54127..4e92da8d34 100644
--- a/tensorflow/contrib/lite/delegates/eager/util_test.cc
+++ b/tensorflow/contrib/lite/delegates/eager/util_test.cc
@@ -102,6 +102,16 @@ TEST(UtilTest, TypeConversions) {
EXPECT_EQ(TF_BOOL, GetTensorFlowDataType(kTfLiteBool));
}
+TEST(UtilTest, IsEagerOp) {
+ EXPECT_TRUE(IsEagerOp("Eager"));
+ EXPECT_TRUE(IsEagerOp("EagerOp"));
+ EXPECT_FALSE(IsEagerOp("eager"));
+ EXPECT_FALSE(IsEagerOp("Eage"));
+ EXPECT_FALSE(IsEagerOp("OpEager"));
+ EXPECT_FALSE(IsEagerOp(nullptr));
+ EXPECT_FALSE(IsEagerOp(""));
+}
+
} // namespace
} // namespace eager
} // namespace tflite
diff --git a/tensorflow/contrib/lite/interpreter.h b/tensorflow/contrib/lite/interpreter.h
index be149a8cc0..e8301ff507 100644
--- a/tensorflow/contrib/lite/interpreter.h
+++ b/tensorflow/contrib/lite/interpreter.h
@@ -527,12 +527,13 @@ class Interpreter {
TfLiteRegistration** registration);
// WARNING: This is an experimental interface that is subject to change.
- // Gets an TfLiteIntArray* representing the execution plan. The caller owns
- // this memory and must free it with TfLiteIntArrayFree().
+ // Gets an TfLiteIntArray* representing the execution plan. The interpreter
+ // owns this memory and it is only guaranteed to exist during the invocation
+ // of the delegate prepare.
TfLiteStatus GetExecutionPlan(TfLiteIntArray** execution_plan);
// WARNING: This is an experimental interface that is subject to change.
- // Entry point for C node plugin API to get the execution plan
+ // Entry point for C node plugin API to get the execution plan.
static TfLiteStatus GetExecutionPlan(struct TfLiteContext* context,
TfLiteIntArray** execution_plan);