diff options
author | 2018-08-08 13:33:49 -0700 | |
---|---|---|
committer | 2018-08-08 13:38:02 -0700 | |
commit | 54d92a58cad8619460889bd1b1ef34df89d1b612 (patch) | |
tree | 5c61666c2331dfc7e5d2001a679bff761c78325e | |
parent | 1ce16d0f9e5282d74c51f8471d1962ada8c566ed (diff) |
Minor fixes to Eager delegate.
PiperOrigin-RevId: 207937525
-rw-r--r-- | tensorflow/contrib/lite/delegates/eager/BUILD | 2 | ||||
-rw-r--r-- | tensorflow/contrib/lite/delegates/eager/delegate.cc | 38 | ||||
-rw-r--r-- | tensorflow/contrib/lite/delegates/eager/delegate.h | 8 | ||||
-rw-r--r-- | tensorflow/contrib/lite/delegates/eager/util.cc | 6 | ||||
-rw-r--r-- | tensorflow/contrib/lite/delegates/eager/util.h | 4 | ||||
-rw-r--r-- | tensorflow/contrib/lite/delegates/eager/util_test.cc | 10 | ||||
-rw-r--r-- | tensorflow/contrib/lite/interpreter.h | 7 |
7 files changed, 52 insertions, 23 deletions
diff --git a/tensorflow/contrib/lite/delegates/eager/BUILD b/tensorflow/contrib/lite/delegates/eager/BUILD index 332a871446..f21540d524 100644 --- a/tensorflow/contrib/lite/delegates/eager/BUILD +++ b/tensorflow/contrib/lite/delegates/eager/BUILD @@ -50,6 +50,7 @@ cc_library( ":buffer_map", ":delegate_data", ":kernel", + ":util", "//tensorflow/contrib/lite:framework", "//tensorflow/contrib/lite:kernel_api", "//tensorflow/contrib/lite:util", @@ -154,6 +155,7 @@ cc_library( srcs = ["util.cc"], hdrs = ["util.h"], deps = [ + ":constants", "//tensorflow/c:c_api_internal", "//tensorflow/contrib/lite:framework", "//tensorflow/contrib/lite:kernel_api", diff --git a/tensorflow/contrib/lite/delegates/eager/delegate.cc b/tensorflow/contrib/lite/delegates/eager/delegate.cc index 673859da48..7d22b45419 100644 --- a/tensorflow/contrib/lite/delegates/eager/delegate.cc +++ b/tensorflow/contrib/lite/delegates/eager/delegate.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/contrib/lite/context_util.h" #include "tensorflow/contrib/lite/delegates/eager/buffer_map.h" #include "tensorflow/contrib/lite/delegates/eager/kernel.h" +#include "tensorflow/contrib/lite/delegates/eager/util.h" #include "tensorflow/contrib/lite/util.h" #include "tensorflow/core/lib/core/status.h" @@ -27,7 +28,7 @@ namespace eager { namespace delegate { TfLiteStatus Prepare(TfLiteContext* context, TfLiteDelegate* delegate) { - // Get the nodes in the current execution plan. + // Get the nodes in the current execution plan. Interpreter owns this array. TfLiteIntArray* plan; TF_LITE_ENSURE_STATUS(context->GetExecutionPlan(context, &plan)); @@ -39,8 +40,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteDelegate* delegate) { TF_LITE_ENSURE_STATUS(context->GetNodeAndRegistration( context, node_index, &node, ®istration)); - if (registration->custom_name && - strncmp(registration->custom_name, "Eager", 5) == 0) { + if (IsEagerOp(registration->custom_name)) { supported_nodes.push_back(node_index); } } @@ -63,6 +63,7 @@ TfLiteStatus CopyFromBufferHandle(TfLiteDelegate* delegate, BufferMap* buffer_map = reinterpret_cast<DelegateData*>(delegate->data_)->GetBufferMap(); + // TODO(nupurgarg): Use TfLiteContext's ReportError instead of fprinf. if (!buffer_map->HasTensor(buffer_handle)) { fprintf(stderr, "Invalid tensor index %d.\n", buffer_handle); return kTfLiteError; @@ -83,20 +84,27 @@ TfLiteStatus CopyFromBufferHandle(TfLiteDelegate* delegate, } // namespace delegate } // namespace eager -EagerDelegate::EagerDelegate() { - if (!eager::DelegateData::Create(&delegate_data_).ok()) { - fprintf(stderr, "Unable to initialize TensorFlow context.\n"); - return; +EagerDelegate::EagerDelegate() {} + +EagerDelegate::~EagerDelegate() {} + +TfLiteStatus EagerDelegate::Apply(Interpreter* interpreter) { + if (!delegate_) { + if (!eager::DelegateData::Create(&delegate_data_).ok()) { + fprintf(stderr, "Unable to initialize TensorFlow context.\n"); + return kTfLiteError; + } + + delegate_.reset(new TfLiteDelegate{ + /*data_=*/delegate_data_.get(), + /*nullptr,*/ &eager::delegate::Prepare, + /*CopyFromBufferHandle=*/&eager::delegate::CopyFromBufferHandle, + /*CopyToBufferHandle=*/nullptr, + /*FreeBufferHandle=*/nullptr}); } - delegate_.reset(new TfLiteDelegate{ - /*data_=*/delegate_data_.get(), - /*nullptr,*/ &eager::delegate::Prepare, - /*CopyFromBufferHandle=*/&eager::delegate::CopyFromBufferHandle, - /*CopyToBufferHandle=*/nullptr, - /*FreeBufferHandle=*/nullptr}); + return interpreter->ModifyGraphWithDelegate(delegate_.get(), + /*allow_dynamic_tensors=*/true); } -EagerDelegate::~EagerDelegate() {} - } // namespace tflite diff --git a/tensorflow/contrib/lite/delegates/eager/delegate.h b/tensorflow/contrib/lite/delegates/eager/delegate.h index 6259b35931..0defca7c32 100644 --- a/tensorflow/contrib/lite/delegates/eager/delegate.h +++ b/tensorflow/contrib/lite/delegates/eager/delegate.h @@ -30,7 +30,7 @@ namespace tflite { // interpreter. // // Usage: -// EagerDelegate delegate(); +// EagerDelegate delegate; // ... build interpreter ... // // delegate.Apply(interpreter); @@ -42,10 +42,8 @@ class EagerDelegate { EagerDelegate(); ~EagerDelegate(); - TfLiteStatus Apply(Interpreter* interpreter) { - return interpreter->ModifyGraphWithDelegate(delegate_.get(), - /*allow_dynamic_tensors=*/true); - } + // Modifies the graph loaded in the interpreter. + TfLiteStatus Apply(Interpreter* interpreter); private: std::unique_ptr<eager::DelegateData> delegate_data_; diff --git a/tensorflow/contrib/lite/delegates/eager/util.cc b/tensorflow/contrib/lite/delegates/eager/util.cc index 4426c653e6..c8aa0b7f69 100644 --- a/tensorflow/contrib/lite/delegates/eager/util.cc +++ b/tensorflow/contrib/lite/delegates/eager/util.cc @@ -13,10 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/lite/delegates/eager/util.h" +#include "tensorflow/contrib/lite/delegates/eager/constants.h" namespace tflite { namespace eager { +bool IsEagerOp(const char* custom_name) { + return custom_name && strncmp(custom_name, kCustomCodePrefix, + strlen(kCustomCodePrefix)) == 0; +} + TfLiteStatus ConvertStatus(TfLiteContext* context, const tensorflow::Status& status) { if (!status.ok()) { diff --git a/tensorflow/contrib/lite/delegates/eager/util.h b/tensorflow/contrib/lite/delegates/eager/util.h index a9407be071..b7363361be 100644 --- a/tensorflow/contrib/lite/delegates/eager/util.h +++ b/tensorflow/contrib/lite/delegates/eager/util.h @@ -23,6 +23,10 @@ limitations under the License. namespace tflite { namespace eager { +// Checks whether the prefix of the custom name indicates the operation is an +// Eager operation. +bool IsEagerOp(const char* custom_name); + // Converts a tensorflow:Status into a TfLiteStatus. If the original status // represented an error, reports it using the given 'context'. TfLiteStatus ConvertStatus(TfLiteContext* context, diff --git a/tensorflow/contrib/lite/delegates/eager/util_test.cc b/tensorflow/contrib/lite/delegates/eager/util_test.cc index c4fbf54127..4e92da8d34 100644 --- a/tensorflow/contrib/lite/delegates/eager/util_test.cc +++ b/tensorflow/contrib/lite/delegates/eager/util_test.cc @@ -102,6 +102,16 @@ TEST(UtilTest, TypeConversions) { EXPECT_EQ(TF_BOOL, GetTensorFlowDataType(kTfLiteBool)); } +TEST(UtilTest, IsEagerOp) { + EXPECT_TRUE(IsEagerOp("Eager")); + EXPECT_TRUE(IsEagerOp("EagerOp")); + EXPECT_FALSE(IsEagerOp("eager")); + EXPECT_FALSE(IsEagerOp("Eage")); + EXPECT_FALSE(IsEagerOp("OpEager")); + EXPECT_FALSE(IsEagerOp(nullptr)); + EXPECT_FALSE(IsEagerOp("")); +} + } // namespace } // namespace eager } // namespace tflite diff --git a/tensorflow/contrib/lite/interpreter.h b/tensorflow/contrib/lite/interpreter.h index be149a8cc0..e8301ff507 100644 --- a/tensorflow/contrib/lite/interpreter.h +++ b/tensorflow/contrib/lite/interpreter.h @@ -527,12 +527,13 @@ class Interpreter { TfLiteRegistration** registration); // WARNING: This is an experimental interface that is subject to change. - // Gets an TfLiteIntArray* representing the execution plan. The caller owns - // this memory and must free it with TfLiteIntArrayFree(). + // Gets an TfLiteIntArray* representing the execution plan. The interpreter + // owns this memory and it is only guaranteed to exist during the invocation + // of the delegate prepare. TfLiteStatus GetExecutionPlan(TfLiteIntArray** execution_plan); // WARNING: This is an experimental interface that is subject to change. - // Entry point for C node plugin API to get the execution plan + // Entry point for C node plugin API to get the execution plan. static TfLiteStatus GetExecutionPlan(struct TfLiteContext* context, TfLiteIntArray** execution_plan); |