aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/testing/tflite_driver.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/testing/tflite_driver.cc')
-rw-r--r--tensorflow/contrib/lite/testing/tflite_driver.cc16
1 files changed, 15 insertions, 1 deletions
diff --git a/tensorflow/contrib/lite/testing/tflite_driver.cc b/tensorflow/contrib/lite/testing/tflite_driver.cc
index 4d08fb5458..71a98a3d56 100644
--- a/tensorflow/contrib/lite/testing/tflite_driver.cc
+++ b/tensorflow/contrib/lite/testing/tflite_driver.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include <iostream>
#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/delegates/eager/delegate.h"
#include "tensorflow/contrib/lite/testing/split.h"
namespace tflite {
@@ -135,7 +136,13 @@ class TfLiteDriver::Expectation {
size_t num_elements_;
};
-TfLiteDriver::TfLiteDriver(bool use_nnapi) : use_nnapi_(use_nnapi) {}
+TfLiteDriver::TfLiteDriver(bool use_nnapi, const string& delegate_name)
+ : use_nnapi_(use_nnapi) {
+ if (delegate_name == "EAGER") {
+ delegate_.reset(new EagerDelegate());
+ }
+}
+
TfLiteDriver::~TfLiteDriver() {}
void TfLiteDriver::AllocateTensors() {
@@ -165,6 +172,13 @@ void TfLiteDriver::LoadModel(const string& bin_file_path) {
}
interpreter_->UseNNAPI(use_nnapi_);
+ if (delegate_) {
+ if (delegate_->Apply(interpreter_.get()) != kTfLiteOk) {
+ Invalidate("Unable to the build graph using the delegate");
+ return;
+ }
+ }
+
must_allocate_tensors_ = true;
}