aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/testing/tflite_driver.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/testing/tflite_driver.h')
-rw-r--r--tensorflow/contrib/lite/testing/tflite_driver.h4
1 files changed, 3 insertions, 1 deletions
diff --git a/tensorflow/contrib/lite/testing/tflite_driver.h b/tensorflow/contrib/lite/testing/tflite_driver.h
index 5493ba3631..aed35f877d 100644
--- a/tensorflow/contrib/lite/testing/tflite_driver.h
+++ b/tensorflow/contrib/lite/testing/tflite_driver.h
@@ -17,6 +17,7 @@ limitations under the License.
#include <map>
+#include "tensorflow/contrib/lite/delegates/eager/delegate.h"
#include "tensorflow/contrib/lite/interpreter.h"
#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/model.h"
@@ -28,7 +29,7 @@ namespace testing {
// A test runner that feeds inputs into TF Lite and verifies its outputs.
class TfLiteDriver : public TestRunner {
public:
- explicit TfLiteDriver(bool use_nnapi);
+ explicit TfLiteDriver(bool use_nnapi, const string& delegate = "");
~TfLiteDriver() override;
void LoadModel(const string& bin_file_path) override;
@@ -52,6 +53,7 @@ class TfLiteDriver : public TestRunner {
class Expectation;
+ std::unique_ptr<EagerDelegate> delegate_;
bool use_nnapi_ = false;
std::unique_ptr<FlatBufferModel> model_;
std::unique_ptr<Interpreter> interpreter_;