aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib
diff options
context:
space:
mode:
authorGravatar Nupur Garg <nupurgarg@google.com>2018-08-14 18:28:47 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-14 18:32:22 -0700
commit63a49c712edd3b2ee990a9f98b766b24190d3ccb (patch)
tree56460c939b3abcbb2cdb846dc40c60ff89c216a6 /tensorflow/contrib
parent7a3595d2b399770c2fb25d61fc524d63310cc134 (diff)
Adds support for Eager delegate to tflite_diff.
PiperOrigin-RevId: 208752057
Diffstat (limited to 'tensorflow/contrib')
-rw-r--r--tensorflow/contrib/lite/testing/BUILD3
-rw-r--r--tensorflow/contrib/lite/testing/generate_testspec.cc8
-rw-r--r--tensorflow/contrib/lite/testing/tf_driver.cc4
-rw-r--r--tensorflow/contrib/lite/testing/tflite_diff_flags.h27
-rw-r--r--tensorflow/contrib/lite/testing/tflite_diff_util.cc2
-rw-r--r--tensorflow/contrib/lite/testing/tflite_diff_util.h3
-rw-r--r--tensorflow/contrib/lite/testing/tflite_driver.cc16
-rw-r--r--tensorflow/contrib/lite/testing/tflite_driver.h4
8 files changed, 54 insertions, 13 deletions
diff --git a/tensorflow/contrib/lite/testing/BUILD b/tensorflow/contrib/lite/testing/BUILD
index a788d41ba7..89912fd116 100644
--- a/tensorflow/contrib/lite/testing/BUILD
+++ b/tensorflow/contrib/lite/testing/BUILD
@@ -162,11 +162,12 @@ cc_library(
":test_runner",
"//tensorflow/contrib/lite:builtin_op_data",
"//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/delegates/eager:delegate",
"//tensorflow/contrib/lite/kernels:builtin_ops",
],
)
-cc_test(
+tf_cc_test(
name = "tflite_driver_test",
size = "small",
srcs = ["tflite_driver_test.cc"],
diff --git a/tensorflow/contrib/lite/testing/generate_testspec.cc b/tensorflow/contrib/lite/testing/generate_testspec.cc
index f29c188e6c..62cbeccd33 100644
--- a/tensorflow/contrib/lite/testing/generate_testspec.cc
+++ b/tensorflow/contrib/lite/testing/generate_testspec.cc
@@ -114,7 +114,13 @@ bool GenerateTestSpecFromTensorflowModel(
// different set.
std::vector<string> input_values =
GenerateInputValues(input_layer, input_layer_type, input_layer_shape);
- if (input_values.empty()) return false;
+ if (input_values.empty()) {
+ std::cerr << "Unable to generate input values for the TensorFlow model. "
+ "Make sure the correct values are defined for "
+ "input_layer, input_layer_type, and input_layer_shape."
+ << std::endl;
+ return false;
+ }
// Run TensorFlow.
for (int j = 0; j < input_values.size(); j++) {
diff --git a/tensorflow/contrib/lite/testing/tf_driver.cc b/tensorflow/contrib/lite/testing/tf_driver.cc
index ec435ca60d..30381ba028 100644
--- a/tensorflow/contrib/lite/testing/tf_driver.cc
+++ b/tensorflow/contrib/lite/testing/tf_driver.cc
@@ -179,7 +179,9 @@ void TfDriver::Invoke() {
auto status = session_->Run({input_tensors_.begin(), input_tensors_.end()},
output_names_, {}, &output_tensors_);
if (!status.ok()) {
- Invalidate("Failed to run input data on graph");
+ Invalidate(
+ "Failed to run input data on graph. Make sure the correct value is "
+ "defined for the input and output arrays.");
}
}
diff --git a/tensorflow/contrib/lite/testing/tflite_diff_flags.h b/tensorflow/contrib/lite/testing/tflite_diff_flags.h
index 695c2a3de6..3874bc31d7 100644
--- a/tensorflow/contrib/lite/testing/tflite_diff_flags.h
+++ b/tensorflow/contrib/lite/testing/tflite_diff_flags.h
@@ -33,6 +33,7 @@ DiffOptions ParseTfliteDiffFlags(int* argc, char** argv) {
string input_layer_shape;
string output_layer;
int32_t num_runs_per_pass = 100;
+ string delegate;
} values;
std::vector<tensorflow::Flag> flags = {
@@ -42,18 +43,21 @@ DiffOptions ParseTfliteDiffFlags(int* argc, char** argv) {
"Path of tensorflow lite model."),
tensorflow::Flag("input_layer", &values.input_layer,
"Names of input tensors, separated by comma. Example: "
- "input_1,input_2"),
+ "input_1,input_2."),
tensorflow::Flag("input_layer_type", &values.input_layer_type,
"Data types of input tensors, separated by comma. "
- "Example: float,int"),
+ "Example: float,int."),
tensorflow::Flag(
"input_layer_shape", &values.input_layer_shape,
- "Shapes of input tensors, separated by colon. Example: 1,3,4,1:2"),
+ "Shapes of input tensors, separated by colon. Example: 1,3,4,1:2."),
tensorflow::Flag("output_layer", &values.output_layer,
- "Names of output tensors, separated by comma. Example "
- "output_1,output_2"),
+ "Names of output tensors, separated by comma. Example: "
+ "output_1,output_2."),
tensorflow::Flag("num_runs_per_pass", &values.num_runs_per_pass,
- "Number of full runs in each pass."),
+ "[optional] Number of full runs in each pass."),
+ tensorflow::Flag("delegate", &values.delegate,
+ "[optional] Delegate to use for executing ops. Must be "
+ "`{\"\", EAGER}`"),
};
bool no_inputs = *argc == 1;
@@ -61,6 +65,14 @@ DiffOptions ParseTfliteDiffFlags(int* argc, char** argv) {
if (!success || no_inputs || (*argc == 2 && !strcmp(argv[1], "--helpfull"))) {
fprintf(stderr, "%s", tensorflow::Flags::Usage(argv[0], flags).c_str());
return {};
+ } else if (values.tensorflow_model.empty() || values.tflite_model.empty() ||
+ values.input_layer.empty() || values.input_layer_type.empty() ||
+ values.input_layer_shape.empty() || values.output_layer.empty()) {
+ fprintf(stderr, "%s", tensorflow::Flags::Usage(argv[0], flags).c_str());
+ return {};
+ } else if (!(values.delegate == "" || values.delegate == "EAGER")) {
+ fprintf(stderr, "%s", tensorflow::Flags::Usage(argv[0], flags).c_str());
+ return {};
}
return {values.tensorflow_model,
@@ -69,7 +81,8 @@ DiffOptions ParseTfliteDiffFlags(int* argc, char** argv) {
Split<string>(values.input_layer_type, ","),
Split<string>(values.input_layer_shape, ":"),
Split<string>(values.output_layer, ","),
- values.num_runs_per_pass};
+ values.num_runs_per_pass,
+ values.delegate};
}
} // namespace testing
diff --git a/tensorflow/contrib/lite/testing/tflite_diff_util.cc b/tensorflow/contrib/lite/testing/tflite_diff_util.cc
index 19f34c0a51..c6ca796ac2 100644
--- a/tensorflow/contrib/lite/testing/tflite_diff_util.cc
+++ b/tensorflow/contrib/lite/testing/tflite_diff_util.cc
@@ -33,7 +33,7 @@ bool RunDiffTest(const DiffOptions& options, int num_invocations) {
options.input_layer_shape, options.output_layer)) {
return false;
}
- TfLiteDriver tflite_driver(/*use_nnapi=*/true);
+ TfLiteDriver tflite_driver(/*use_nnapi=*/true, options.delegate);
tflite_driver.LoadModel(options.tflite_model);
return tflite::testing::ParseAndRunTests(&tflite_stream, &tflite_driver);
}
diff --git a/tensorflow/contrib/lite/testing/tflite_diff_util.h b/tensorflow/contrib/lite/testing/tflite_diff_util.h
index 4ab2f230fd..f67992139f 100644
--- a/tensorflow/contrib/lite/testing/tflite_diff_util.h
+++ b/tensorflow/contrib/lite/testing/tflite_diff_util.h
@@ -44,6 +44,9 @@ struct DiffOptions {
// each of the passes. The first pass has a single inference, while the
// second pass does multiple inferences back to back.
int num_runs_per_pass;
+ // Path to the delegate library to be loaded in order to execute ops. Must be
+ // `{"", EAGER}`.
+ string delegate;
};
// Run a single TensorFLow Lite diff test with a given options.
diff --git a/tensorflow/contrib/lite/testing/tflite_driver.cc b/tensorflow/contrib/lite/testing/tflite_driver.cc
index 4d08fb5458..71a98a3d56 100644
--- a/tensorflow/contrib/lite/testing/tflite_driver.cc
+++ b/tensorflow/contrib/lite/testing/tflite_driver.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include <iostream>
#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/delegates/eager/delegate.h"
#include "tensorflow/contrib/lite/testing/split.h"
namespace tflite {
@@ -135,7 +136,13 @@ class TfLiteDriver::Expectation {
size_t num_elements_;
};
-TfLiteDriver::TfLiteDriver(bool use_nnapi) : use_nnapi_(use_nnapi) {}
+TfLiteDriver::TfLiteDriver(bool use_nnapi, const string& delegate_name)
+ : use_nnapi_(use_nnapi) {
+ if (delegate_name == "EAGER") {
+ delegate_.reset(new EagerDelegate());
+ }
+}
+
TfLiteDriver::~TfLiteDriver() {}
void TfLiteDriver::AllocateTensors() {
@@ -165,6 +172,13 @@ void TfLiteDriver::LoadModel(const string& bin_file_path) {
}
interpreter_->UseNNAPI(use_nnapi_);
+ if (delegate_) {
+ if (delegate_->Apply(interpreter_.get()) != kTfLiteOk) {
+ Invalidate("Unable to the build graph using the delegate");
+ return;
+ }
+ }
+
must_allocate_tensors_ = true;
}
diff --git a/tensorflow/contrib/lite/testing/tflite_driver.h b/tensorflow/contrib/lite/testing/tflite_driver.h
index 5493ba3631..aed35f877d 100644
--- a/tensorflow/contrib/lite/testing/tflite_driver.h
+++ b/tensorflow/contrib/lite/testing/tflite_driver.h
@@ -17,6 +17,7 @@ limitations under the License.
#include <map>
+#include "tensorflow/contrib/lite/delegates/eager/delegate.h"
#include "tensorflow/contrib/lite/interpreter.h"
#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/model.h"
@@ -28,7 +29,7 @@ namespace testing {
// A test runner that feeds inputs into TF Lite and verifies its outputs.
class TfLiteDriver : public TestRunner {
public:
- explicit TfLiteDriver(bool use_nnapi);
+ explicit TfLiteDriver(bool use_nnapi, const string& delegate = "");
~TfLiteDriver() override;
void LoadModel(const string& bin_file_path) override;
@@ -52,6 +53,7 @@ class TfLiteDriver : public TestRunner {
class Expectation;
+ std::unique_ptr<EagerDelegate> delegate_;
bool use_nnapi_ = false;
std::unique_ptr<FlatBufferModel> model_;
std::unique_ptr<Interpreter> interpreter_;