diff options
author | 2018-08-15 17:10:25 -0700 | |
---|---|---|
committer | 2018-08-15 17:17:23 -0700 | |
commit | ec9d2891c3b1a5b9fd851f67a3cccf8ebaff38ad (patch) | |
tree | 010be4c118c80f17478e4536b36e91d57d07c742 /tensorflow/contrib/lite/delegates | |
parent | c3ef5d7034a50ca1b500c6fabea9250d38628884 (diff) |
Internal change
PiperOrigin-RevId: 208910402
Diffstat (limited to 'tensorflow/contrib/lite/delegates')
4 files changed, 37 insertions, 38 deletions
diff --git a/tensorflow/contrib/lite/delegates/eager/BUILD b/tensorflow/contrib/lite/delegates/eager/BUILD index 5a7eb370f6..87486e8814 100644 --- a/tensorflow/contrib/lite/delegates/eager/BUILD +++ b/tensorflow/contrib/lite/delegates/eager/BUILD @@ -16,7 +16,6 @@ cc_library( deps = [ ":util", "//tensorflow/c:c_api_internal", - "//tensorflow/contrib/lite:framework", "//tensorflow/contrib/lite:kernel_api", ] + select({ "//tensorflow:android": [ @@ -55,7 +54,6 @@ cc_library( ":delegate_data", ":kernel", ":util", - "//tensorflow/contrib/lite:framework", "//tensorflow/contrib/lite:kernel_api", "//tensorflow/contrib/lite:util", ] + select({ @@ -119,7 +117,6 @@ cc_library( ":delegate_data", ":util", "@flatbuffers", - "//tensorflow/contrib/lite:framework", "//tensorflow/contrib/lite:kernel_api", "//tensorflow/contrib/lite:string", "//tensorflow/contrib/lite/kernels:kernel_util", @@ -169,7 +166,6 @@ cc_library( deps = [ ":constants", "//tensorflow/c:c_api_internal", - "//tensorflow/contrib/lite:framework", "//tensorflow/contrib/lite:kernel_api", ] + select({ "//tensorflow:android": [ diff --git a/tensorflow/contrib/lite/delegates/eager/delegate.cc b/tensorflow/contrib/lite/delegates/eager/delegate.cc index 8ab768575e..45fc158157 100644 --- a/tensorflow/contrib/lite/delegates/eager/delegate.cc +++ b/tensorflow/contrib/lite/delegates/eager/delegate.cc @@ -83,27 +83,26 @@ TfLiteStatus CopyFromBufferHandle(TfLiteContext* context, } // namespace delegate } // namespace eager -EagerDelegate::EagerDelegate() {} - -EagerDelegate::~EagerDelegate() {} - -TfLiteStatus EagerDelegate::Apply(Interpreter* interpreter) { - if (!delegate_) { - if (!eager::DelegateData::Create(&delegate_data_).ok()) { - fprintf(stderr, "Unable to initialize TensorFlow context.\n"); - return kTfLiteError; - } - - delegate_.reset(new TfLiteDelegate{ - /*data_=*/delegate_data_.get(), - /*nullptr,*/ &eager::delegate::Prepare, - /*CopyFromBufferHandle=*/&eager::delegate::CopyFromBufferHandle, - /*CopyToBufferHandle=*/nullptr, - /*FreeBufferHandle=*/nullptr}); +std::unique_ptr<EagerDelegate> EagerDelegate::Create() { + std::unique_ptr<eager::DelegateData> delegate_data; + if (!eager::DelegateData::Create(&delegate_data).ok()) { + fprintf(stderr, "Unable to initialize TensorFlow context.\n"); + return nullptr; } - return interpreter->ModifyGraphWithDelegate(delegate_.get(), - /*allow_dynamic_tensors=*/true); + return std::unique_ptr<EagerDelegate>( + new EagerDelegate(std::move(delegate_data))); } +EagerDelegate::EagerDelegate(std::unique_ptr<eager::DelegateData> delegate_data) + : TfLiteDelegate{ + /*data_=*/delegate_data.get(), + /*nullptr,*/ &eager::delegate::Prepare, + /*CopyFromBufferHandle=*/&eager::delegate::CopyFromBufferHandle, + /*CopyToBufferHandle=*/nullptr, + /*FreeBufferHandle=*/nullptr}, + delegate_data_(std::move(delegate_data)) {} + +EagerDelegate::~EagerDelegate() {} + } // namespace tflite diff --git a/tensorflow/contrib/lite/delegates/eager/delegate.h b/tensorflow/contrib/lite/delegates/eager/delegate.h index a07002f487..6d15ba47dc 100644 --- a/tensorflow/contrib/lite/delegates/eager/delegate.h +++ b/tensorflow/contrib/lite/delegates/eager/delegate.h @@ -17,7 +17,6 @@ limitations under the License. #include "tensorflow/contrib/lite/context.h" #include "tensorflow/contrib/lite/delegates/eager/delegate_data.h" -#include "tensorflow/contrib/lite/interpreter.h" namespace tflite { @@ -30,24 +29,29 @@ namespace tflite { // interpreters, but it is *not* thread-safe. // // Usage: -// EagerDelegate delegate; +// auto delegate = EagerDelegate::Create(); // ... build interpreter ... // -// delegate.Apply(interpreter); +// if (delegate) { +// interpreter->ModifyGraphWithDelegate( +// delegate.get(), /*allow_dynamic_tensors=*/true); +// } // ... run inference ... // ... destroy interpreter ... // ... destroy delegate ... -class EagerDelegate { +class EagerDelegate : public TfLiteDelegate { public: - EagerDelegate(); - ~EagerDelegate(); + // Creates a delegate that supports TF ops. + // + // If the underyling TF Eager context creation fails, returns null. + static std::unique_ptr<EagerDelegate> Create(); - // Modifies the graph loaded in the interpreter. - TfLiteStatus Apply(Interpreter* interpreter); + ~EagerDelegate(); private: + explicit EagerDelegate(std::unique_ptr<eager::DelegateData> delegate_data); + std::unique_ptr<eager::DelegateData> delegate_data_; - std::unique_ptr<TfLiteDelegate> delegate_; }; } // namespace tflite diff --git a/tensorflow/contrib/lite/delegates/eager/delegate_test.cc b/tensorflow/contrib/lite/delegates/eager/delegate_test.cc index 511a239363..eb47f46c0b 100644 --- a/tensorflow/contrib/lite/delegates/eager/delegate_test.cc +++ b/tensorflow/contrib/lite/delegates/eager/delegate_test.cc @@ -28,21 +28,21 @@ using ::testing::ElementsAre; class DelegateTest : public testing::EagerModelTest { public: DelegateTest() { - // The delegate needs to be constructed before the interpreter because the - // interpreter references data contained in the delegate. - delegate_.reset(new EagerDelegate()); + delegate_ = EagerDelegate::Create(); interpreter_.reset(new Interpreter(&error_reporter_)); } ~DelegateTest() override { // The delegate needs to be destructed after the interpreter because the // interpreter references data contained in the delegate. - delete interpreter_.release(); - delete delegate_.release(); + interpreter_.reset(); + delegate_.reset(); } void ConfigureDelegate() { - CHECK(delegate_->Apply(interpreter_.get()) == kTfLiteOk); + ASSERT_EQ(interpreter_->ModifyGraphWithDelegate( + delegate_.get(), /*allow_dynamic_tensors=*/true), + kTfLiteOk); } private: |