aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/delegates
diff options
context:
space:
mode:
authorGravatar Jared Duke <jdduke@google.com>2018-08-15 17:10:25 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-15 17:17:23 -0700
commitec9d2891c3b1a5b9fd851f67a3cccf8ebaff38ad (patch)
tree010be4c118c80f17478e4536b36e91d57d07c742 /tensorflow/contrib/lite/delegates
parentc3ef5d7034a50ca1b500c6fabea9250d38628884 (diff)
Internal change
PiperOrigin-RevId: 208910402
Diffstat (limited to 'tensorflow/contrib/lite/delegates')
-rw-r--r--tensorflow/contrib/lite/delegates/eager/BUILD4
-rw-r--r--tensorflow/contrib/lite/delegates/eager/delegate.cc37
-rw-r--r--tensorflow/contrib/lite/delegates/eager/delegate.h22
-rw-r--r--tensorflow/contrib/lite/delegates/eager/delegate_test.cc12
4 files changed, 37 insertions, 38 deletions
diff --git a/tensorflow/contrib/lite/delegates/eager/BUILD b/tensorflow/contrib/lite/delegates/eager/BUILD
index 5a7eb370f6..87486e8814 100644
--- a/tensorflow/contrib/lite/delegates/eager/BUILD
+++ b/tensorflow/contrib/lite/delegates/eager/BUILD
@@ -16,7 +16,6 @@ cc_library(
deps = [
":util",
"//tensorflow/c:c_api_internal",
- "//tensorflow/contrib/lite:framework",
"//tensorflow/contrib/lite:kernel_api",
] + select({
"//tensorflow:android": [
@@ -55,7 +54,6 @@ cc_library(
":delegate_data",
":kernel",
":util",
- "//tensorflow/contrib/lite:framework",
"//tensorflow/contrib/lite:kernel_api",
"//tensorflow/contrib/lite:util",
] + select({
@@ -119,7 +117,6 @@ cc_library(
":delegate_data",
":util",
"@flatbuffers",
- "//tensorflow/contrib/lite:framework",
"//tensorflow/contrib/lite:kernel_api",
"//tensorflow/contrib/lite:string",
"//tensorflow/contrib/lite/kernels:kernel_util",
@@ -169,7 +166,6 @@ cc_library(
deps = [
":constants",
"//tensorflow/c:c_api_internal",
- "//tensorflow/contrib/lite:framework",
"//tensorflow/contrib/lite:kernel_api",
] + select({
"//tensorflow:android": [
diff --git a/tensorflow/contrib/lite/delegates/eager/delegate.cc b/tensorflow/contrib/lite/delegates/eager/delegate.cc
index 8ab768575e..45fc158157 100644
--- a/tensorflow/contrib/lite/delegates/eager/delegate.cc
+++ b/tensorflow/contrib/lite/delegates/eager/delegate.cc
@@ -83,27 +83,26 @@ TfLiteStatus CopyFromBufferHandle(TfLiteContext* context,
} // namespace delegate
} // namespace eager
-EagerDelegate::EagerDelegate() {}
-
-EagerDelegate::~EagerDelegate() {}
-
-TfLiteStatus EagerDelegate::Apply(Interpreter* interpreter) {
- if (!delegate_) {
- if (!eager::DelegateData::Create(&delegate_data_).ok()) {
- fprintf(stderr, "Unable to initialize TensorFlow context.\n");
- return kTfLiteError;
- }
-
- delegate_.reset(new TfLiteDelegate{
- /*data_=*/delegate_data_.get(),
- /*nullptr,*/ &eager::delegate::Prepare,
- /*CopyFromBufferHandle=*/&eager::delegate::CopyFromBufferHandle,
- /*CopyToBufferHandle=*/nullptr,
- /*FreeBufferHandle=*/nullptr});
+std::unique_ptr<EagerDelegate> EagerDelegate::Create() {
+ std::unique_ptr<eager::DelegateData> delegate_data;
+ if (!eager::DelegateData::Create(&delegate_data).ok()) {
+ fprintf(stderr, "Unable to initialize TensorFlow context.\n");
+ return nullptr;
}
- return interpreter->ModifyGraphWithDelegate(delegate_.get(),
- /*allow_dynamic_tensors=*/true);
+ return std::unique_ptr<EagerDelegate>(
+ new EagerDelegate(std::move(delegate_data)));
}
+EagerDelegate::EagerDelegate(std::unique_ptr<eager::DelegateData> delegate_data)
+ : TfLiteDelegate{
+ /*data_=*/delegate_data.get(),
+ /*nullptr,*/ &eager::delegate::Prepare,
+ /*CopyFromBufferHandle=*/&eager::delegate::CopyFromBufferHandle,
+ /*CopyToBufferHandle=*/nullptr,
+ /*FreeBufferHandle=*/nullptr},
+ delegate_data_(std::move(delegate_data)) {}
+
+EagerDelegate::~EagerDelegate() {}
+
} // namespace tflite
diff --git a/tensorflow/contrib/lite/delegates/eager/delegate.h b/tensorflow/contrib/lite/delegates/eager/delegate.h
index a07002f487..6d15ba47dc 100644
--- a/tensorflow/contrib/lite/delegates/eager/delegate.h
+++ b/tensorflow/contrib/lite/delegates/eager/delegate.h
@@ -17,7 +17,6 @@ limitations under the License.
#include "tensorflow/contrib/lite/context.h"
#include "tensorflow/contrib/lite/delegates/eager/delegate_data.h"
-#include "tensorflow/contrib/lite/interpreter.h"
namespace tflite {
@@ -30,24 +29,29 @@ namespace tflite {
// interpreters, but it is *not* thread-safe.
//
// Usage:
-// EagerDelegate delegate;
+// auto delegate = EagerDelegate::Create();
// ... build interpreter ...
//
-// delegate.Apply(interpreter);
+// if (delegate) {
+// interpreter->ModifyGraphWithDelegate(
+// delegate.get(), /*allow_dynamic_tensors=*/true);
+// }
// ... run inference ...
// ... destroy interpreter ...
// ... destroy delegate ...
-class EagerDelegate {
+class EagerDelegate : public TfLiteDelegate {
public:
- EagerDelegate();
- ~EagerDelegate();
+ // Creates a delegate that supports TF ops.
+ //
+ // If the underyling TF Eager context creation fails, returns null.
+ static std::unique_ptr<EagerDelegate> Create();
- // Modifies the graph loaded in the interpreter.
- TfLiteStatus Apply(Interpreter* interpreter);
+ ~EagerDelegate();
private:
+ explicit EagerDelegate(std::unique_ptr<eager::DelegateData> delegate_data);
+
std::unique_ptr<eager::DelegateData> delegate_data_;
- std::unique_ptr<TfLiteDelegate> delegate_;
};
} // namespace tflite
diff --git a/tensorflow/contrib/lite/delegates/eager/delegate_test.cc b/tensorflow/contrib/lite/delegates/eager/delegate_test.cc
index 511a239363..eb47f46c0b 100644
--- a/tensorflow/contrib/lite/delegates/eager/delegate_test.cc
+++ b/tensorflow/contrib/lite/delegates/eager/delegate_test.cc
@@ -28,21 +28,21 @@ using ::testing::ElementsAre;
class DelegateTest : public testing::EagerModelTest {
public:
DelegateTest() {
- // The delegate needs to be constructed before the interpreter because the
- // interpreter references data contained in the delegate.
- delegate_.reset(new EagerDelegate());
+ delegate_ = EagerDelegate::Create();
interpreter_.reset(new Interpreter(&error_reporter_));
}
~DelegateTest() override {
// The delegate needs to be destructed after the interpreter because the
// interpreter references data contained in the delegate.
- delete interpreter_.release();
- delete delegate_.release();
+ interpreter_.reset();
+ delegate_.reset();
}
void ConfigureDelegate() {
- CHECK(delegate_->Apply(interpreter_.get()) == kTfLiteOk);
+ ASSERT_EQ(interpreter_->ModifyGraphWithDelegate(
+ delegate_.get(), /*allow_dynamic_tensors=*/true),
+ kTfLiteOk);
}
private: