aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/lite/build_def.bzl3
-rw-r--r--tensorflow/contrib/lite/testing/generate_examples.py106
2 files changed, 45 insertions, 64 deletions
diff --git a/tensorflow/contrib/lite/build_def.bzl b/tensorflow/contrib/lite/build_def.bzl
index 2813d1c347..b8f6b7fd59 100644
--- a/tensorflow/contrib/lite/build_def.bzl
+++ b/tensorflow/contrib/lite/build_def.bzl
@@ -200,8 +200,7 @@ def gen_zipped_test_files(name, files):
native.genrule(
name = name + "_" + f + ".files",
cmd = ("$(locations :generate_examples) --toco $(locations %s) " % toco
- + " --zip_to_output " + f +
- " $(@D) zipped"),
+ + " --zip_to_output " + f + " $(@D)"),
outs = [out_file],
tools = [
":generate_examples",
diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py
index 8045052452..f919517e93 100644
--- a/tensorflow/contrib/lite/testing/generate_examples.py
+++ b/tensorflow/contrib/lite/testing/generate_examples.py
@@ -17,10 +17,9 @@
Usage:
-generate_examples <output directory> zipped
+generate_examples <output directory>
bazel run //tensorflow/contrib/lite/testing:generate_examples
- third_party/tensorflow/contrib/lite/testing/generated_examples zipped
"""
from __future__ import absolute_import
from __future__ import division
@@ -52,8 +51,6 @@ from tensorflow.python.ops import rnn
parser = argparse.ArgumentParser(description="Script to generate TFLite tests.")
parser.add_argument("output_path",
help="Directory where the outputs will be go.")
-# TODO(ahentz): remove this flag
-parser.add_argument("type", help="zipped")
parser.add_argument("--zip_to_output",
type=str,
help="Particular zip to output.",
@@ -543,6 +540,18 @@ def make_pool_tests(pool_op_in):
return f
+def make_l2_pool_tests(zip_path):
+ make_pool_tests(make_l2_pool)(zip_path)
+
+
+def make_avg_pool_tests(zip_path):
+ make_pool_tests(tf.nn.avg_pool)(zip_path)
+
+
+def make_max_pool_tests(zip_path):
+ make_pool_tests(tf.nn.max_pool)(zip_path)
+
+
def make_relu_tests(zip_path):
"""Make a set of tests to do relu."""
@@ -902,6 +911,22 @@ def make_binary_op_tests_func(binary_operator):
return lambda zip_path: make_binary_op_tests(zip_path, binary_operator)
+def make_add_tests(zip_path):
+ make_binary_op_tests(zip_path, tf.add)
+
+
+def make_div_tests(zip_path):
+ make_binary_op_tests(zip_path, tf.div)
+
+
+def make_sub_tests(zip_path):
+ make_binary_op_tests(zip_path, tf.subtract)
+
+
+def make_mul_tests(zip_path):
+ make_binary_op_tests(zip_path, tf.multiply)
+
+
def make_gather_tests(zip_path):
"""Make a set of tests to do gather."""
@@ -1169,7 +1194,7 @@ def make_split_tests(zip_path):
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
-def make_concatenation_tests(zip_path):
+def make_concat_tests(zip_path):
"""Make a set of tests to do concatenation."""
test_parameters = [{
@@ -1966,69 +1991,26 @@ def main(unused_args):
if not os.path.isdir(x):
raise RuntimeError("Failed to create dir %r" % x)
- if FLAGS.type == "zipped":
- opstest_path = os.path.join(FLAGS.output_path)
- mkdir_if_not_exist(opstest_path)
- def _path(filename):
- return os.path.join(opstest_path, filename)
-
- dispatch = {
- "control_dep.zip": make_control_dep_tests,
- "add.zip": make_binary_op_tests_func(tf.add),
- "space_to_batch_nd.zip": make_space_to_batch_nd_tests,
- "div.zip": make_binary_op_tests_func(tf.div),
- "sub.zip": make_binary_op_tests_func(tf.subtract),
- "batch_to_space_nd.zip": make_batch_to_space_nd_tests,
- "conv.zip": make_conv_tests,
- "constant.zip": make_constant_tests,
- "depthwiseconv.zip": make_depthwiseconv_tests,
- "concat.zip": make_concatenation_tests,
- "fully_connected.zip": make_fully_connected_tests,
- "global_batch_norm.zip": make_global_batch_norm_tests,
- "gather.zip": make_gather_tests,
- "fused_batch_norm.zip": make_fused_batch_norm_tests,
- "l2norm.zip": make_l2norm_tests,
- "local_response_norm.zip": make_local_response_norm_tests,
- "mul.zip": make_binary_op_tests_func(tf.multiply),
- "relu.zip": make_relu_tests,
- "relu1.zip": make_relu1_tests,
- "relu6.zip": make_relu6_tests,
- "prelu.zip": make_prelu_tests,
- "l2_pool.zip": make_pool_tests(make_l2_pool),
- "avg_pool.zip": make_pool_tests(tf.nn.avg_pool),
- "max_pool.zip": make_pool_tests(tf.nn.max_pool),
- "pad.zip": make_pad_tests,
- "reshape.zip": make_reshape_tests,
- "resize_bilinear.zip": make_resize_bilinear_tests,
- "sigmoid.zip": make_sigmoid_tests,
- "softmax.zip": make_softmax_tests,
- "space_to_depth.zip": make_space_to_depth_tests,
- "topk.zip": make_topk_tests,
- "split.zip": make_split_tests,
- "transpose.zip": make_transpose_tests,
- "mean.zip": make_mean_tests,
- "squeeze.zip": make_squeeze_tests,
- "strided_slice.zip": make_strided_slice_tests,
- "exp.zip": make_exp_tests,
- "log_softmax.zip": make_log_softmax_tests,
- "lstm.zip": make_lstm_tests,
- "maximum.zip": make_maximum_tests,
- }
- out = FLAGS.zip_to_output
- bin_path = FLAGS.toco
- if out in dispatch:
- dispatch[out](_path(out))
- else:
- raise RuntimeError("Invalid zip to output %r" % out)
+ opstest_path = os.path.join(FLAGS.output_path)
+ mkdir_if_not_exist(opstest_path)
- else:
- raise RuntimeError("Invalid argument for type of generation.")
+ out = FLAGS.zip_to_output
+ bin_path = FLAGS.toco
+ test_function = ("make_%s_tests" % out.replace(".zip", ""))
+ if test_function not in globals():
+ raise RuntimeError("Can't find a test function to create %r. Tried %r" %
+ (out, test_function))
+
+ # TODO(ahentz): accessing globals() is not very elegant. We should either
+ # break this file into multiple tests or use decorator-based registration to
+ # avoid using globals().
+ globals()[test_function](os.path.join(opstest_path, out))
if __name__ == "__main__":
FLAGS, unparsed = parser.parse_known_args()
if unparsed:
- print("Usage: %s <path out> zipped <zip file to generate>")
+ print("Usage: %s <path out> <zip file to generate>")
else:
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)