aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/testing
diff options
context:
space:
mode:
authorGravatar Nupur Garg <nupurgarg@google.com>2018-09-11 16:37:49 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-11 16:42:08 -0700
commit8ebfc2633b355535284e6fc8970fc91dde45ed9d (patch)
tree9b5b710149f11f03c5686b38c91af888f64c1440 /tensorflow/contrib/lite/testing
parent668c079f4e6020131978b7a812c3b92eea9c47b9 (diff)
Internal change.
PiperOrigin-RevId: 212545735
Diffstat (limited to 'tensorflow/contrib/lite/testing')
-rw-r--r--tensorflow/contrib/lite/testing/BUILD5
-rw-r--r--tensorflow/contrib/lite/testing/generate_examples.py27
2 files changed, 24 insertions, 8 deletions
diff --git a/tensorflow/contrib/lite/testing/BUILD b/tensorflow/contrib/lite/testing/BUILD
index 3a6c16cafc..a4736bfee9 100644
--- a/tensorflow/contrib/lite/testing/BUILD
+++ b/tensorflow/contrib/lite/testing/BUILD
@@ -7,7 +7,7 @@ licenses(["notice"]) # Apache 2.0
load(
"//tensorflow/contrib/lite:build_def.bzl",
"gen_zip_test",
- "generated_test_models",
+ "generated_test_models_all",
)
load("//tensorflow/contrib/lite:special_rules.bzl", "tflite_portable_test_suite")
load(
@@ -29,6 +29,7 @@ load(
"--unzip_binary_path=/usr/bin/unzip",
],
}),
+ conversion_mode = conversion_mode,
data = [
":zip_%s" % test_name,
],
@@ -59,7 +60,7 @@ load(
"//tensorflow/core:android_tensorflow_test_lib",
],
}),
-) for test_name in generated_test_models()]
+) for conversion_mode, test_name in generated_test_models_all()]
test_suite(
name = "generated_zip_tests",
diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py
index 32f02a4f6c..812385e706 100644
--- a/tensorflow/contrib/lite/testing/generate_examples.py
+++ b/tensorflow/contrib/lite/testing/generate_examples.py
@@ -80,7 +80,10 @@ parser.add_argument(
"--save_graphdefs",
action="store_true",
help="Include intermediate graphdefs in the output zip files.")
-
+parser.add_argument(
+ "--run_with_extended",
+ action="store_true",
+ help="Whether the TFLite Extended converter is being used.")
RANDOM_SEED = 342
TEST_INPUT_DEPTH = 3
@@ -320,10 +323,11 @@ def toco_convert(graph_def_str, input_tensors, output_tensors,
output tflite model, log_txt from conversion
or None, log_txt if it did not convert properly.
"""
+ input_arrays = [x[0] for x in input_tensors]
data_types = [_TF_TYPE_INFO[x[2]][1] for x in input_tensors]
opts = toco_options(
data_types=data_types,
- input_arrays=[x[0] for x in input_tensors],
+ input_arrays=input_arrays,
shapes=[x[1] for x in input_tensors],
output_arrays=output_tensors,
extra_toco_options=extra_toco_options)
@@ -335,6 +339,11 @@ def toco_convert(graph_def_str, input_tensors, output_tensors,
graphdef_file.flush()
# TODO(aselle): Switch this to subprocess at some point.
+ if "pb2lite" in bin_path and FLAGS.run_with_extended:
+ opts = ("--input_arrays={0} --output_arrays={1}".format(
+ ",".join(input_arrays), ",".join(output_tensors)))
+ elif FLAGS.run_with_extended:
+ opts += " --allow_eager_ops --force_eager_ops"
cmd = ("%s --input_file=%s --output_file=%s %s > %s 2>&1" %
(bin_path, graphdef_file.name, output_file.name, opts,
stdout_file.name))
@@ -1502,7 +1511,7 @@ def make_split_tests(zip_path):
dtype=tf.float32, name="input", shape=parameters["input_shape"])
out = tf.split(
input_tensor, parameters["num_or_size_splits"], parameters["axis"])
- return [input_tensor], out
+ return [input_tensor], [out[0]]
def build_inputs(parameters, sess, inputs, outputs):
values = [create_tensor_data(np.float32, parameters["input_shape"])]
@@ -2510,10 +2519,12 @@ def make_topk_tests(zip_path):
shape=parameters["input_shape"])
if parameters["input_k"] is not None:
k = tf.placeholder(dtype=tf.int32, name="input_k", shape=[])
+ inputs = [input_value, k]
else:
k = tf.constant(3, name="k")
+ inputs = [input_value]
out = tf.nn.top_k(input_value, k)
- return [input_value, k], [out[1]]
+ return inputs, [out[1]]
def build_inputs(parameters, sess, inputs, outputs):
input_value = create_tensor_data(parameters["input_dtype"],
@@ -3208,7 +3219,7 @@ def make_unpack_tests(zip_path):
input_tensor = tf.placeholder(
dtype=tf.float32, name=("input"), shape=parameters["base_shape"])
outs = tf.unstack(input_tensor, axis=get_valid_axis(parameters))
- return [input_tensor], outs
+ return [input_tensor], [outs[0]]
def build_inputs(parameters, sess, inputs, outputs):
input_value = create_tensor_data(np.float32, shape=parameters["base_shape"])
@@ -3286,7 +3297,11 @@ def main(unused_args):
out = FLAGS.zip_to_output
bin_path = FLAGS.toco
- test_function = ("make_%s_tests" % out.replace(".zip", ""))
+ # Some zip filenames contain a postfix identifying the conversion mode. The
+ # list of valid conversion modes is defined in
+ # generated_test_conversion_modes() in build_def.bzl.
+ test_function = ("make_%s_tests" % (out.replace(".zip", "").replace(
+ "pb2lite", "").replace("toco-extended", "").rstrip("_")))
if test_function not in globals():
raise RuntimeError("Can't find a test function to create %r. Tried %r" %
(out, test_function))