diff options
Diffstat (limited to 'tensorflow/contrib/lite/testing/tflite_driver.cc')
-rw-r--r-- | tensorflow/contrib/lite/testing/tflite_driver.cc | 16 |
1 files changed, 15 insertions, 1 deletions
diff --git a/tensorflow/contrib/lite/testing/tflite_driver.cc b/tensorflow/contrib/lite/testing/tflite_driver.cc index 4d08fb5458..71a98a3d56 100644 --- a/tensorflow/contrib/lite/testing/tflite_driver.cc +++ b/tensorflow/contrib/lite/testing/tflite_driver.cc @@ -17,6 +17,7 @@ limitations under the License. #include <iostream> #include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/delegates/eager/delegate.h" #include "tensorflow/contrib/lite/testing/split.h" namespace tflite { @@ -135,7 +136,13 @@ class TfLiteDriver::Expectation { size_t num_elements_; }; -TfLiteDriver::TfLiteDriver(bool use_nnapi) : use_nnapi_(use_nnapi) {} +TfLiteDriver::TfLiteDriver(bool use_nnapi, const string& delegate_name) + : use_nnapi_(use_nnapi) { + if (delegate_name == "EAGER") { + delegate_.reset(new EagerDelegate()); + } +} + TfLiteDriver::~TfLiteDriver() {} void TfLiteDriver::AllocateTensors() { @@ -165,6 +172,13 @@ void TfLiteDriver::LoadModel(const string& bin_file_path) { } interpreter_->UseNNAPI(use_nnapi_); + if (delegate_) { + if (delegate_->Apply(interpreter_.get()) != kTfLiteOk) { + Invalidate("Unable to the build graph using the delegate"); + return; + } + } + must_allocate_tensors_ = true; } |