aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/testing
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/testing')
-rw-r--r--tensorflow/contrib/lite/testing/BUILD3
-rw-r--r--tensorflow/contrib/lite/testing/generate_examples.py69
-rw-r--r--tensorflow/contrib/lite/testing/generate_testspec.cc8
-rw-r--r--tensorflow/contrib/lite/testing/generated_examples_zip_test.cc94
-rw-r--r--tensorflow/contrib/lite/testing/tf_driver.cc4
-rw-r--r--tensorflow/contrib/lite/testing/tflite_diff_flags.h27
-rw-r--r--tensorflow/contrib/lite/testing/tflite_diff_util.cc2
-rw-r--r--tensorflow/contrib/lite/testing/tflite_diff_util.h3
-rw-r--r--tensorflow/contrib/lite/testing/tflite_driver.cc16
-rw-r--r--tensorflow/contrib/lite/testing/tflite_driver.h4
10 files changed, 186 insertions, 44 deletions
diff --git a/tensorflow/contrib/lite/testing/BUILD b/tensorflow/contrib/lite/testing/BUILD
index a788d41ba7..89912fd116 100644
--- a/tensorflow/contrib/lite/testing/BUILD
+++ b/tensorflow/contrib/lite/testing/BUILD
@@ -162,11 +162,12 @@ cc_library(
":test_runner",
"//tensorflow/contrib/lite:builtin_op_data",
"//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/delegates/eager:delegate",
"//tensorflow/contrib/lite/kernels:builtin_ops",
],
)
-cc_test(
+tf_cc_test(
name = "tflite_driver_test",
size = "small",
srcs = ["tflite_driver_test.cc"],
diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py
index 52ef0d5b86..9dd5c8ae44 100644
--- a/tensorflow/contrib/lite/testing/generate_examples.py
+++ b/tensorflow/contrib/lite/testing/generate_examples.py
@@ -1255,6 +1255,75 @@ def make_conv_tests(zip_path):
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+# Note: This is a regression test for a bug (b/112436267) that Toco incorrectly
+# fuses weights when multiple Conv2D/FULLY_CONNECTED ops share the same constant
+# weight tensor.
+def make_conv_with_shared_weights_tests(zip_path):
+ """Make a test where 2 Conv ops shared the same constant weight tensor."""
+
+ test_parameters = [{
+ "input_shape": [[1, 10, 10, 3]],
+ "filter_shape": [[3, 3]],
+ "strides": [[1, 1, 1, 1]],
+ "dilations": [[1, 1, 1, 1]],
+ "padding": ["SAME"],
+ "data_format": ["NHWC"],
+ "channel_multiplier": [1],
+ }]
+
+ def get_tensor_shapes(parameters):
+ input_shape = parameters["input_shape"]
+ filter_size = parameters["filter_shape"]
+ filter_shape = filter_size + [
+ input_shape[3], parameters["channel_multiplier"]
+ ]
+ return [input_shape, filter_shape]
+
+ def build_graph(parameters):
+ """Build a conv graph given `parameters`."""
+ input_shape, filter_shape = get_tensor_shapes(parameters)
+ input_tensor = tf.placeholder(
+ dtype=tf.float32, name="input", shape=input_shape)
+
+ # Construct a constant weights tensor which will be used by both Conv2D.
+ filter_tensor = tf.constant(
+ create_tensor_data(np.float32, filter_shape), dtype=tf.float32)
+ input_tensors = [input_tensor]
+
+ # Construct 2 Conv2D operations which use exactly the same input and
+ # weights.
+ result1 = tf.nn.conv2d(
+ input_tensor,
+ filter_tensor,
+ strides=parameters["strides"],
+ dilations=parameters["dilations"],
+ padding=parameters["padding"],
+ data_format=parameters["data_format"])
+ result2 = tf.nn.conv2d(
+ input_tensor,
+ filter_tensor,
+ strides=parameters["strides"],
+ dilations=parameters["dilations"],
+ padding=parameters["padding"],
+ data_format=parameters["data_format"])
+ # Add MUL ops after Conv2D ops. These MUL ops should be fused into the
+ # weights of Conv2D.
+ result1 = result1 * 2
+ result2 = result2 * 3
+ # Add the 2 results up.
+ out = result1 + result2
+ return input_tensors, [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ # Build list of input values either containing 1 tensor (input) or 2 tensors
+ # (input, filter) based on whether filter is constant or variable input.
+ input_shape, unused_filter_shape = get_tensor_shapes(parameters)
+ values = [create_tensor_data(np.float32, input_shape)]
+ return values, sess.run(outputs, feed_dict=dict(zip(inputs, values)))
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
def make_depthwiseconv_tests(zip_path):
"""Make a set of tests to do convolution."""
diff --git a/tensorflow/contrib/lite/testing/generate_testspec.cc b/tensorflow/contrib/lite/testing/generate_testspec.cc
index f29c188e6c..62cbeccd33 100644
--- a/tensorflow/contrib/lite/testing/generate_testspec.cc
+++ b/tensorflow/contrib/lite/testing/generate_testspec.cc
@@ -114,7 +114,13 @@ bool GenerateTestSpecFromTensorflowModel(
// different set.
std::vector<string> input_values =
GenerateInputValues(input_layer, input_layer_type, input_layer_shape);
- if (input_values.empty()) return false;
+ if (input_values.empty()) {
+ std::cerr << "Unable to generate input values for the TensorFlow model. "
+ "Make sure the correct values are defined for "
+ "input_layer, input_layer_type, and input_layer_shape."
+ << std::endl;
+ return false;
+ }
// Run TensorFlow.
for (int j = 0; j < input_values.size(); j++) {
diff --git a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
index e475f256c0..e67fee2a1c 100644
--- a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
+++ b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
@@ -33,13 +33,18 @@ namespace testing {
namespace {
bool FLAGS_ignore_known_bugs = true;
-// TODO(b/71769302) zip_files_dir should have a more accurate default, if
-// possible
-string* FLAGS_zip_file_path = new string("./");
+// As archive file names are test-specific, no default is possible.
+//
+// This test supports input as both zip and tar, as a stock android image does
+// not have unzip but does have tar.
+string* FLAGS_zip_file_path = new string;
+string* FLAGS_tar_file_path = new string;
#ifndef __ANDROID__
string* FLAGS_unzip_binary_path = new string("/usr/bin/unzip");
+string* FLAGS_tar_binary_path = new string("/bin/tar");
#else
string* FLAGS_unzip_binary_path = new string("/system/bin/unzip");
+string* FLAGS_tar_binary_path = new string("/system/bin/tar");
#endif
bool FLAGS_use_nnapi = false;
bool FLAGS_ignore_unsupported_nnapi = false;
@@ -98,11 +103,11 @@ std::map<string, string> kBrokenTests = {
"77546240"},
};
-// Allows test data to be unzipped into a temporary directory and makes
+// Allows test data to be unarchived into a temporary directory and makes
// sure those temporary directories are removed later.
-class ZipEnvironment : public ::testing::Environment {
+class ArchiveEnvironment : public ::testing::Environment {
public:
- ~ZipEnvironment() override {}
+ ~ArchiveEnvironment() override {}
// Delete all temporary directories on teardown.
void TearDown() override {
@@ -114,15 +119,26 @@ class ZipEnvironment : public ::testing::Environment {
temporary_directories_.clear();
}
- // Unzip `zip` file into a new temporary directory `out_dir`.
- tensorflow::Status UnZip(const string& zip, string* out_dir) {
+ // Unarchive `archive` file into a new temporary directory `out_dir`.
+ tensorflow::Status UnArchive(const string& zip, const string& tar,
+ string* out_dir) {
string dir;
TF_CHECK_OK(MakeTemporaryDirectory(&dir));
tensorflow::SubProcess proc;
- string unzip_binary = *FLAGS_unzip_binary_path;
- TF_CHECK_OK(env->FileExists(unzip_binary));
- TF_CHECK_OK(env->FileExists(zip));
- proc.SetProgram(unzip_binary, {"unzip", "-d", dir, zip});
+ if (!zip.empty()) {
+ string unzip_binary = *FLAGS_unzip_binary_path;
+ TF_CHECK_OK(env->FileExists(unzip_binary));
+ TF_CHECK_OK(env->FileExists(zip));
+ proc.SetProgram(unzip_binary, {"unzip", "-d", dir, zip});
+ } else {
+ string tar_binary = *FLAGS_tar_binary_path;
+ TF_CHECK_OK(env->FileExists(tar_binary));
+ TF_CHECK_OK(env->FileExists(tar));
+ // 'o' needs to be explicitly set on Android so that
+ // untarring works as non-root (otherwise tries to chown
+ // files, which fails)
+ proc.SetProgram(tar_binary, {"tar", "xfo", tar, "-C", dir});
+ }
proc.SetChannelAction(tensorflow::CHAN_STDOUT, tensorflow::ACTION_PIPE);
proc.SetChannelAction(tensorflow::CHAN_STDERR, tensorflow::ACTION_PIPE);
if (!proc.Start())
@@ -156,15 +172,15 @@ class ZipEnvironment : public ::testing::Environment {
std::vector<string> temporary_directories_;
};
-// Return the singleton zip_environment.
-ZipEnvironment* zip_environment() {
- static ZipEnvironment* env = new ZipEnvironment;
+// Return the singleton archive_environment.
+ArchiveEnvironment* archive_environment() {
+ static ArchiveEnvironment* env = new ArchiveEnvironment;
return env;
}
-// Read the manifest.txt out of the unarchived zip file. Specifically
+// Read the manifest.txt out of the unarchived archive file. Specifically
// `original_file` is the original zip file for error messages. `dir` is
-// the temporary directory where the zip file has been unarchived and
+// the temporary directory where the archive file has been unarchived and
// `test_paths` is the list of test prefixes that were in the manifest.
// Note, it is an error for a manifest to contain no tests.
tensorflow::Status ReadManifest(const string& original_file, const string& dir,
@@ -190,12 +206,22 @@ tensorflow::Status ReadManifest(const string& original_file, const string& dir,
return tensorflow::Status::OK();
}
-// Get a list of tests from a zip file `zip_file_name`.
-std::vector<string> UnarchiveZipAndFindTestNames(const string& zip_file) {
+// Get a list of tests from either zip or tar file
+std::vector<string> UnarchiveAndFindTestNames(const string& zip_file,
+ const string& tar_file) {
+ if (zip_file.empty() && tar_file.empty()) {
+ TF_CHECK_OK(tensorflow::Status(tensorflow::error::UNKNOWN,
+ "Neither zip_file nor tar_file was given"));
+ }
string decompress_tmp_dir;
- TF_CHECK_OK(zip_environment()->UnZip(zip_file, &decompress_tmp_dir));
+ TF_CHECK_OK(archive_environment()->UnArchive(zip_file, tar_file,
+ &decompress_tmp_dir));
std::vector<string> stuff;
- TF_CHECK_OK(ReadManifest(zip_file, decompress_tmp_dir, &stuff));
+ if (!zip_file.empty()) {
+ TF_CHECK_OK(ReadManifest(zip_file, decompress_tmp_dir, &stuff));
+ } else {
+ TF_CHECK_OK(ReadManifest(tar_file, decompress_tmp_dir, &stuff));
+ }
return stuff;
}
@@ -223,8 +249,7 @@ TEST_P(OpsTest, RunZipTests) {
string message = test_driver.GetErrorMessage();
if (bug_number.empty()) {
if (FLAGS_use_nnapi && FLAGS_ignore_unsupported_nnapi && !result) {
- EXPECT_EQ(message, string("Failed to invoke NNAPI interpreter"))
- << message;
+ EXPECT_EQ(message, string("Failed to invoke interpreter")) << message;
} else {
EXPECT_TRUE(result) << message;
}
@@ -256,27 +281,34 @@ struct ZipPathParamName {
}
};
-INSTANTIATE_TEST_CASE_P(
- tests, OpsTest,
- ::testing::ValuesIn(UnarchiveZipAndFindTestNames(*FLAGS_zip_file_path)),
- ZipPathParamName());
+INSTANTIATE_TEST_CASE_P(tests, OpsTest,
+ ::testing::ValuesIn(UnarchiveAndFindTestNames(
+ *FLAGS_zip_file_path, *FLAGS_tar_file_path)),
+ ZipPathParamName());
} // namespace testing
} // namespace tflite
int main(int argc, char** argv) {
- ::testing::AddGlobalTestEnvironment(tflite::testing::zip_environment());
+ ::testing::AddGlobalTestEnvironment(tflite::testing::archive_environment());
std::vector<tensorflow::Flag> flags = {
tensorflow::Flag(
"ignore_known_bugs", &tflite::testing::FLAGS_ignore_known_bugs,
"If a particular model is affected by a known bug, the "
"corresponding test should expect the outputs to not match."),
- tensorflow::Flag("zip_file_path", tflite::testing::FLAGS_zip_file_path,
- "Required: Location of the test zip file."),
+ tensorflow::Flag(
+ "tar_file_path", tflite::testing::FLAGS_tar_file_path,
+ "Required (or zip_file_path): Location of the test tar file."),
+ tensorflow::Flag(
+ "zip_file_path", tflite::testing::FLAGS_zip_file_path,
+ "Required (or tar_file_path): Location of the test zip file."),
tensorflow::Flag("unzip_binary_path",
tflite::testing::FLAGS_unzip_binary_path,
- "Required: Location of a suitable unzip binary."),
+ "Location of a suitable unzip binary."),
+ tensorflow::Flag("tar_binary_path",
+ tflite::testing::FLAGS_tar_binary_path,
+ "Location of a suitable tar binary."),
tensorflow::Flag("use_nnapi", &tflite::testing::FLAGS_use_nnapi,
"Whether to enable the NNAPI delegate"),
tensorflow::Flag("ignore_unsupported_nnapi",
diff --git a/tensorflow/contrib/lite/testing/tf_driver.cc b/tensorflow/contrib/lite/testing/tf_driver.cc
index ec435ca60d..30381ba028 100644
--- a/tensorflow/contrib/lite/testing/tf_driver.cc
+++ b/tensorflow/contrib/lite/testing/tf_driver.cc
@@ -179,7 +179,9 @@ void TfDriver::Invoke() {
auto status = session_->Run({input_tensors_.begin(), input_tensors_.end()},
output_names_, {}, &output_tensors_);
if (!status.ok()) {
- Invalidate("Failed to run input data on graph");
+ Invalidate(
+ "Failed to run input data on graph. Make sure the correct value is "
+ "defined for the input and output arrays.");
}
}
diff --git a/tensorflow/contrib/lite/testing/tflite_diff_flags.h b/tensorflow/contrib/lite/testing/tflite_diff_flags.h
index 695c2a3de6..3874bc31d7 100644
--- a/tensorflow/contrib/lite/testing/tflite_diff_flags.h
+++ b/tensorflow/contrib/lite/testing/tflite_diff_flags.h
@@ -33,6 +33,7 @@ DiffOptions ParseTfliteDiffFlags(int* argc, char** argv) {
string input_layer_shape;
string output_layer;
int32_t num_runs_per_pass = 100;
+ string delegate;
} values;
std::vector<tensorflow::Flag> flags = {
@@ -42,18 +43,21 @@ DiffOptions ParseTfliteDiffFlags(int* argc, char** argv) {
"Path of tensorflow lite model."),
tensorflow::Flag("input_layer", &values.input_layer,
"Names of input tensors, separated by comma. Example: "
- "input_1,input_2"),
+ "input_1,input_2."),
tensorflow::Flag("input_layer_type", &values.input_layer_type,
"Data types of input tensors, separated by comma. "
- "Example: float,int"),
+ "Example: float,int."),
tensorflow::Flag(
"input_layer_shape", &values.input_layer_shape,
- "Shapes of input tensors, separated by colon. Example: 1,3,4,1:2"),
+ "Shapes of input tensors, separated by colon. Example: 1,3,4,1:2."),
tensorflow::Flag("output_layer", &values.output_layer,
- "Names of output tensors, separated by comma. Example "
- "output_1,output_2"),
+ "Names of output tensors, separated by comma. Example: "
+ "output_1,output_2."),
tensorflow::Flag("num_runs_per_pass", &values.num_runs_per_pass,
- "Number of full runs in each pass."),
+ "[optional] Number of full runs in each pass."),
+ tensorflow::Flag("delegate", &values.delegate,
+ "[optional] Delegate to use for executing ops. Must be "
+ "`{\"\", EAGER}`"),
};
bool no_inputs = *argc == 1;
@@ -61,6 +65,14 @@ DiffOptions ParseTfliteDiffFlags(int* argc, char** argv) {
if (!success || no_inputs || (*argc == 2 && !strcmp(argv[1], "--helpfull"))) {
fprintf(stderr, "%s", tensorflow::Flags::Usage(argv[0], flags).c_str());
return {};
+ } else if (values.tensorflow_model.empty() || values.tflite_model.empty() ||
+ values.input_layer.empty() || values.input_layer_type.empty() ||
+ values.input_layer_shape.empty() || values.output_layer.empty()) {
+ fprintf(stderr, "%s", tensorflow::Flags::Usage(argv[0], flags).c_str());
+ return {};
+ } else if (!(values.delegate == "" || values.delegate == "EAGER")) {
+ fprintf(stderr, "%s", tensorflow::Flags::Usage(argv[0], flags).c_str());
+ return {};
}
return {values.tensorflow_model,
@@ -69,7 +81,8 @@ DiffOptions ParseTfliteDiffFlags(int* argc, char** argv) {
Split<string>(values.input_layer_type, ","),
Split<string>(values.input_layer_shape, ":"),
Split<string>(values.output_layer, ","),
- values.num_runs_per_pass};
+ values.num_runs_per_pass,
+ values.delegate};
}
} // namespace testing
diff --git a/tensorflow/contrib/lite/testing/tflite_diff_util.cc b/tensorflow/contrib/lite/testing/tflite_diff_util.cc
index 19f34c0a51..c6ca796ac2 100644
--- a/tensorflow/contrib/lite/testing/tflite_diff_util.cc
+++ b/tensorflow/contrib/lite/testing/tflite_diff_util.cc
@@ -33,7 +33,7 @@ bool RunDiffTest(const DiffOptions& options, int num_invocations) {
options.input_layer_shape, options.output_layer)) {
return false;
}
- TfLiteDriver tflite_driver(/*use_nnapi=*/true);
+ TfLiteDriver tflite_driver(/*use_nnapi=*/true, options.delegate);
tflite_driver.LoadModel(options.tflite_model);
return tflite::testing::ParseAndRunTests(&tflite_stream, &tflite_driver);
}
diff --git a/tensorflow/contrib/lite/testing/tflite_diff_util.h b/tensorflow/contrib/lite/testing/tflite_diff_util.h
index 4ab2f230fd..f67992139f 100644
--- a/tensorflow/contrib/lite/testing/tflite_diff_util.h
+++ b/tensorflow/contrib/lite/testing/tflite_diff_util.h
@@ -44,6 +44,9 @@ struct DiffOptions {
// each of the passes. The first pass has a single inference, while the
// second pass does multiple inferences back to back.
int num_runs_per_pass;
+ // Path to the delegate library to be loaded in order to execute ops. Must be
+ // `{"", EAGER}`.
+ string delegate;
};
// Run a single TensorFLow Lite diff test with a given options.
diff --git a/tensorflow/contrib/lite/testing/tflite_driver.cc b/tensorflow/contrib/lite/testing/tflite_driver.cc
index 4d08fb5458..71a98a3d56 100644
--- a/tensorflow/contrib/lite/testing/tflite_driver.cc
+++ b/tensorflow/contrib/lite/testing/tflite_driver.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include <iostream>
#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/delegates/eager/delegate.h"
#include "tensorflow/contrib/lite/testing/split.h"
namespace tflite {
@@ -135,7 +136,13 @@ class TfLiteDriver::Expectation {
size_t num_elements_;
};
-TfLiteDriver::TfLiteDriver(bool use_nnapi) : use_nnapi_(use_nnapi) {}
+TfLiteDriver::TfLiteDriver(bool use_nnapi, const string& delegate_name)
+ : use_nnapi_(use_nnapi) {
+ if (delegate_name == "EAGER") {
+ delegate_.reset(new EagerDelegate());
+ }
+}
+
TfLiteDriver::~TfLiteDriver() {}
void TfLiteDriver::AllocateTensors() {
@@ -165,6 +172,13 @@ void TfLiteDriver::LoadModel(const string& bin_file_path) {
}
interpreter_->UseNNAPI(use_nnapi_);
+ if (delegate_) {
+ if (delegate_->Apply(interpreter_.get()) != kTfLiteOk) {
+ Invalidate("Unable to the build graph using the delegate");
+ return;
+ }
+ }
+
must_allocate_tensors_ = true;
}
diff --git a/tensorflow/contrib/lite/testing/tflite_driver.h b/tensorflow/contrib/lite/testing/tflite_driver.h
index 5493ba3631..aed35f877d 100644
--- a/tensorflow/contrib/lite/testing/tflite_driver.h
+++ b/tensorflow/contrib/lite/testing/tflite_driver.h
@@ -17,6 +17,7 @@ limitations under the License.
#include <map>
+#include "tensorflow/contrib/lite/delegates/eager/delegate.h"
#include "tensorflow/contrib/lite/interpreter.h"
#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/model.h"
@@ -28,7 +29,7 @@ namespace testing {
// A test runner that feeds inputs into TF Lite and verifies its outputs.
class TfLiteDriver : public TestRunner {
public:
- explicit TfLiteDriver(bool use_nnapi);
+ explicit TfLiteDriver(bool use_nnapi, const string& delegate = "");
~TfLiteDriver() override;
void LoadModel(const string& bin_file_path) override;
@@ -52,6 +53,7 @@ class TfLiteDriver : public TestRunner {
class Expectation;
+ std::unique_ptr<EagerDelegate> delegate_;
bool use_nnapi_ = false;
std::unique_ptr<FlatBufferModel> model_;
std::unique_ptr<Interpreter> interpreter_;