aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/testing
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/testing')
-rw-r--r--tensorflow/contrib/lite/testing/BUILD213
-rw-r--r--tensorflow/contrib/lite/testing/generate_examples.py1189
-rw-r--r--tensorflow/contrib/lite/testing/generate_examples_report.py125
-rw-r--r--tensorflow/contrib/lite/testing/generated_examples_zip_test.cc279
-rw-r--r--tensorflow/contrib/lite/testing/message.cc96
-rw-r--r--tensorflow/contrib/lite/testing/message.h82
-rw-r--r--tensorflow/contrib/lite/testing/message_test.cc121
-rw-r--r--tensorflow/contrib/lite/testing/nnapi_example.cc114
-rw-r--r--tensorflow/contrib/lite/testing/parse_testdata.cc335
-rw-r--r--tensorflow/contrib/lite/testing/parse_testdata.h74
-rw-r--r--tensorflow/contrib/lite/testing/split.cc42
-rw-r--r--tensorflow/contrib/lite/testing/split.h77
-rw-r--r--tensorflow/contrib/lite/testing/split_test.cc57
-rw-r--r--tensorflow/contrib/lite/testing/test_runner.h124
-rw-r--r--tensorflow/contrib/lite/testing/test_runner_test.cc84
-rw-r--r--tensorflow/contrib/lite/testing/tflite_driver.cc208
-rw-r--r--tensorflow/contrib/lite/testing/tflite_driver.h62
-rw-r--r--tensorflow/contrib/lite/testing/tflite_driver_test.cc61
-rw-r--r--tensorflow/contrib/lite/testing/tokenize.cc95
-rw-r--r--tensorflow/contrib/lite/testing/tokenize.h42
-rw-r--r--tensorflow/contrib/lite/testing/tokenize_test.cc105
21 files changed, 3585 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/testing/BUILD b/tensorflow/contrib/lite/testing/BUILD
new file mode 100644
index 0000000000..5e40a13d3c
--- /dev/null
+++ b/tensorflow/contrib/lite/testing/BUILD
@@ -0,0 +1,213 @@
+package(default_visibility = [
+ "//visibility:public",
+])
+
+licenses(["notice"]) # Apache 2.0
+
+load(
+ "//tensorflow/contrib/lite:build_def.bzl",
+ "gen_zipped_test_files",
+)
+load(
+ "//tensorflow:tensorflow.bzl",
+ "tf_cc_test",
+)
+
+gen_zipped_test_files(
+ name = "optest",
+ files = [
+ "add.zip",
+ "avg_pool.zip",
+ "concat.zip",
+ "constant.zip",
+ "control_dep.zip",
+ "conv.zip",
+ "depthwiseconv.zip",
+ "fully_connected.zip",
+ "fused_batch_norm.zip",
+ "global_batch_norm.zip",
+ "l2_pool.zip",
+ "l2norm.zip",
+ "local_response_norm.zip",
+ "max_pool.zip",
+ "mul.zip",
+ "relu.zip",
+ "relu1.zip",
+ "relu6.zip",
+ "reshape.zip",
+ "resize_bilinear.zip",
+ "sigmoid.zip",
+ "softmax.zip",
+ "space_to_depth.zip",
+ ],
+)
+
+py_binary(
+ name = "generate_examples",
+ srcs = ["generate_examples.py"],
+ data = [
+ "//tensorflow/contrib/lite/toco",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":generate_examples_report",
+ "//tensorflow:tensorflow_py",
+ "//tensorflow/python:graph_util",
+ "//third_party/py/numpy",
+ "@six_archive//:six",
+ ],
+)
+
+py_library(
+ name = "generate_examples_report",
+ srcs = ["generate_examples_report.py"],
+ srcs_version = "PY2AND3",
+)
+
+cc_library(
+ name = "parse_testdata_lib",
+ srcs = ["parse_testdata.cc"],
+ hdrs = ["parse_testdata.h"],
+ deps = [
+ ":message",
+ ":split",
+ ":test_runner",
+ "//tensorflow/contrib/lite:framework",
+ ],
+)
+
+cc_library(
+ name = "message",
+ srcs = ["message.cc"],
+ hdrs = ["message.h"],
+ deps = [":tokenize"],
+)
+
+cc_test(
+ name = "message_test",
+ srcs = ["message_test.cc"],
+ deps = [
+ ":message",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_library(
+ name = "split",
+ srcs = ["split.cc"],
+ hdrs = ["split.h"],
+ deps = [
+ "//tensorflow/contrib/lite:string",
+ ],
+)
+
+cc_test(
+ name = "split_test",
+ size = "small",
+ srcs = ["split_test.cc"],
+ deps = [
+ ":split",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_library(
+ name = "tflite_driver",
+ srcs = ["tflite_driver.cc"],
+ hdrs = ["tflite_driver.h"],
+ deps = [
+ ":split",
+ ":test_runner",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:builtin_ops",
+ ],
+)
+
+cc_test(
+ name = "tflite_driver_test",
+ size = "small",
+ srcs = ["tflite_driver_test.cc"],
+ data = ["//tensorflow/contrib/lite:testdata/multi_add.bin"],
+ deps = [
+ ":tflite_driver",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_library(
+ name = "tokenize",
+ srcs = ["tokenize.cc"],
+ hdrs = ["tokenize.h"],
+ deps = [
+ "//tensorflow/contrib/lite:string",
+ ],
+)
+
+cc_test(
+ name = "tokenize_test",
+ srcs = ["tokenize_test.cc"],
+ deps = [
+ ":tokenize",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_library(
+ name = "test_runner",
+ hdrs = ["test_runner.h"],
+ deps = [
+ "//tensorflow/contrib/lite:string",
+ ],
+)
+
+cc_test(
+ name = "test_runner_test",
+ srcs = ["test_runner_test.cc"],
+ deps = [
+ ":test_runner",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_binary(
+ name = "nnapi_example",
+ srcs = ["nnapi_example.cc"],
+ deps = [
+ ":parse_testdata_lib",
+ "//tensorflow/contrib/lite:builtin_op_data",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:builtin_ops",
+ "//tensorflow/contrib/lite/nnapi:nnapi_lib",
+ ],
+)
+
+tf_cc_test(
+ name = "generated_examples_zip_test",
+ size = "medium",
+ srcs = ["generated_examples_zip_test.cc"],
+ data = [":optest"],
+ shard_count = 10,
+ deps = [
+ ":parse_testdata_lib",
+ "//tensorflow/contrib/lite:builtin_op_data",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:builtin_ops",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ "@com_google_googletest//:gtest",
+ "@com_googlesource_code_re2//:re2",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py
new file mode 100644
index 0000000000..86540d58a6
--- /dev/null
+++ b/tensorflow/contrib/lite/testing/generate_examples.py
@@ -0,0 +1,1189 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Generate a series of TensorFlow graphs that become tflite test cases.
+
+Usage:
+
+generate_examples <output directory> zipped
+
+bazel run //tensorflow/contrib/lite/testing:generate_examples
+ third_party/tensorflow/contrib/lite/testing/generated_examples zipped
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import itertools
+import os
+import re
+import sys
+import tempfile
+import traceback
+import zipfile
+import numpy as np
+from six import StringIO
+import tensorflow as tf
+from google.protobuf import text_format
+# TODO(aselle): switch to TensorFlow's resource_loader
+from tensorflow.contrib.lite.testing import generate_examples_report as report_lib
+from tensorflow.python.framework import graph_util as tf_graph_util
+
+parser = argparse.ArgumentParser(description="Script to generate TFLite tests.")
+parser.add_argument("output_path",
+ help="Directory where the outputs will be go.")
+# TODO(ahentz): remove this flag
+parser.add_argument("type", help="zipped")
+parser.add_argument("--zip_to_output",
+ type=str,
+ help="Particular zip to output.",
+ required=False)
+parser.add_argument("--toco",
+ type=str,
+ help="Path to toco tool.",
+ required=True)
+parser.add_argument(
+ "--known_bugs_are_errors",
+ action="store_true",
+ help=("If a particular model is affected by a known bug,"
+ " count it as a toco error."))
+parser.add_argument(
+ "--ignore_toco_errors",
+ action="store_true",
+ help="Raise an exception if any toco error is encountered.")
+parser.add_argument(
+ "--save_graphdefs",
+ action="store_true",
+ help="Include intermediate graphdefs in the output zip files.")
+
+
+RANDOM_SEED = 342
+TEST_INPUT_DEPTH = 3
+
+
+# A map from regular expression to bug number. Any test failure with label
+# matching the expression will be considered due to the corresponding bug.
+KNOWN_BUGS = {
+ # TOCO doesn't support scalars as input.
+ r"relu.*input_shape=\[\]": "67587484",
+ r"sigmoid.*input_shape=\[\]": "67645668",
+ # Concat doesn't work with a single input tensor
+ r"concat.*num_tensors=1": "67378344",
+ # Transposition in MatMul is not supported.
+ r"fully_connected.*transpose_.=True": "67586970",
+ # Softmax graphs are too complex.
+ r"softmax.*dim=0": "67749831",
+ r"softmax.*input_shape=\[1,3,4,3\]": "67749831",
+ # SpaceToDepth only supports float32.
+ r"space_to_depth.*(float16|int32|uint8|int64)": "68018134",
+}
+
+
+def toco_options(data_types,
+ input_arrays,
+ output_arrays,
+ shapes,
+ drop_control_dependency):
+ """Create TOCO options to process a model.
+
+ Args:
+ data_types: input and inference types used by TOCO.
+ input_arrays: names of the input tensors
+ output_arrays: name of the output tensors
+ shapes: shapes of the input tensors
+ drop_control_dependency: whether to ignore control dependency nodes.
+
+ Returns:
+ the options in a string.
+ """
+ shape_str = ":".join([",".join(str(y) for y in x) for x in shapes])
+ inference_type = "FLOAT"
+ # TODO(ahentz): if we get multi-input quantization to work we need this
+ # to change
+ if data_types[0] == "QUANTIZED_UINT8":
+ inference_type = "QUANTIZED_UINT8"
+ s = (" --input_types=%s" % ",".join(data_types) +
+ " --inference_type=%s" % inference_type +
+ " --input_format=TENSORFLOW_GRAPHDEF" + " --output_format=TFLITE" +
+ " --input_arrays=%s" % ",".join(input_arrays) +
+ " --input_shapes=%s" % shape_str +
+ " --output_arrays=%s" % ",".join(output_arrays))
+ if drop_control_dependency:
+ s += " --drop_control_dependency"
+ return s
+
+
+def write_toco_options(filename,
+ data_types,
+ input_arrays,
+ output_arrays,
+ shapes,
+ drop_control_dependency=False):
+ """Create TOCO options to process a model.
+
+ Args:
+ filename: Filename to write the options to.
+ data_types: input and inference types used by TOCO.
+ input_arrays: names of the input tensors
+ output_arrays: names of the output tensors
+ shapes: shapes of the input tensors
+ drop_control_dependency: whether to ignore control dependency nodes.
+ """
+ with open(filename, "w") as fp:
+ fp.write(
+ toco_options(
+ data_types=data_types,
+ input_arrays=input_arrays,
+ output_arrays=output_arrays,
+ shapes=shapes,
+ drop_control_dependency=drop_control_dependency))
+
+
+def write_examples(fp, examples):
+ """Given a list `examples`, write a text format representation.
+
+ The file format is csv like with a simple repeated pattern. We would ike
+ to use proto here, but we can't yet due to interfacing with the Android
+ team using this format.
+
+ Args:
+ fp: File-like object to write to.
+ examples: Example dictionary consiting of keys "inputs" and "outputs"
+ """
+
+ def write_tensor(fp, x):
+ """Write tensor in file format supported by TFLITE example."""
+ fp.write("dtype,%s\n" % x.dtype)
+ fp.write("shape," + ",".join(map(str, x.shape)) + "\n")
+ # Output 9 digits after the point to ensure the precision is good enough.
+ values = ["{:.9f}".format(value) for value in list(x.flatten())]
+ fp.write("values," + ",".join(values) + "\n")
+
+ fp.write("test_cases,%d\n" % len(examples))
+ for example in examples:
+ fp.write("inputs,%d\n" % len(example["inputs"]))
+ for i in example["inputs"]:
+ write_tensor(fp, i)
+ fp.write("outputs,%d\n" % len(example["outputs"]))
+ for i in example["outputs"]:
+ write_tensor(fp, i)
+
+
+def write_test_cases(fp, model_name, examples):
+ """Given a dictionary of `examples`, write a text format representation.
+
+ The file format is protocol-buffer-like, even though we don't use proto due
+ to the needs of the Android team.
+
+ Args:
+ fp: File-like object to write to.
+ model_name: Filename where the model was written to, relative to filename.
+ examples: Example dictionary consiting of keys "inputs" and "outputs"
+ """
+
+ fp.write("load_model: %s\n" % os.path.basename(model_name))
+ for example in examples:
+ fp.write("reshape {\n")
+ for t in example["inputs"]:
+ fp.write(" input: \"" + ",".join(map(str, t.shape)) + "\"\n")
+ fp.write("}\n")
+ fp.write("invoke {\n")
+
+ for t in example["inputs"]:
+ values = ["{:.9f}".format(value) for value in list(t.flatten())]
+ fp.write(" input: \"" + ",".join(values) + "\"\n")
+ for t in example["outputs"]:
+ values = ["{:.9f}".format(value) for value in list(t.flatten())]
+ fp.write(" output: \"" + ",".join(values) + "\"\n")
+ fp.write("}\n")
+
+
+_TF_TYPE_INFO = {
+ tf.float32: (np.float32, "FLOAT"),
+ tf.float16: (np.float16, "FLOAT"),
+ tf.int32: (np.int32, "INT32"),
+ tf.uint8: (np.uint8, "QUANTIZED_UINT8"),
+ tf.int64: (np.int64, "INT64"),
+}
+
+
+def create_tensor_data(dtype, shape, min_value=-100, max_value=100):
+ """Build tensor data spreading the range [min_value, max_value)."""
+
+ if dtype in _TF_TYPE_INFO:
+ dtype = _TF_TYPE_INFO[dtype][0]
+
+ if dtype in (tf.float32, tf.float16):
+ value = (max_value-min_value)*np.random.random_sample(shape)+min_value
+ elif dtype in (tf.int32, tf.uint8, tf.int64):
+ value = np.random.random_integers(min_value, max_value, shape)
+ return value.astype(dtype)
+
+
+def freeze_graph(session, outputs):
+ """Freeze the current graph.
+
+ Args:
+ session: Tensorflow sessions containing the graph
+ outputs: List of output tensors
+
+ Returns:
+ The frozen graph_def.
+ """
+ return tf_graph_util.convert_variables_to_constants(
+ session, session.graph.as_graph_def(), [x.op.name for x in outputs])
+
+
+def make_control_dep_tests(zip_path):
+ """Make a set of tests that use control dependencies."""
+
+ test_parameters = [{
+ "input_shape": [[], [1, 1, 1, 1], [1, 15, 14, 1], [3, 15, 14, 3]],
+ }]
+
+ def build_graph(parameters):
+ input_tensor = tf.placeholder(
+ dtype=tf.float32, name="input", shape=parameters["input_shape"])
+ filter_value = tf.zeros((3, 3, TEST_INPUT_DEPTH, 8), tf.float32)
+ assert_op = tf.assert_greater_equal(input_tensor, input_tensor - 1)
+ with tf.control_dependencies([assert_op]):
+ out = tf.nn.conv2d(input_tensor, filter_value,
+ strides=(1, 1, 1, 1), padding="SAME")
+ return [input_tensor], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ input_values = create_tensor_data(tf.float32, parameters["input_shape"])
+ return [input_values], sess.run(
+ outputs, feed_dict=dict(zip(inputs, [input_values])))
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs,
+ drop_control_dependency=True)
+
+
+def toco_convert(graph_def_str, input_tensors, output_tensors,
+ drop_control_dependency=False):
+ """Convert a model's graph def into a tflite model.
+
+ NOTE: this currently shells out to the toco binary, but we would like
+ convert to Python API tooling in the future.
+
+ Args:
+ graph_def_str: Graph def proto in serialized string format.
+ input_tensors: List of input tensor tuples `(name, shape, type)`
+ output_tensors: List of output tensors (names)
+ drop_control_dependency: whether to ignore control dependency nodes.
+
+ Returns:
+ output tflite model, log_txt from conversion
+ or None, log_txt if it did not convert properly.
+ """
+ data_types = [_TF_TYPE_INFO[x[2]][1] for x in input_tensors]
+ opts = toco_options(
+ data_types=data_types,
+ input_arrays=[x[0] for x in input_tensors],
+ shapes=[x[1] for x in input_tensors],
+ output_arrays=output_tensors,
+ drop_control_dependency=drop_control_dependency)
+
+ with tempfile.NamedTemporaryFile() as graphdef_file, \
+ tempfile.NamedTemporaryFile() as output_file, \
+ tempfile.NamedTemporaryFile("w+") as stdout_file:
+ graphdef_file.write(graph_def_str)
+ graphdef_file.flush()
+
+ # TODO(aselle): Switch this to subprocess at some point.
+ cmd = ("%s --input_file=%s --output_file=%s %s > %s 2>&1" %
+ (bin_path, graphdef_file.name, output_file.name, opts,
+ stdout_file.name))
+ exit_code = os.system(cmd)
+ log = (
+ cmd + "exited with code %d" % exit_code + "\n------------------\n" +
+ stdout_file.read())
+ return (None if exit_code != 0 else output_file.read()), log
+
+
+def make_zip_of_tests(zip_path,
+ test_parameters,
+ make_graph,
+ make_test_inputs,
+ drop_control_dependency=False):
+ """Helper to make a zip file of a bunch of TensorFlow models.
+
+ This does a cartestian product of the dictionary of test_parameters and
+ calls make_graph() for each item in the cartestian product set.
+ If the graph is built successfully, then make_test_inputs() is called to
+ build expected input/output value pairs. The model is then converted to tflite
+ with toco, and the examples are serialized with the tflite model into a zip
+ file (2 files per item in the cartesian product set).
+
+ Args:
+ zip_path: Path of zip file to write
+ test_parameters: Dictionary mapping to lists for each parameter.
+ e.g. `{"strides": [[1,3,3,1], [1,2,2,1]], "foo": [1.2, 1.3]}`
+ make_graph: function that takes current parameters and returns tuple
+ `[input1, input2, ...], [output1, output2, ...]`
+ make_test_inputs: function taking `curr_params`, `session`, `input_tensors`,
+ `output_tensors` and returns tuple `(input_values, output_values)`.
+ drop_control_dependency: whether to ignore control dependency nodes.
+ Raises:
+ RuntimeError: if there are toco errors that can't be ignored.
+ """
+
+ # TODO(aselle): Make this allow multiple inputs outputs.
+ archive = zipfile.PyZipFile(zip_path, "w")
+ zip_manifest = []
+ convert_report = []
+ toco_errors = 0
+ for parameters in test_parameters:
+ keys = parameters.keys()
+ for curr in itertools.product(*parameters.values()):
+ label = zip_path.replace(".zip", "") + (",".join(
+ "%s=%r" % z for z in sorted(zip(keys, curr))).replace(" ", ""))
+ if label[0] == "/":
+ label = label[1:]
+ param_dict = dict(zip(keys, curr))
+
+ def build_example(label, param_dict_real):
+ """Build the model with parameter values set in param_dict_real.
+
+ Args:
+ label: Label of the model (i.e. the filename in the zip).
+ param_dict_real: Parameter dictionary (arguments to the factories
+ make_graph and make_test_inputs)
+ Returns:
+ (tflite_model_binary, report) where tflite_model_binary is the
+ serialized flatbuffer as a string and report is a dictionary with
+ keys `toco_log` (log of toco conversion), `tf_log` (log of tf
+ conversion), `toco` (a string of success status of the conversion),
+ `tf` (a string success status of the conversion).
+ """
+
+ np.random.seed(RANDOM_SEED)
+ report = {"toco": report_lib.NOTRUN, "tf": report_lib.FAILED}
+
+ # Build graph
+ report["tf_log"] = ""
+ report["toco_log"] = ""
+ tf.reset_default_graph()
+
+ try:
+ inputs, outputs = make_graph(param_dict_real)
+ except (tf.errors.UnimplementedError, tf.errors.InvalidArgumentError,
+ ValueError):
+ report["tf_log"] += traceback.format_exc()
+ return None, report
+
+ sess = tf.Session()
+ try:
+ baseline_inputs, baseline_outputs = (make_test_inputs(
+ param_dict_real, sess, inputs, outputs))
+ except (tf.errors.UnimplementedError, tf.errors.InvalidArgumentError,
+ ValueError):
+ report["tf_log"] += traceback.format_exc()
+ return None, report
+ report["toco"] = report_lib.FAILED
+ report["tf"] = report_lib.SUCCESS
+
+ # Convert graph to toco
+ tflite_model_binary, toco_log = toco_convert(
+ sess.graph_def.SerializeToString(),
+ [(input_tensor.name.split(":")[0], input_tensor.get_shape(),
+ input_tensor.dtype) for input_tensor in inputs],
+ [out.name.split(":")[0]
+ for out in outputs], drop_control_dependency)
+ report["toco"] = (report_lib.SUCCESS if tflite_model_binary is not None
+ else report_lib.FAILED)
+ report["toco_log"] = toco_log
+
+ if FLAGS.save_graphdefs:
+ archive.writestr(label + ".pb",
+ text_format.MessageToString(sess.graph_def),
+ zipfile.ZIP_DEFLATED)
+
+ if tflite_model_binary:
+ archive.writestr(label + ".bin", tflite_model_binary,
+ zipfile.ZIP_DEFLATED)
+ example = {"inputs": baseline_inputs, "outputs": baseline_outputs}
+
+ example_fp = StringIO()
+ write_examples(example_fp, [example])
+ archive.writestr(label + ".inputs",
+ example_fp.getvalue(), zipfile.ZIP_DEFLATED)
+
+ example_fp2 = StringIO()
+ write_test_cases(example_fp2, label + ".bin", [example])
+ archive.writestr(label + "_tests.txt",
+ example_fp2.getvalue(), zipfile.ZIP_DEFLATED)
+
+ zip_manifest.append(label + "\n")
+
+ return tflite_model_binary, report
+
+ _, report = build_example(label, param_dict)
+
+ if report["toco"] == report_lib.FAILED:
+ ignore_error = False
+ if not FLAGS.known_bugs_are_errors:
+ for pattern, bug_number in KNOWN_BUGS.items():
+ if re.search(pattern, label):
+ print("Ignored TOCO error due to bug %s" % bug_number)
+ ignore_error = True
+ if not ignore_error:
+ toco_errors += 1
+ print("-----------------\ntoco error!\n%s\n-----------------\n" %
+ report["toco_log"])
+
+ convert_report.append((param_dict, report))
+ report_io = StringIO()
+ report_lib.make_report_table(report_io, zip_path, convert_report)
+ archive.writestr("report.html", report_io.getvalue())
+
+ archive.writestr("manifest.txt", "".join(zip_manifest), zipfile.ZIP_DEFLATED)
+
+ # Log statistics of what succeeded
+ total_conversions = len(convert_report)
+ tf_success = sum(1 for x in convert_report
+ if x[1]["tf"] == report_lib.SUCCESS)
+ toco_success = sum(1 for x in convert_report
+ if x[1]["toco"] == report_lib.SUCCESS)
+ percent = 0
+ if tf_success > 0:
+ percent = float(toco_success) / float(tf_success) * 100.
+ tf.logging.info(("Archive %s Considered %d graphs, %d TF evaluated graphs "
+ " and %d TOCO converted graphs (%.1f%%"), zip_path,
+ total_conversions, tf_success, toco_success, percent)
+
+ if not FLAGS.ignore_toco_errors and toco_errors > 0:
+ raise RuntimeError(
+ "Found %d errors while generating toco models" % toco_errors)
+
+
+def make_pool_tests(pool_op_in):
+ """Make a set of tests to do average pooling.
+
+ Args:
+ pool_op_in: TensorFlow pooling operation to test i.e. `tf.nn.avg_pool`.
+
+ Returns:
+ A function representing the true generator (after curried pool_op_in).
+ """
+
+ pool_op = pool_op_in
+
+ def f(zip_path):
+ """Actual function that generates examples.
+
+ Args:
+ zip_path: path to write zip to.
+ """
+
+ # Chose a set of parameters
+ test_parameters = [{
+ "ksize": [[2, 1, 1, 2], [1, 1, 1, 1], [1, 1, 2, 1], [1, 10, 11, 1]],
+ "strides": [[2, 1, 1, 2], [1, 1, 1, 1], [1, 1, 2, 1], [1, 10, 11, 1]],
+ # TODO(aselle): should add in a degenerate shape (e.g. [1, 0, 1, 1]).
+ "input_shape": [[], [1, 1, 1, 1], [1, 15, 14, 1], [3, 15, 14, 3]],
+ "padding": ["SAME", "VALID"],
+ "data_format": ["NHWC"], # TODO(aselle): NCHW would be good
+ }]
+
+ def build_graph(parameters):
+ input_tensor = tf.placeholder(
+ dtype=tf.float32, name="input", shape=parameters["input_shape"])
+ out = pool_op(
+ input_tensor,
+ ksize=parameters["ksize"],
+ strides=parameters["strides"],
+ data_format=parameters["data_format"],
+ padding=parameters["padding"])
+ return [input_tensor], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ input_values = create_tensor_data(tf.float32, parameters["input_shape"])
+ return [input_values], sess.run(
+ outputs, feed_dict=dict(zip(inputs, [input_values])))
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+ return f
+
+
+def make_relu_tests(zip_path):
+ """Make a set of tests to do relu."""
+
+ # Chose a set of parameters
+ test_parameters = [{
+ "input_shape": [[], [1], [2, 3], [1, 1, 1, 1], [1, 3, 4, 3],
+ [3, 15, 14, 3], [3, 1, 2, 4, 6], [2, 2, 3, 4, 5, 6]],
+ }]
+
+ def build_graph(parameters):
+ input_tensor = tf.placeholder(
+ dtype=tf.float32, name="input", shape=parameters["input_shape"])
+ out = tf.nn.relu(input_tensor)
+ return [input_tensor], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ input_values = create_tensor_data(
+ np.float32, parameters["input_shape"], min_value=-4, max_value=10)
+ return [input_values], sess.run(
+ outputs, feed_dict=dict(zip(inputs, [input_values])))
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
+def make_relu1_tests(zip_path):
+ """Make a set of tests to do relu1."""
+
+ # Chose a set of parameters
+ test_parameters = [{
+ "input_shape": [[], [1, 1, 1, 1], [1, 3, 4, 3], [3, 15, 14, 3],
+ [3, 1, 2, 4, 6], [2, 2, 3, 4, 5, 6]],
+ }]
+
+ def build_graph(parameters):
+ input_tensor = tf.placeholder(
+ dtype=tf.float32, name="input", shape=parameters["input_shape"])
+ # Note that the following is not supported:
+ # out = tf.maximum(-1.0, tf.minimum(input_tensor, 1.0))
+ out = tf.minimum(1.0, tf.maximum(input_tensor, -1.0))
+ return [input_tensor], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ input_values = create_tensor_data(
+ np.float32, parameters["input_shape"], min_value=-3, max_value=10)
+ return [input_values], sess.run(
+ outputs, feed_dict=dict(zip(inputs, [input_values])))
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
+def make_relu6_tests(zip_path):
+ """Make a set of tests to do relu6."""
+
+ # Chose a set of parameters
+ test_parameters = [{
+ "input_shape": [[], [1, 1, 1, 1], [1, 3, 4, 3], [3, 15, 14, 3],
+ [3, 1, 2, 4, 6], [2, 2, 3, 4, 5, 6]],
+ }]
+
+ def build_graph(parameters):
+ input_tensor = tf.placeholder(
+ dtype=tf.float32, name="input", shape=parameters["input_shape"])
+ out = tf.nn.relu(input_tensor)
+ return [input_tensor], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ input_values = create_tensor_data(
+ np.float32, parameters["input_shape"], min_value=-3, max_value=10)
+ return [input_values], sess.run(
+ outputs, feed_dict=dict(zip(inputs, [input_values])))
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
+# This function tests various TensorFLow functions that generates Const op,
+# including `tf.ones`, `tf.zeros` and random functions.
+def make_constant_tests(zip_path):
+ """Make a set of tests to do constant ops."""
+
+ test_parameters = [{
+ "dtype": [tf.float32, tf.int32],
+ "input_shape": [[1], [2], [1, 1, 1, 1], [2, 2, 2, 2]],
+ }]
+
+ def build_graph(parameters):
+ # Since Toco & Tflite can't have a single constant op in the entire graph,
+ # this test adds a zero tesnor with a constant op tensor.
+ input1 = tf.placeholder(dtype=parameters["dtype"], name="input1",
+ shape=parameters["input_shape"])
+ out = tf.ones(parameters["input_shape"], dtype=parameters["dtype"]) + input1
+ return [input1], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ input1 = np.zeros(parameters["input_shape"],
+ dtype=_TF_TYPE_INFO[parameters["dtype"]][0])
+ return [input1], sess.run(outputs, feed_dict={inputs[0]: input1})
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
+def make_add_tests(zip_path):
+ """Make a set of tests to do add with and without broadcast."""
+
+ # These parameters are split because we don't support broadcasting.
+ test_parameters = [{
+ "dtype": [tf.float32, tf.int32],
+ "input_shape_1": [[1, 3, 4, 3]],
+ "input_shape_2": [[1, 3, 4, 3]],
+ }, {
+ "dtype": [tf.float32],
+ "input_shape_1": [[5]],
+ "input_shape_2": [[5]],
+ }, {
+ "dtype": [tf.float32],
+ "input_shape_1": [[1, 3, 4, 3]],
+ "input_shape_2": [[3]],
+ }]
+
+ def build_graph(parameters):
+ input1 = tf.placeholder(dtype=parameters["dtype"], name="input1",
+ shape=parameters["input_shape_1"])
+ input2 = tf.placeholder(dtype=parameters["dtype"], name="input2",
+ shape=parameters["input_shape_2"])
+ out = tf.add(input1, input2)
+ return [input1, input2], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ input1 = create_tensor_data(parameters["dtype"],
+ parameters["input_shape_1"])
+ input2 = create_tensor_data(parameters["dtype"],
+ parameters["input_shape_2"])
+ return [input1, input2], sess.run(
+ outputs, feed_dict={
+ inputs[0]: input1,
+ inputs[1]: input2
+ })
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
+def make_mul_tests(zip_path):
+ """Make a set of tests to do mul with and without broadcast."""
+
+ # These parameters are split because we don't support broadcasting.
+ test_parameters = [{
+ "dtype": [tf.float32, tf.int32],
+ "input_shape_1": [[1, 3, 4, 3]],
+ "input_shape_2": [[1, 3, 4, 3]],
+ }, {
+ "dtype": [tf.float32],
+ "input_shape_1": [[5]],
+ "input_shape_2": [[5]],
+ }, {
+ "dtype": [tf.float32],
+ "input_shape_1": [[1, 3, 4, 3]],
+ "input_shape_2": [[3]],
+ }]
+
+ def build_graph(parameters):
+ input1 = tf.placeholder(dtype=parameters["dtype"], name="input1",
+ shape=parameters["input_shape_1"])
+ input2 = tf.placeholder(dtype=parameters["dtype"], name="input2",
+ shape=parameters["input_shape_2"])
+ out = tf.multiply(input1, input2)
+ return [input1, input2], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ input1 = create_tensor_data(parameters["dtype"],
+ parameters["input_shape_1"])
+ input2 = create_tensor_data(parameters["dtype"],
+ parameters["input_shape_2"])
+ return [input1, input2], sess.run(
+ outputs, feed_dict={inputs[0]: input1,
+ inputs[1]: input2})
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
+def make_global_batch_norm_tests(zip_path):
+ """Make a set of tests to do batch_norm_with_global_normalization."""
+
+ test_parameters = [{
+ "dtype": [tf.float32],
+ "input_shape": [[1, 1, 6, 2], [3, 4, 5, 4]],
+ "epsilon": [0.1, 0.0001],
+ "scale_after": [True, False],
+ }]
+
+ def build_graph(parameters):
+ """Build the global batch norm testing graph."""
+ input_shape = parameters["input_shape"]
+ scale_shape = input_shape[3]
+
+ scale = create_tensor_data(parameters["dtype"], scale_shape)
+ offset = create_tensor_data(parameters["dtype"], scale_shape)
+ mean = create_tensor_data(parameters["dtype"], scale_shape)
+ variance = create_tensor_data(parameters["dtype"], scale_shape)
+
+ x = create_tensor_data(parameters["dtype"], parameters["input_shape"])
+ x_norm = tf.nn.batch_norm_with_global_normalization(
+ x, mean, variance, scale, offset,
+ parameters["epsilon"], parameters["scale_after"])
+
+ input_tensor = tf.placeholder(dtype=parameters["dtype"], name="input",
+ shape=parameters["input_shape"])
+ out = tf.add(input_tensor, x_norm)
+ return [input_tensor], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ input_value = create_tensor_data(parameters["dtype"],
+ parameters["input_shape"])
+ return [input_value], sess.run(
+ outputs, feed_dict=dict(zip(inputs, [input_value])))
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
+def make_fused_batch_norm_tests(zip_path):
+ """Make a set of tests to do fused_batch_norm."""
+
+ test_parameters = [{
+ "dtype": [tf.float32],
+ "input_shape": [[1, 1, 6, 2]],
+ "epsilon": [0.001, 0.1],
+ }]
+
+ def build_graph(parameters):
+ """Build the testing graph for fused batch normalization."""
+ input_shape = parameters["input_shape"]
+ scale_shape = input_shape[3]
+
+ scale = create_tensor_data(parameters["dtype"], scale_shape)
+ offset = create_tensor_data(parameters["dtype"], scale_shape)
+ mean = create_tensor_data(parameters["dtype"], scale_shape)
+ variance = create_tensor_data(parameters["dtype"], scale_shape)
+
+ x = create_tensor_data(parameters["dtype"], parameters["input_shape"])
+ [x_norm, _, _] = tf.nn.fused_batch_norm(
+ x, scale, offset, mean, variance,
+ parameters["epsilon"], data_format="NHWC", is_training=False)
+
+ input_tensor = tf.placeholder(dtype=parameters["dtype"], name="input",
+ shape=parameters["input_shape"])
+ out = tf.add(input_tensor, x_norm)
+ return [input_tensor], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ input_value = create_tensor_data(parameters["dtype"],
+ parameters["input_shape"])
+ return [input_value], sess.run(
+ outputs, feed_dict=dict(zip(inputs, [input_value])))
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
+def make_conv_tests(zip_path):
+ """Make a set of tests to do convolution."""
+
+ test_parameters = [{
+ "input_shape": [[1, 3, 4, 3]],
+ "filter_shape": [[1, 1, 3, 2]],
+ "strides": [[1, 1, 1, 1], [1, 2, 3, 1]],
+ "padding": ["SAME", "VALID"],
+ "data_format": ["NHWC"], # TODO(aselle): NCHW would be good
+ }, {
+ "input_shape": [[2, 14, 14, 2]],
+ "filter_shape": [[6, 6, 2, 2]],
+ "strides": [[1, 1, 1, 1], [1, 2, 3, 1]],
+ "padding": ["SAME", "VALID"],
+ "data_format": ["NHWC"], # TODO(aselle): NCHW would be good
+ }]
+
+ def build_graph(parameters):
+ input_tensor = tf.placeholder(
+ dtype=tf.float32, name="input", shape=parameters["input_shape"])
+ filter_values = create_tensor_data(np.float32, parameters["filter_shape"])
+ out = tf.nn.conv2d(input_tensor, filter_values,
+ strides=parameters["strides"],
+ padding=parameters["padding"],
+ data_format=parameters["data_format"])
+ return [input_tensor], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ input_values = create_tensor_data(np.float32, parameters["input_shape"])
+ return [input_values], sess.run(
+ outputs, feed_dict=dict(zip(inputs, [input_values])))
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
+def make_depthwiseconv_tests(zip_path):
+ """Make a set of tests to do convolution."""
+
+ # Tensorflow only supports equal strides
+ test_parameters = [{
+ "input_shape": [[1, 3, 4, 3], [1, 10, 10, 3]],
+ "filter_size": [[1, 1], [1, 2], [3, 3]],
+ "strides": [[1, 1, 1, 1], [1, 3, 3, 1]],
+ "channel_multiplier": [1, 2],
+ "rate": [[1, 1]],
+ "padding": ["SAME", "VALID"],
+ "data_format": ["NHWC"],
+ }, {
+ "input_shape": [[1, 3, 4, 3]],
+ "filter_size": [[1, 1]],
+ "strides": [[1, 1, 2, 1]], # TF needs [1, x, x, 1]
+ "channel_multiplier": [2],
+ "rate": [[2, 2]], # Only [1, 1] is supported
+ "padding": ["SAME"],
+ "data_format": ["NHWC"],
+ }]
+
+ def build_graph(parameters):
+ """Build a depthwise conv graph given `parameters`."""
+ input_shape = parameters["input_shape"]
+ filter_size = parameters["filter_size"]
+ input_tensor = tf.placeholder(
+ dtype=tf.float32, name="input", shape=input_shape)
+ filter_shape = filter_size + [
+ input_shape[3], parameters["channel_multiplier"]]
+ filter_values = create_tensor_data(np.float32, filter_shape)
+ out = tf.nn.depthwise_conv2d(
+ input_tensor, filter_values,
+ strides=parameters["strides"],
+ rate=parameters["rate"],
+ padding=parameters["padding"],
+ data_format=parameters["data_format"])
+ return [input_tensor], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ input_values = create_tensor_data(np.float32, parameters["input_shape"])
+ return [input_values], sess.run(
+ outputs, feed_dict=dict(zip(inputs, [input_values])))
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
+def make_concatenation_tests(zip_path):
+ """Make a set of tests to do concatenatinon."""
+
+ test_parameters = [{
+ "base_shape": [[1, 3, 4, 3], [3, 4]],
+ "num_tensors": [1, 2, 3, 4, 5, 6],
+ "axis": [0, 1, 2, 3],
+ }]
+
+ def get_shape(parameters, delta):
+ """Return a tweaked version of 'base_shape'."""
+ axis = parameters["axis"]
+ shape = parameters["base_shape"][:]
+ if axis < len(shape):
+ shape[axis] += delta
+ return shape
+
+ def build_graph(parameters):
+ all_tensors = []
+ for n in range(0, parameters["num_tensors"]):
+ input_tensor = tf.placeholder(dtype=tf.float32, name=("input%d" % n),
+ shape=get_shape(parameters, n))
+ all_tensors.append(input_tensor)
+ out = tf.concat(all_tensors, parameters["axis"])
+ return all_tensors, [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ all_values = []
+ for n in range(0, parameters["num_tensors"]):
+ input_values = create_tensor_data(np.float32,
+ get_shape(parameters, n))
+ all_values.append(input_values)
+ return all_values, sess.run(
+ outputs, feed_dict=dict(zip(inputs, all_values)))
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
+def make_fully_connected_tests(zip_path):
+ """Make a set of tests to do fully_connected."""
+
+ test_parameters = [{
+ "shape1": [[3, 3]],
+ "shape2": [[3, 3]],
+ "transpose_a": [True, False],
+ "transpose_b": [True, False],
+ }, {
+ "shape1": [[4, 4], [1, 4], [4]],
+ "shape2": [[4, 4], [4, 1], [4]],
+ "transpose_a": [False],
+ "transpose_b": [False],
+ }, {
+ "shape1": [[40, 37]],
+ "shape2": [[37, 40]],
+ "transpose_a": [False],
+ "transpose_b": [False],
+
+ }]
+
+ def build_graph(parameters):
+ input_tensor1 = tf.placeholder(dtype=tf.float32, name="input1",
+ shape=parameters["shape1"])
+ input_tensor2 = create_tensor_data(np.float32, parameters["shape2"])
+ out = tf.matmul(input_tensor1, input_tensor2,
+ transpose_a=parameters["transpose_a"],
+ transpose_b=parameters["transpose_b"])
+ return [input_tensor1], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ input_values1 = create_tensor_data(np.float32, shape=parameters["shape1"])
+ return [input_values1], sess.run(
+ outputs, feed_dict=dict(zip(inputs, [input_values1])))
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
+def make_l2norm_tests(zip_path):
+ """Make a set of tests to do l2norm."""
+
+ # Chose a set of parameters
+ test_parameters = [{
+ "input_shape": [[5, 7], [1, 1, 1, 1], [1, 3, 4, 3], [3, 15, 14, 3],
+ [3, 1, 2, 4, 6], [2, 2, 3, 4, 5, 6]],
+ "dim": [0, 1, 2, 3, [2, 3], -2],
+ "epsilon": [None, 1e-12, 1e-3],
+ }]
+
+ def build_graph(parameters):
+ input_tensor = tf.placeholder(
+ dtype=tf.float32, name="input", shape=parameters["input_shape"])
+ if parameters["epsilon"]:
+ out = tf.nn.l2_normalize(
+ input_tensor, parameters["dim"], epsilon=parameters["epsilon"])
+ else:
+ out = tf.nn.l2_normalize(input_tensor, parameters["dim"])
+ return [input_tensor], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ input_values = create_tensor_data(
+ np.float32, parameters["input_shape"], min_value=-4, max_value=10)
+ return [input_values], sess.run(
+ outputs, feed_dict=dict(zip(inputs, [input_values])))
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
+def make_local_response_norm_tests(zip_path):
+ """Make a set of tests to do local_response_norm."""
+
+ # Chose a set of parameters
+ test_parameters = [{
+ "input_shape": [[1, 1, 1, 1], [1, 3, 4, 3], [3, 15, 14, 3]],
+ "depth_radius": [None, 0, 1, 3, 4, 5],
+ "bias": [None, 0.1, 0.3, -0.1],
+ "alpha": [None, 1, 2, -3],
+ "beta": [None, 0.5, 0.25, 2],
+ }]
+
+ def build_graph(parameters):
+ input_tensor = tf.placeholder(
+ dtype=tf.float32, name="input", shape=parameters["input_shape"])
+ out = tf.nn.local_response_normalization(
+ input_tensor, depth_radius=parameters["depth_radius"],
+ bias=parameters["bias"], alpha=parameters["alpha"],
+ beta=parameters["beta"])
+ return [input_tensor], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ input_values = create_tensor_data(
+ np.float32, parameters["input_shape"], min_value=-4, max_value=10)
+ return [input_values], sess.run(
+ outputs, feed_dict=dict(zip(inputs, [input_values])))
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
+def make_reshape_tests(zip_path):
+ """Make a set of tests to do reshape."""
+
+ # Alll shapes below are suitable for tensors with 420 elements.
+ test_parameters = [{
+ "dtype": [tf.float32, tf.int32],
+ "input_shape": [[3, 4, 5, 7], [4, 105], [21, 5, 2, 2], [420]],
+ "output_shape": [[15, 28], [420], [1, -1, 5, 7], [-1]],
+ }]
+
+ def build_graph(parameters):
+ input_tensor = tf.placeholder(dtype=parameters["dtype"], name="input",
+ shape=parameters["input_shape"])
+ out = tf.reshape(input_tensor, shape=parameters["output_shape"])
+ return [input_tensor], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ input_values = create_tensor_data(parameters["dtype"],
+ parameters["input_shape"])
+ return [input_values], sess.run(
+ outputs, feed_dict=dict(zip(inputs, [input_values])))
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
+def make_resize_bilinear_tests(zip_path):
+ """Make a set of tests to do resize_bilinear."""
+
+ test_parameters = [{
+ "dtype": [tf.float32, tf.int32],
+ "input_shape": [[1, 3, 4, 3], [1, 10, 2, 1]],
+ "size": [[1, 1], [4, 3], [2, 2], [5, 6]],
+ "align_corners": [None, True, False],
+ }]
+
+ def build_graph(parameters):
+ input_tensor = tf.placeholder(dtype=parameters["dtype"], name="input",
+ shape=parameters["input_shape"])
+ out = tf.image.resize_bilinear(input_tensor, size=parameters["size"],
+ align_corners=parameters["align_corners"])
+ return [input_tensor], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ input_values = create_tensor_data(parameters["dtype"],
+ parameters["input_shape"])
+ return [input_values], sess.run(
+ outputs, feed_dict=dict(zip(inputs, [input_values])))
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
+def make_sigmoid_tests(zip_path):
+ """Make a set of tests to do sigmoid."""
+
+ test_parameters = [{
+ "dtype": [tf.float32],
+ "input_shape": [[1, 3, 4, 3], [4], [], [1, 2, 3, 4, 5, 6]],
+ }]
+
+ def build_graph(parameters):
+ input_tensor = tf.placeholder(dtype=parameters["dtype"], name="input",
+ shape=parameters["input_shape"])
+ out = tf.sigmoid(input_tensor)
+ return [input_tensor], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ input_values = create_tensor_data(parameters["dtype"],
+ parameters["input_shape"])
+ return [input_values], sess.run(
+ outputs, feed_dict=dict(zip(inputs, [input_values])))
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
+def make_softmax_tests(zip_path):
+ """Make a set of tests to do softmax."""
+
+ test_parameters = [{
+ "dtype": [tf.float32],
+ "input_shape": [[1, 3, 4, 3], [2, 3]],
+ "dim": [-1, 0],
+ }, {
+ "dtype": [tf.float32],
+ "input_shape": [[4, 7]],
+ "dim": [-1, 1],
+ }]
+
+ def build_graph(parameters):
+ input_tensor = tf.placeholder(dtype=parameters["dtype"], name="input",
+ shape=parameters["input_shape"])
+ out = tf.nn.softmax(input_tensor, dim=parameters["dim"])
+ return [input_tensor], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ input_values = create_tensor_data(parameters["dtype"],
+ parameters["input_shape"])
+ return [input_values], sess.run(
+ outputs, feed_dict=dict(zip(inputs, [input_values])))
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
+def make_space_to_depth_tests(zip_path):
+ """Make a set of tests to do space_to_depth."""
+
+ test_parameters = [{
+ "dtype": [tf.float32, tf.float16, tf.int32, tf.uint8, tf.int64],
+ "input_shape": [[2, 12, 24, 1]],
+ "block_size": [2, 3, 4],
+ }]
+
+ def build_graph(parameters):
+ input_tensor = tf.placeholder(dtype=parameters["dtype"], name="input",
+ shape=parameters["input_shape"])
+ out = tf.space_to_depth(input_tensor, block_size=parameters["block_size"])
+ return [input_tensor], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ input_values = create_tensor_data(parameters["dtype"],
+ parameters["input_shape"])
+ return [input_values], sess.run(
+ outputs, feed_dict=dict(zip(inputs, [input_values])))
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
+def make_l2_pool(input_tensor, ksize, strides, padding, data_format):
+ """Given an input perform a sequence of TensorFlow ops to produce l2pool."""
+ return tf.sqrt(tf.nn.avg_pool(
+ tf.square(input_tensor), ksize=ksize, strides=strides,
+ padding=padding, data_format=data_format))
+
+
+# Toco binary path provided by the generate rule.
+bin_path = None
+
+
+def main(unused_args):
+ global bin_path
+ def mkdir_if_not_exist(x):
+ if not os.path.isdir(x):
+ os.mkdir(x)
+ if not os.path.isdir(x):
+ raise RuntimeError("Failed to create dir %r" % x)
+
+ if FLAGS.type == "zipped":
+ opstest_path = os.path.join(FLAGS.output_path)
+ mkdir_if_not_exist(opstest_path)
+ def _path(filename):
+ return os.path.join(opstest_path, filename)
+
+ dispatch = {
+ "control_dep.zip": make_control_dep_tests,
+ "add.zip": make_add_tests,
+ "conv.zip": make_conv_tests,
+ "constant.zip": make_constant_tests,
+ "depthwiseconv.zip": make_depthwiseconv_tests,
+ "concat.zip": make_concatenation_tests,
+ "fully_connected.zip": make_fully_connected_tests,
+ "global_batch_norm.zip": make_global_batch_norm_tests,
+ "fused_batch_norm.zip": make_fused_batch_norm_tests,
+ "l2norm.zip": make_l2norm_tests,
+ "local_response_norm.zip": make_local_response_norm_tests,
+ "mul.zip": make_mul_tests,
+ "relu.zip": make_relu_tests,
+ "relu1.zip": make_relu1_tests,
+ "relu6.zip": make_relu6_tests,
+ "l2_pool.zip": make_pool_tests(make_l2_pool),
+ "avg_pool.zip": make_pool_tests(tf.nn.avg_pool),
+ "max_pool.zip": make_pool_tests(tf.nn.max_pool),
+ "reshape.zip": make_reshape_tests,
+ "resize_bilinear.zip": make_resize_bilinear_tests,
+ "sigmoid.zip": make_sigmoid_tests,
+ "softmax.zip": make_softmax_tests,
+ "space_to_depth.zip": make_space_to_depth_tests,
+ }
+ out = FLAGS.zip_to_output
+ bin_path = FLAGS.toco
+ if out in dispatch:
+ dispatch[out](_path(out))
+ else:
+ raise RuntimeError("Invalid zip to output %r" % out)
+
+ else:
+ raise RuntimeError("Invalid argument for type of generation.")
+
+
+if __name__ == "__main__":
+ FLAGS, unparsed = parser.parse_known_args()
+
+ if unparsed:
+ print("Usage: %s <path out> zipped <zip file to generate>")
+ else:
+ tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
diff --git a/tensorflow/contrib/lite/testing/generate_examples_report.py b/tensorflow/contrib/lite/testing/generate_examples_report.py
new file mode 100644
index 0000000000..7bcf8cd86a
--- /dev/null
+++ b/tensorflow/contrib/lite/testing/generate_examples_report.py
@@ -0,0 +1,125 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Make HTML tables that report where TF and TOCO failed to convert models.
+
+This is primarily used by generate_examples.py. See it or
+`make_report_table` for more details on usage.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import cgi
+import json
+
+FAILED = "FAILED"
+SUCCESS = "SUCCESS"
+NOTRUN = "NOTRUN"
+
+
+def make_report_table(fp, title, reports):
+ """Make an HTML report of the success/failure reports.
+
+ Args:
+ fp: File-like object in which to put the html.
+ title: "Title of the zip file this pertains to."
+ reports: a list of conversion attempts. (report_args, report_vals) i.e.
+ ({"shape": [1,2,3], "type": "tf.float32"},
+ {"tf": "SUCCESS", "toco": "FAILURE", "toco_log": "Unsupported type.",
+ "tf_log": ""})
+ """
+ # sort reports by if TOCO failure and then TF failure (reversed)
+ reports.sort(key=lambda x: x[1]["toco"], reverse=False)
+ reports.sort(key=lambda x: x[1]["tf"], reverse=True)
+ def result_cell(x, row, col):
+ """Produce a cell with the condition string `x`."""
+ s = cgi.escape(repr(x), quote=True)
+ color = "#44ff44" if x == SUCCESS else (
+ "#ff4444" if x == FAILED else "#eeeeee")
+ handler = "ShowLog(%d, %d)" % (row, col)
+ fp.write("<td style='background-color: %s' onclick='%s'>%s</td>\n" % (
+ color, handler, s))
+
+ fp.write("""<html>
+<head>
+<title>tflite report</title>
+<style>
+body { font-family: Arial; }
+th { background-color: #555555; color: #eeeeee; }
+td { vertical-align: top; }
+td.horiz {width: 50%;}
+pre { white-space: pre-wrap; word-break: keep-all; }
+table {width: 100%;}
+</style>
+</head>
+""")
+ # Write the log data to a javascript variable and also make a function
+ # in javascript to show the log when an item is clicked.
+ fp.write("<script> \n")
+ fp.write("""
+function ShowLog(row, col) {
+
+var log = document.getElementById("log");
+log.innerHTML = "<pre>" + data[row][col] + "</pre>";
+}
+""")
+ fp.write("var data = \n")
+ fp.write(json.dumps([[cgi.escape(x[1]["tf_log"], quote=True),
+ cgi.escape(x[1]["toco_log"], quote=True)]
+ for x in reports]))
+ fp.write(";</script>\n")
+
+ # Write the main table and use onclick on the items that have log items.
+ fp.write("""
+<body>
+<h1>TOCO Conversion</h1>
+<h2>%s</h2>
+""" % title)
+
+ # Get a list of keys that are in any of the records.
+ param_keys = {}
+ for params, _ in reports:
+ for k in params.keys():
+ param_keys[k] = True
+
+ fp.write("<table>\n")
+ fp.write("<tr><td class='horiz'>\n")
+ fp.write("<div style='height:1000px; overflow:auto'>\n")
+ fp.write("<table>\n")
+ fp.write("<tr>\n")
+ for p in param_keys:
+ fp.write("<th>%s</th>\n" % cgi.escape(p, quote=True))
+ fp.write("<th>TensorFlow</th>\n")
+ fp.write("<th>TOCO</th>\n")
+ fp.write("</tr>\n")
+ for idx, (params, vals) in enumerate(reports):
+ fp.write("<tr>\n")
+ for p in param_keys:
+ fp.write(" <td>%s</td>\n" % cgi.escape(repr(params[p]), quote=True))
+
+ result_cell(vals["tf"], idx, 0)
+ result_cell(vals["toco"], idx, 1)
+ fp.write("</tr>\n")
+ fp.write("</table>\n")
+ fp.write("</div>\n")
+ fp.write("</td>\n")
+ fp.write("<td class='horiz' id='log'></td></tr>\n")
+ fp.write("</table>\n")
+ fp.write("<script>\n")
+ fp.write("</script>\n")
+ fp.write("""
+ </body>
+ </html>
+ """)
diff --git a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
new file mode 100644
index 0000000000..e7df97ee54
--- /dev/null
+++ b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
@@ -0,0 +1,279 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <cstdarg>
+#include <cstdio>
+#include <cstdlib>
+#include <fstream>
+#include <map>
+#include <sstream>
+#include <gtest/gtest.h>
+#include "re2/re2.h"
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/model.h"
+#include "tensorflow/contrib/lite/testing/parse_testdata.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/subprocess.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/util/command_line_flags.h"
+
+namespace {
+bool FLAGS_ignore_known_bugs = true;
+} // namespace
+
+namespace tflite {
+namespace testing {
+
+// TensorFlow system environment for file system called.
+tensorflow::Env* env = tensorflow::Env::Default();
+
+// List of tests that are expected to fail when
+// --test_arg=--ignore_known_bugs=false
+// Key is a substring of the test name and value is a bug number.
+// TODO(ahentz): make sure we clean this list up frequently.
+std::map<string, string> kBrokenTests = {
+ // Add doesn't support broadcasting.
+ {R"(addd.*input_shape_1=\[1,3,4,3\],input_shape_2=\[3\])", "68500195"},
+ {R"(muld.*input_shape_1=\[1,3,4,3\],input_shape_2=\[3\])", "68500195"},
+
+ // Add only supports float32. (and "constant" tests use Add)
+ {R"(addd.*int32)", "68808744"},
+ {R"(constant.*int32)", "68808744"},
+ {R"(mul.*int32)", "68808744"},
+
+ // Toco or TFLite has a bug to deal with some constant functions with
+ // more than 1 element.
+ {R"(constant.*input_shape=\[(2|2,2,2,2)\])", "68721522"},
+
+ // L2Norm only supports 4D tensors.
+ {R"(l2normdim=.*,epsilon=.*,input_shape=\[.,.\])", "67963684"},
+ {R"(l2normdim=.*,epsilon=.*,input_shape=\[.,.,.,.,.*\])", "67963684"},
+
+ // L2Norm only works for dim=-1.
+ {R"(l2normdim=-2,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"},
+ {R"(l2normdim=-2,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"},
+ {R"(l2normdim=2,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"},
+ {R"(l2normdim=2,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"},
+ {R"(l2normdim=0,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"},
+ {R"(l2normdim=0,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"},
+ {R"(l2normdim=1,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"},
+ {R"(l2normdim=1,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"},
+ {R"(l2normdim=\[2,3\],epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"},
+ {R"(l2normdim=\[2,3\],epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"},
+
+ // ResizeBilinear looks completely incompatible with Tensorflow
+ {R"(resize_bilinear)", "67964336"},
+};
+
+// Allows test data to be unzipped into a temporary directory and makes
+// sure those temporary directories are removed later.
+class ZipEnvironment : public ::testing::Environment {
+ public:
+ ~ZipEnvironment() override {}
+
+ // Delete all temporary directories on teardown.
+ void TearDown() override {
+ for (const auto& dir : temporary_directories_) {
+ tensorflow::int64 undeleted_dirs, undeleted_files;
+ TF_CHECK_OK(
+ env->DeleteRecursively(dir, &undeleted_dirs, &undeleted_files));
+ }
+ temporary_directories_.clear();
+ }
+
+ // Unzip `zip` file into a new temporary directory `out_dir`.
+ tensorflow::Status UnZip(const std::string& zip, std::string* out_dir) {
+ string dir;
+ TF_CHECK_OK(MakeTemporaryDirectory(&dir));
+ tensorflow::SubProcess proc;
+ std::string unzip_binary =
+ "/usr/bin/unzip";
+ proc.SetProgram(unzip_binary, {"unzip", "-d", dir, zip.c_str()});
+ proc.SetChannelAction(tensorflow::CHAN_STDOUT, tensorflow::ACTION_PIPE);
+ proc.SetChannelAction(tensorflow::CHAN_STDERR, tensorflow::ACTION_PIPE);
+ if (!proc.Start())
+ return tensorflow::Status(tensorflow::error::UNKNOWN,
+ "unzip couldn't start");
+ string out, err;
+ int status = proc.Communicate(nullptr, &out, &err);
+ if (WEXITSTATUS(status) == 0) {
+ *out_dir = dir;
+ return tensorflow::Status::OK();
+ } else {
+ return tensorflow::Status(tensorflow::error::UNKNOWN, "unzip failed");
+ }
+ }
+
+ private:
+ // Make a temporary directory and return its name in `temporary`.
+ tensorflow::Status MakeTemporaryDirectory(string* temporary) {
+ if (env->LocalTempFilename(temporary)) {
+ TF_CHECK_OK(env->CreateDir(*temporary));
+ temporary_directories_.push_back(*temporary);
+ return tensorflow::Status::OK();
+ }
+ return tensorflow::Status(tensorflow::error::UNKNOWN,
+ "make temporary directory failed");
+ }
+
+ std::vector<string> temporary_directories_;
+};
+
+// Return the singleton zip_environment.
+ZipEnvironment* zip_environment() {
+ static ZipEnvironment* env = new ZipEnvironment;
+ return env;
+}
+
+// Read the manifest.txt out of the unarchived zip file. Specifically
+// `original_file` is the original zip file for error messages. `dir` is
+// the temporary directory where the zip file has been unarchived and
+// `test_paths` is the list of test prefixes that were in the manifest.
+// Note, it is an error for a manifest to contain no tests.
+tensorflow::Status ReadManifest(const std::string& original_file,
+ const std::string& dir,
+ std::vector<std::string>* test_paths) {
+ // Read the newline delimited list of entries in the manifest.
+ std::ifstream manifest_fp(dir + "/manifest.txt");
+ std::string manifest((std::istreambuf_iterator<char>(manifest_fp)),
+ std::istreambuf_iterator<char>());
+ size_t pos = 0;
+ int added = 0;
+ while (true) {
+ size_t end_pos = manifest.find("\n", pos);
+ if (end_pos == std::string::npos) break;
+ std::string filename = manifest.substr(pos, end_pos - pos);
+ test_paths->push_back(dir + "/" + filename);
+ pos = end_pos + 1;
+ added += 1;
+ }
+ if (!added) {
+ std::string message = "Test had no examples: " + original_file;
+ return tensorflow::Status(tensorflow::error::UNKNOWN, message.c_str());
+ }
+ return tensorflow::Status::OK();
+}
+
+// Get a list of tests from a zip file `zip_file_name`.
+std::vector<std::string> UnarchiveZipAndFindTestNames(
+ const std::string& zip_file_name) {
+ std::string zip_file = ::tensorflow::testing::TensorFlowSrcRoot() +
+ "/contrib/lite/testing/optest/" + zip_file_name;
+ std::string decompress_tmp_dir;
+ TF_CHECK_OK(zip_environment()->UnZip(zip_file, &decompress_tmp_dir));
+ std::vector<std::string> stuff;
+ TF_CHECK_OK(ReadManifest(zip_file, decompress_tmp_dir, &stuff));
+ return stuff;
+}
+
+class OpsTest : public ::testing::TestWithParam<std::string> {};
+
+TEST_P(OpsTest, RunStuff) {
+ std::string test_path = GetParam();
+ std::string tflite_file = test_path + ".bin";
+ std::string tflite_examples = test_path + ".inputs";
+ auto model = tflite::FlatBufferModel::BuildFromFile(tflite_file.c_str());
+ std::unique_ptr<tflite::Interpreter> interpreter;
+
+ tflite::ops::builtin::BuiltinOpResolver builtins;
+ ASSERT_EQ(tflite::InterpreterBuilder(*model, builtins)(&interpreter),
+ kTfLiteOk);
+
+ std::vector<tflite::testing::Example> examples;
+ ASSERT_EQ(tflite::testing::ParseExamples(tflite_examples.c_str(), &examples),
+ kTfLiteOk);
+
+ string bug_number;
+ for (const auto& p : kBrokenTests) {
+ if (RE2::PartialMatch(test_path, p.first)) {
+ bug_number = p.second;
+ }
+ }
+
+ for (const auto& example : examples) {
+ ASSERT_EQ(interpreter->inputs().size(), example.inputs.size());
+ auto result = [&]() {
+ TF_LITE_ENSURE_STATUS(FeedExample(interpreter.get(), example));
+ TF_LITE_ENSURE_STATUS(interpreter->Invoke());
+ TF_LITE_ENSURE_STATUS(CheckOutputs(interpreter.get(), example));
+ return kTfLiteOk;
+ }();
+
+ if (bug_number.empty()) {
+ ASSERT_EQ(result, kTfLiteOk);
+ } else {
+ if (FLAGS_ignore_known_bugs) {
+ ASSERT_EQ(result, kTfLiteError)
+ << "Not failing as expected dut to http://b/" << bug_number;
+ } else {
+ ASSERT_EQ(result, kTfLiteOk)
+ << "Possibly due to http://b/" << bug_number;
+ }
+ }
+ }
+}
+
+// Instantiate a test. This assumes `zip_base`.zip is a declared data file
+// of this test.
+#define INSTANTIATE_TESTS(zip_base) \
+ INSTANTIATE_TEST_CASE_P( \
+ zip_base, OpsTest, \
+ ::testing::ValuesIn(UnarchiveZipAndFindTestNames(#zip_base ".zip")));
+
+INSTANTIATE_TESTS(add)
+INSTANTIATE_TESTS(avg_pool)
+INSTANTIATE_TESTS(concat)
+INSTANTIATE_TESTS(constant)
+INSTANTIATE_TESTS(control_dep)
+INSTANTIATE_TESTS(conv)
+INSTANTIATE_TESTS(depthwiseconv)
+INSTANTIATE_TESTS(fully_connected)
+INSTANTIATE_TESTS(fused_batch_norm)
+INSTANTIATE_TESTS(global_batch_norm)
+INSTANTIATE_TESTS(l2norm)
+INSTANTIATE_TESTS(l2_pool)
+INSTANTIATE_TESTS(local_response_norm)
+INSTANTIATE_TESTS(max_pool)
+INSTANTIATE_TESTS(mul)
+INSTANTIATE_TESTS(relu)
+INSTANTIATE_TESTS(relu1)
+INSTANTIATE_TESTS(relu6)
+INSTANTIATE_TESTS(reshape)
+INSTANTIATE_TESTS(resize_bilinear)
+INSTANTIATE_TESTS(sigmoid)
+INSTANTIATE_TESTS(softmax)
+INSTANTIATE_TESTS(space_to_depth)
+
+} // namespace testing
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::testing::AddGlobalTestEnvironment(tflite::testing::zip_environment());
+
+ std::vector<tensorflow::Flag> flags = {tensorflow::Flag(
+ "ignore_known_bugs", &FLAGS_ignore_known_bugs,
+ "If a particular model is affected by a known bug, the "
+ "corresponding test should expect the outputs to not match.")};
+ bool success = tensorflow::Flags::Parse(&argc, argv, flags);
+ if (!success || (argc == 2 && !strcmp(argv[1], "--helpfull"))) {
+ fprintf(stderr, "%s", tensorflow::Flags::Usage(argv[0], flags).c_str());
+ return 1;
+ }
+
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/testing/message.cc b/tensorflow/contrib/lite/testing/message.cc
new file mode 100644
index 0000000000..03fae4bb86
--- /dev/null
+++ b/tensorflow/contrib/lite/testing/message.cc
@@ -0,0 +1,96 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/testing/message.h"
+
+#include <stack>
+
+#include "tensorflow/contrib/lite/testing/tokenize.h"
+
+namespace tflite {
+namespace testing {
+
+// A token processor that builds messages and forward calls to the current
+// message object. Place a new message at the top of the stack when it start
+// and remove it when it is finished.
+class MessageStack : public TokenProcessor {
+ public:
+ // Start a new MessageStack with the given first_node, which will be used to
+ // process freestanding fields and submessages.
+ explicit MessageStack(Message* first_node) {
+ nodes_.push(first_node);
+ valid_ = true;
+ }
+
+ void ConsumeToken(std::string* token) override {
+ if (!valid_) return;
+ Message* current_node = nodes_.top();
+ if (*token == "{") {
+ // This is the beginning of a new message, names after the previous token.
+ if (previous_token_.empty()) {
+ valid_ = false;
+ return;
+ }
+ nodes_.push(current_node ? current_node->AddChild(previous_token_)
+ : nullptr);
+ previous_token_.clear();
+ } else if (*token == "}") {
+ // A message is being completed. There should be no previous token. Note
+ // that the top-level message never closes, so we should always have at
+ // least one entry in the stack.
+ if (nodes_.size() == 1 || !previous_token_.empty()) {
+ valid_ = false;
+ return;
+ }
+ if (current_node) {
+ current_node->Finish();
+ }
+ nodes_.pop();
+ } else if (*token == ":") {
+ // We reached the end of the 'key' portion of a field. Store the token
+ // until we have the 'value' portion.
+ if (previous_token_.empty()) {
+ valid_ = false;
+ return;
+ }
+ } else {
+ if (previous_token_.empty()) {
+ previous_token_.swap(*token);
+ } else {
+ // This is the 'value' portion of a field. The previous token is the
+ // 'key'.
+ if (current_node) {
+ current_node->SetField(previous_token_, *token);
+ }
+ previous_token_.clear();
+ }
+ }
+ }
+
+ bool valid() const { return valid_; }
+
+ private:
+ std::stack<Message*> nodes_;
+ std::string previous_token_;
+ bool valid_;
+};
+
+bool Message::Read(std::istream* input, Message* message) {
+ MessageStack stack(message);
+ Tokenize(input, &stack);
+ return stack.valid();
+}
+
+} // namespace testing
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/testing/message.h b/tensorflow/contrib/lite/testing/message.h
new file mode 100644
index 0000000000..78ef7e2cbe
--- /dev/null
+++ b/tensorflow/contrib/lite/testing/message.h
@@ -0,0 +1,82 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_MESSAGE_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_MESSAGE_H_
+
+#include <memory>
+#include <string>
+#include <vector>
+
+namespace tflite {
+namespace testing {
+
+// A Message is a textual protobuf-like structure that looks like:
+// tag {
+// f : "values"
+// child {
+// a : 1
+// }
+// }
+// This class provides the framework for processing message but does not
+// associate any particular behavior to fields and submessage. In order
+// to properly parse a stream this class must be derived.
+class Message {
+ public:
+ // Reads a stream, tokenizes it and create a new message under the given
+ // top-level message. Returns true if the parsing succeeded.
+ static bool Read(std::istream* input, Message* message);
+
+ Message() {}
+ virtual ~Message() {}
+
+ // Called when a new field is found. For example, when:
+ // f : "values"
+ // is found, it triggers:
+ // SetField("f", "values");
+ virtual void SetField(const std::string& name, const std::string& value) {}
+
+ // Called when a submessage is started. For example, when:
+ // child {
+ // is found, it triggers
+ // AddChild("child");
+ // If nullptr is returned, the contents of the submessage will be ignored.
+ // Otherwise, the returned Message will be used to handle new fields and new
+ // submessages. The caller should not take ownership of the returned pointer.
+ virtual Message* AddChild(const std::string& name) { return nullptr; }
+
+ // Called when a submessage is completed, that is, whenever a '}' is found.
+ virtual void Finish() {}
+
+ protected:
+ // Takes ownership of the given pointer. Subclasses can use this method if
+ // they don't want to implement their own ownership semantics.
+ Message* Store(Message* n) {
+ children_.emplace_back(n);
+ return n;
+ }
+
+ // Returns a list of all owned submessages.
+ const std::vector<std::unique_ptr<Message>>& Children() const {
+ return children_;
+ }
+
+ private:
+ std::vector<std::unique_ptr<Message>> children_;
+};
+
+} // namespace testing
+} // namespace tflite
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_MESSAGE_H_
diff --git a/tensorflow/contrib/lite/testing/message_test.cc b/tensorflow/contrib/lite/testing/message_test.cc
new file mode 100644
index 0000000000..fb6a49bd6f
--- /dev/null
+++ b/tensorflow/contrib/lite/testing/message_test.cc
@@ -0,0 +1,121 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/testing/message.h"
+
+#include <map>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+namespace tflite {
+namespace testing {
+namespace {
+
+// A hierarchical, key-value store.
+class TestMessage : public Message {
+ public:
+ TestMessage() {}
+ explicit TestMessage(const std::string& text_to_parse) {
+ std::stringstream ss(text_to_parse);
+ finished_ = Message::Read(&ss, this);
+ }
+ void SetField(const std::string& name, const std::string& value) override {
+ fields_[name] = value;
+ }
+ Message* AddChild(const std::string& name) override {
+ TestMessage* m = new TestMessage;
+ m->name_ = name;
+ return Store(m);
+ }
+ void Finish() override { finished_ = true; }
+
+ int NumChildren() const { return Children().size(); }
+
+ const TestMessage* GetChild(int i) const {
+ return dynamic_cast<TestMessage*>(Children()[i].get());
+ }
+
+ int NumFields() const { return fields_.size(); }
+ const std::string& GetField(const std::string& key) const {
+ return fields_.at(key);
+ }
+
+ const std::string& name() const { return name_; }
+ bool finished() const { return finished_; }
+
+ protected:
+ std::string name_;
+ std::map<std::string, std::string> fields_;
+ bool finished_ = false;
+};
+
+TEST(MessageTest, Simple) {
+ TestMessage message("x{a:1 b:2} y{} z{c:3} d:4");
+ ASSERT_TRUE(message.finished());
+
+ ASSERT_EQ(message.NumFields(), 1);
+ EXPECT_EQ(message.GetField("d"), "4");
+
+ ASSERT_EQ(message.NumChildren(), 3);
+
+ auto* x = message.GetChild(0);
+ EXPECT_EQ(x->name(), "x");
+ ASSERT_EQ(x->NumFields(), 2);
+ EXPECT_EQ(x->GetField("a"), "1");
+ EXPECT_EQ(x->GetField("b"), "2");
+
+ auto* y = message.GetChild(1);
+ EXPECT_EQ(y->name(), "y");
+ ASSERT_EQ(y->NumFields(), 0);
+
+ auto* z = message.GetChild(2);
+ EXPECT_EQ(z->name(), "z");
+ ASSERT_EQ(z->NumFields(), 1);
+ EXPECT_EQ(z->GetField("c"), "3");
+}
+
+TEST(MessageTest, Unnamed) {
+ TestMessage message("x{c:3} {} y{d:4}");
+ ASSERT_FALSE(message.finished());
+ EXPECT_EQ(message.NumChildren(), 1);
+}
+
+TEST(MessageTest, TooManyBraces) {
+ TestMessage message("x{c:3} } y{d:4}");
+ ASSERT_FALSE(message.finished());
+ EXPECT_EQ(message.NumChildren(), 1);
+}
+
+TEST(MessageTest, LeftoverToken) {
+ TestMessage message("x{c:3} z{test} y{d:4}");
+ ASSERT_FALSE(message.finished());
+ EXPECT_EQ(message.NumChildren(), 2);
+}
+
+TEST(MessageTest, MissingKey) {
+ TestMessage message("x{c:3} z{:test} y{d:4}");
+ ASSERT_FALSE(message.finished());
+ EXPECT_EQ(message.NumChildren(), 2);
+}
+
+TEST(MessageTest, MissingValue) {
+ TestMessage message("x{c:3} z{test:} y{d:4}");
+ ASSERT_FALSE(message.finished());
+ EXPECT_EQ(message.NumChildren(), 2);
+}
+
+} // namespace
+} // namespace testing
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/testing/nnapi_example.cc b/tensorflow/contrib/lite/testing/nnapi_example.cc
new file mode 100644
index 0000000000..74f6cfc3de
--- /dev/null
+++ b/tensorflow/contrib/lite/testing/nnapi_example.cc
@@ -0,0 +1,114 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// NOTE: this is an example driver that converts a tflite model to TensorFlow.
+// This is an example that will be integrated more tightly into tflite in
+// the future.
+//
+// Usage: bazel run -c opt \
+// tensorflow/contrib/lite/nnapi:nnapi_example -- <filename>
+//
+#include <cstdarg>
+#include <cstdio>
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/model.h"
+#include "tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h"
+#include "tensorflow/contrib/lite/testing/parse_testdata.h"
+
+// TODO(aselle): FATAL leaves resources hanging.
+void FATAL(const char* format, ...) {
+ va_list args;
+ va_start(args, format);
+ vfprintf(stderr, format, args);
+ va_end(args);
+ fflush(stderr);
+ exit(1);
+}
+
+#define CHECK_TFLITE_SUCCESS(x) \
+ if (x != kTfLiteOk) { \
+ FATAL("Aborting since tflite returned failure."); \
+ }
+
+void Interpret(const char* filename, const char* examples_filename,
+ bool use_nnapi) {
+ // TODO(aselle): Resize of input image should go here
+ // ...
+ // For now I am allocating all tensors. This means I am fixed size.
+ // So I am not using the variable size ability yet.
+ fprintf(stderr, "example file %s\n", examples_filename);
+ std::vector<tflite::testing::Example> examples;
+ CHECK_TFLITE_SUCCESS(
+ tflite::testing::ParseExamples(examples_filename, &examples));
+
+ for (const tflite::testing::Example& example : examples) {
+ auto model = tflite::FlatBufferModel::BuildFromFile(filename);
+ if (!model) FATAL("Cannot read file %s\n", filename);
+ std::unique_ptr<tflite::Interpreter> interpreter;
+ tflite::ops::builtin::BuiltinOpResolver builtins;
+
+ CHECK_TFLITE_SUCCESS(
+ tflite::InterpreterBuilder(*model, builtins)(&interpreter));
+
+ printf("Use nnapi is set to: %d\n", use_nnapi);
+ interpreter->UseNNAPI(use_nnapi);
+ CHECK_TFLITE_SUCCESS(
+ tflite::testing::FeedExample(interpreter.get(), example));
+
+ {
+ TfLiteTensor* tensor = interpreter->tensor(interpreter->outputs()[0]);
+ if (float* data =
+ interpreter->typed_tensor<float>(interpreter->outputs()[0])) {
+ size_t num = tensor->bytes / sizeof(float);
+ for (float* p = data; p < data + num; p++) {
+ *p = 0;
+ }
+ }
+ }
+ interpreter->Invoke();
+
+ CHECK_TFLITE_SUCCESS(
+ tflite::testing::CheckOutputs(interpreter.get(), example));
+
+ printf("Result:\n");
+ TfLiteTensor* tensor = interpreter->tensor(interpreter->outputs()[0]);
+ if (float* data =
+ interpreter->typed_tensor<float>(interpreter->outputs()[0])) {
+ size_t num = tensor->bytes / sizeof(float);
+ for (float* p = data; p < data + num; p++) {
+ printf(" %f", *p);
+ }
+ }
+ }
+}
+
+int main(int argc, char* argv[]) {
+ bool use_nnapi = true;
+ if (argc == 4) {
+ use_nnapi = strcmp(argv[3], "1") == 0 ? true : false;
+ }
+ if (argc < 3) {
+ fprintf(stderr,
+ "Compiled " __DATE__ __TIME__
+ "\n"
+ "Usage!!!: %s <tflite model> <examples to test> "
+ "{ use nn api i.e. 0,1}\n",
+ argv[0]);
+ return 1;
+ }
+ Interpret(argv[1], argv[2], use_nnapi);
+ return 0;
+}
diff --git a/tensorflow/contrib/lite/testing/parse_testdata.cc b/tensorflow/contrib/lite/testing/parse_testdata.cc
new file mode 100644
index 0000000000..2b67052cad
--- /dev/null
+++ b/tensorflow/contrib/lite/testing/parse_testdata.cc
@@ -0,0 +1,335 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// Parses tflite example input data.
+// Format is ASCII
+// TODO(aselle): Switch to protobuf, but the android team requested a simple
+// ASCII file.
+#include "tensorflow/contrib/lite/testing/parse_testdata.h"
+
+#include <cmath>
+#include <cstdint>
+#include <cstdio>
+#include <fstream>
+#include <iostream>
+#include <streambuf>
+
+#include "tensorflow/contrib/lite/error_reporter.h"
+#include "tensorflow/contrib/lite/testing/message.h"
+#include "tensorflow/contrib/lite/testing/split.h"
+
+namespace tflite {
+namespace testing {
+namespace {
+
+// Fatal error if parse error occurs
+#define PARSE_CHECK_EQ(filename, current_line, x, y) \
+ if ((x) != (y)) { \
+ fprintf(stderr, "Parse Error @ %s:%d\n File %s\n Line %d, %s != %s\n", \
+ __FILE__, __LINE__, filename, current_line + 1, #x, #y); \
+ return kTfLiteError; \
+ }
+
+// Breakup a "," delimited line into a std::vector<std::string>.
+// This is extremely inefficient, and just used for testing code.
+// TODO(aselle): replace with absl when we use it.
+std::vector<std::string> ParseLine(const std::string& line) {
+ size_t pos = 0;
+ std::vector<std::string> elements;
+ while (true) {
+ size_t end = line.find(',', pos);
+ if (end == std::string::npos) {
+ elements.push_back(line.substr(pos));
+ break;
+ } else {
+ elements.push_back(line.substr(pos, end - pos));
+ }
+ pos = end + 1;
+ }
+ return elements;
+}
+
+} // namespace
+
+// Given a `filename`, produce a vector of Examples corresopnding
+// to test cases that can be applied to a tflite model.
+TfLiteStatus ParseExamples(const char* filename,
+ std::vector<Example>* examples) {
+ std::ifstream fp(filename);
+ if (!fp.good()) {
+ fprintf(stderr, "Could not read '%s'\n", filename);
+ return kTfLiteError;
+ }
+ std::string str((std::istreambuf_iterator<char>(fp)),
+ std::istreambuf_iterator<char>());
+ size_t pos = 0;
+
+ // \n and , delimit parse a file.
+ std::vector<std::vector<std::string>> csv;
+ while (true) {
+ size_t end = str.find('\n', pos);
+
+ if (end == std::string::npos) {
+ csv.emplace_back(ParseLine(str.substr(pos)));
+ break;
+ }
+ csv.emplace_back(ParseLine(str.substr(pos, end - pos)));
+ pos = end + 1;
+ }
+
+ int current_line = 0;
+ PARSE_CHECK_EQ(filename, current_line, csv[0][0], "test_cases");
+ int example_count = std::stoi(csv[0][1]);
+ current_line++;
+
+ auto parse_tensor = [&filename, &current_line,
+ &csv](FloatTensor* tensor_ptr) {
+ PARSE_CHECK_EQ(filename, current_line, csv[current_line][0], "dtype");
+ current_line++;
+ // parse shape
+ PARSE_CHECK_EQ(filename, current_line, csv[current_line][0], "shape");
+ size_t elements = 1;
+ FloatTensor& tensor = *tensor_ptr;
+
+ for (size_t i = 1; i < csv[current_line].size(); i++) {
+ const auto& shape_part_to_parse = csv[current_line][i];
+ if (shape_part_to_parse.empty()) {
+ // Case of a 0-dimensional shape
+ break;
+ }
+ int shape_part = std::stoi(shape_part_to_parse);
+ elements *= shape_part;
+ tensor.shape.push_back(shape_part);
+ }
+ current_line++;
+ // parse data
+ PARSE_CHECK_EQ(filename, current_line, csv[current_line].size() - 1,
+ elements);
+ for (size_t i = 1; i < csv[current_line].size(); i++) {
+ tensor.flat_data.push_back(std::stof(csv[current_line][i]));
+ }
+ current_line++;
+
+ return kTfLiteOk;
+ };
+
+ for (int example_idx = 0; example_idx < example_count; example_idx++) {
+ Example example;
+ PARSE_CHECK_EQ(filename, current_line, csv[current_line][0], "inputs");
+ int inputs = std::stoi(csv[current_line][1]);
+ current_line++;
+ // parse dtype
+ for (int input_index = 0; input_index < inputs; input_index++) {
+ example.inputs.push_back(FloatTensor());
+ TF_LITE_ENSURE_STATUS(parse_tensor(&example.inputs.back()));
+ }
+
+ PARSE_CHECK_EQ(filename, current_line, csv[current_line][0], "outputs");
+ int outputs = std::stoi(csv[current_line][1]);
+ current_line++;
+ for (int input_index = 0; input_index < outputs; input_index++) {
+ example.outputs.push_back(FloatTensor());
+ TF_LITE_ENSURE_STATUS(parse_tensor(&example.outputs.back()));
+ }
+ examples->emplace_back(example);
+ }
+ return kTfLiteOk;
+}
+
+TfLiteStatus FeedExample(tflite::Interpreter* interpreter,
+ const Example& example) {
+ // Resize inputs to match example & allocate.
+ for (size_t i = 0; i < interpreter->inputs().size(); i++) {
+ int input_index = interpreter->inputs()[i];
+
+ TF_LITE_ENSURE_STATUS(
+ interpreter->ResizeInputTensor(input_index, example.inputs[i].shape));
+ }
+ TF_LITE_ENSURE_STATUS(interpreter->AllocateTensors());
+ // Copy data into tensors.
+ for (size_t i = 0; i < interpreter->inputs().size(); i++) {
+ int input_index = interpreter->inputs()[i];
+ if (float* data = interpreter->typed_tensor<float>(input_index)) {
+ for (size_t idx = 0; idx < example.inputs[i].flat_data.size(); idx++) {
+ data[idx] = example.inputs[i].flat_data[idx];
+ }
+ } else if (int32_t* data =
+ interpreter->typed_tensor<int32_t>(input_index)) {
+ for (size_t idx = 0; idx < example.inputs[i].flat_data.size(); idx++) {
+ data[idx] = example.inputs[i].flat_data[idx];
+ }
+ } else {
+ fprintf(stderr, "input[%zu] was not float or int data\n", i);
+ return kTfLiteError;
+ }
+ }
+ return kTfLiteOk;
+}
+
+TfLiteStatus CheckOutputs(tflite::Interpreter* interpreter,
+ const Example& example) {
+ constexpr double kRelativeThreshold = 1e-2f;
+ constexpr double kAbsoluteThreshold = 1e-4f;
+
+ ErrorReporter* context = DefaultErrorReporter();
+ int model_outputs = interpreter->outputs().size();
+ TF_LITE_ENSURE_EQ(context, model_outputs, example.outputs.size());
+ for (size_t i = 0; i < interpreter->outputs().size(); i++) {
+ int output_index = interpreter->outputs()[i];
+ if (const float* data = interpreter->typed_tensor<float>(output_index)) {
+ for (size_t idx = 0; idx < example.outputs[i].flat_data.size(); idx++) {
+ float computed = data[idx];
+ float reference = example.outputs[0].flat_data[idx];
+ float diff = std::abs(computed - reference);
+ bool error_is_large = false;
+ // For very small numbers, try absolute error, otherwise go with
+ // relative.
+ if (std::abs(reference) < kRelativeThreshold) {
+ error_is_large = (diff > kAbsoluteThreshold);
+ } else {
+ error_is_large = (diff > kRelativeThreshold * std::abs(reference));
+ }
+ if (error_is_large) {
+ fprintf(stdout, "output[%zu][%zu] did not match %f vs reference %f\n",
+ i, idx, data[idx], reference);
+ return kTfLiteError;
+ }
+ }
+ fprintf(stderr, "\n");
+ } else if (const int32_t* data =
+ interpreter->typed_tensor<int32_t>(output_index)) {
+ for (size_t idx = 0; idx < example.outputs[i].flat_data.size(); idx++) {
+ int32_t computed = data[idx];
+ int32_t reference = example.outputs[0].flat_data[idx];
+ if (std::abs(computed - reference) > 0) {
+ fprintf(stderr, "output[%zu][%zu] did not match %d vs reference %f\n",
+ i, idx, data[idx], example.outputs[0].flat_data[idx]);
+ return kTfLiteError;
+ }
+ }
+ fprintf(stderr, "\n");
+ } else {
+ fprintf(stderr, "output[%zu] was not float or int data\n", i);
+ return kTfLiteError;
+ }
+ }
+ return kTfLiteOk;
+}
+
+// Process an 'invoke' message, triggering execution of the test runner, as
+// well as verification of outputs. An 'invoke' message looks like:
+// invoke {
+// id: xyz
+// input: 1,2,1,1,1,2,3,4
+// ouput: 4,5,6
+// }
+class Invoke : public Message {
+ public:
+ explicit Invoke(TestRunner* test_runner) : test_runner_(test_runner) {
+ expected_inputs_ = test_runner->GetInputs();
+ expected_outputs_ = test_runner->GetOutputs();
+ }
+
+ void SetField(const std::string& name, const std::string& value) override {
+ if (name == "id") {
+ test_runner_->SetInvocationId(value);
+ } else if (name == "input") {
+ if (expected_inputs_.empty()) {
+ return test_runner_->Invalidate("Too many inputs");
+ }
+ test_runner_->SetInput(*expected_inputs_.begin(), value);
+ expected_inputs_.erase(expected_inputs_.begin());
+ } else if (name == "output") {
+ if (expected_outputs_.empty()) {
+ return test_runner_->Invalidate("Too many outputs");
+ }
+ test_runner_->SetExpectation(*expected_outputs_.begin(), value);
+ expected_outputs_.erase(expected_outputs_.begin());
+ }
+ }
+ void Finish() override {
+ test_runner_->Invoke();
+ test_runner_->CheckResults();
+ }
+
+ private:
+ std::vector<int> expected_inputs_;
+ std::vector<int> expected_outputs_;
+
+ TestRunner* test_runner_;
+};
+
+// Process an 'reshape' message, triggering resizing of the input tensors via
+// the test runner. A 'reshape' message looks like:
+// reshape {
+// input: 1,2,1,1,1,2,3,4
+// }
+class Reshape : public Message {
+ public:
+ explicit Reshape(TestRunner* test_runner) : test_runner_(test_runner) {
+ expected_inputs_ = test_runner->GetInputs();
+ }
+
+ void SetField(const std::string& name, const std::string& value) override {
+ if (name == "input") {
+ if (expected_inputs_.empty()) {
+ return test_runner_->Invalidate("Too many inputs to reshape");
+ }
+ test_runner_->ReshapeTensor(*expected_inputs_.begin(), value);
+ expected_inputs_.erase(expected_inputs_.begin());
+ }
+ }
+
+ private:
+ std::vector<int> expected_inputs_;
+ TestRunner* test_runner_;
+};
+
+// This is the top-level message in a test file.
+class TestData : public Message {
+ public:
+ explicit TestData(TestRunner* test_runner) : test_runner_(test_runner) {}
+
+ void SetField(const std::string& name, const std::string& value) override {
+ if (name == "load_model") {
+ test_runner_->LoadModel(value);
+ } else if (name == "init_state") {
+ test_runner_->AllocateTensors();
+ for (int id : Split<int>(value, ",")) {
+ test_runner_->ResetTensor(id);
+ }
+ }
+ }
+ Message* AddChild(const std::string& s) override {
+ if (s == "invoke") {
+ test_runner_->AllocateTensors();
+ return Store(new Invoke(test_runner_));
+ } else if (s == "reshape") {
+ return Store(new Reshape(test_runner_));
+ }
+ return nullptr;
+ }
+
+ private:
+ TestRunner* test_runner_;
+};
+
+bool ParseAndRunTests(std::istream* input, TestRunner* test_runner) {
+ TestData test_data(test_runner);
+ Message::Read(input, &test_data);
+ return test_runner->IsValid() && test_runner->GetOverallSuccess();
+}
+
+} // namespace testing
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/testing/parse_testdata.h b/tensorflow/contrib/lite/testing/parse_testdata.h
new file mode 100644
index 0000000000..90839fe245
--- /dev/null
+++ b/tensorflow/contrib/lite/testing/parse_testdata.h
@@ -0,0 +1,74 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_NNAPI_PARSE_TESTDATA_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_NNAPI_PARSE_TESTDATA_H_
+
+#include <vector>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/testing/test_runner.h"
+
+namespace tflite {
+namespace testing {
+
+// Shape and data for a float tensor
+struct FloatTensor {
+ std::vector<int> shape;
+ std::vector<float> flat_data;
+};
+
+// A prescribed input, output example
+struct Example {
+ std::vector<FloatTensor> inputs;
+ std::vector<FloatTensor> outputs;
+};
+
+// Parses an example input and output file (used for unit tests)
+TfLiteStatus ParseExamples(const char* filename,
+ std::vector<Example>* examples);
+
+// Inputs Tensors into a TensorFlow lite interpreter. Note, this will run
+// interpreter.AllocateTensors();
+TfLiteStatus FeedExample(tflite::Interpreter* interpreter, const Example&);
+
+// Check outputs against (already) evaluated result.
+TfLiteStatus CheckOutputs(tflite::Interpreter* interpreter, const Example&);
+
+// Parses a test description and feeds the given test runner with data.
+// The input format is similar to an ASCII proto:
+// // Loads model 'add.bin' from the TestRunner's model directory.
+// load_model: "add.bin"
+// // Changes the shape of inputs, provided in the same order they appear
+// // in the model.
+// reshape {
+// input: "1,224,224,3"
+// input: "1,3,4,1"
+// }
+// // Fills the given persistent tensors with zeros.
+// init_state: 0,1,2,3
+// // Invokes the interpreter with the given input and checks that it
+// // produces the expected output. Inputs and outputs should be specified in
+// // the order they appear in the model.
+// invoke {
+// input: "1,2,3,4,56"
+// input: "0.1,0.2,0.3,4.3,56.4"
+// output: "12,3,4,545,3"
+// output: "0.01,0.02"
+// }
+bool ParseAndRunTests(std::istream* input, TestRunner* test_runner);
+
+} // namespace testing
+} // namespace tflite
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_NNAPI_PARSE_TESTDATA_H_
diff --git a/tensorflow/contrib/lite/testing/split.cc b/tensorflow/contrib/lite/testing/split.cc
new file mode 100644
index 0000000000..5836f4ff04
--- /dev/null
+++ b/tensorflow/contrib/lite/testing/split.cc
@@ -0,0 +1,42 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/testing/split.h"
+
+namespace tflite {
+namespace testing {
+
+std::vector<std::pair<size_t, size_t>> SplitToPos(const string& s,
+ const string& delimiter) {
+ std::vector<std::pair<size_t, size_t>> fields;
+ if (delimiter.length() == 0) {
+ fields.emplace_back(0, s.length());
+ return fields;
+ }
+ size_t pos = 0;
+ size_t start = 0;
+ while ((pos = s.find(delimiter, start)) != string::npos) {
+ if (pos != start) {
+ fields.emplace_back(start, pos);
+ }
+ start = pos + delimiter.length();
+ }
+ if (start != s.length()) {
+ fields.emplace_back(start, s.length());
+ }
+ return fields;
+}
+
+} // namespace testing
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/testing/split.h b/tensorflow/contrib/lite/testing/split.h
new file mode 100644
index 0000000000..24071442e8
--- /dev/null
+++ b/tensorflow/contrib/lite/testing/split.h
@@ -0,0 +1,77 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_SPLIT_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_SPLIT_H_
+
+#include <cstdlib>
+#include <string>
+#include <utility>
+#include <vector>
+#include "tensorflow/contrib/lite/string.h"
+
+namespace tflite {
+namespace testing {
+
+// Splits a string based on the given delimiter string. Each pair in the
+// returned vector has the start and past-the-end positions for each of the
+// parts of the original string. Empty fields are not represented in the
+// output.
+std::vector<std::pair<size_t, size_t>> SplitToPos(const string& s,
+ const string& delimiter);
+
+// Splits the given string and converts each part to the given T.
+template <typename T>
+std::vector<T> Split(const string& s, const string& delimiter);
+
+template <>
+inline std::vector<string> Split(const string& s, const string& delimiter) {
+ std::vector<string> fields;
+ for (const auto& p : SplitToPos(s, delimiter)) {
+ fields.push_back(s.substr(p.first, p.second - p.first));
+ }
+ return fields;
+}
+
+template <>
+inline std::vector<int> Split(const string& s, const string& delimiter) {
+ std::vector<int> fields;
+ for (const auto& p : SplitToPos(s, delimiter)) {
+ fields.push_back(strtol(s.data() + p.first, nullptr, 10));
+ }
+ return fields;
+}
+
+template <>
+inline std::vector<float> Split(const string& s, const string& delimiter) {
+ std::vector<float> fields;
+ for (const auto& p : SplitToPos(s, delimiter)) {
+ fields.push_back(strtod(s.data() + p.first, nullptr));
+ }
+ return fields;
+}
+
+template <>
+inline std::vector<uint8_t> Split(const string& s, const string& delimiter) {
+ std::vector<uint8_t> fields;
+ for (const auto& p : SplitToPos(s, delimiter)) {
+ fields.push_back(strtol(s.data() + p.first, nullptr, 10));
+ }
+ return fields;
+}
+
+} // namespace testing
+} // namespace tflite
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_SPLIT_H_
diff --git a/tensorflow/contrib/lite/testing/split_test.cc b/tensorflow/contrib/lite/testing/split_test.cc
new file mode 100644
index 0000000000..3d1e25d9c7
--- /dev/null
+++ b/tensorflow/contrib/lite/testing/split_test.cc
@@ -0,0 +1,57 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/testing/split.h"
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+namespace tflite {
+namespace testing {
+namespace {
+
+using ::testing::ElementsAre;
+using ::testing::Pair;
+
+TEST(SplitTest, SplitToPos) {
+ EXPECT_THAT(SplitToPos("test;:1-2-3 ;: test", ";:"),
+ ElementsAre(Pair(0, 4), Pair(6, 12), Pair(14, 19)));
+ EXPECT_THAT(SplitToPos("test;:1-2-3 ;: test", ":"),
+ ElementsAre(Pair(0, 5), Pair(6, 13), Pair(14, 19)));
+ EXPECT_THAT(SplitToPos("test", ":"), ElementsAre(Pair(0, 4)));
+ EXPECT_THAT(SplitToPos("test ", ":"), ElementsAre(Pair(0, 5)));
+ EXPECT_THAT(SplitToPos("", ":"), ElementsAre());
+ EXPECT_THAT(SplitToPos("test ", ""), ElementsAre(Pair(0, 5)));
+ EXPECT_THAT(SplitToPos("::::", ":"), ElementsAre());
+}
+
+TEST(SplitTest, SplitString) {
+ EXPECT_THAT(Split<string>("A;B;C", ";"), ElementsAre("A", "B", "C"));
+}
+
+TEST(SplitTest, SplitFloat) {
+ EXPECT_THAT(Split<float>("1.0 B 1e-5", " "), ElementsAre(1.0, 0.0, 1e-5));
+}
+
+TEST(SplitTest, SplitInt) {
+ EXPECT_THAT(Split<int>("1,-1,258", ","), ElementsAre(1, -1, 258));
+}
+
+TEST(SplitTest, SplitUint8) {
+ EXPECT_THAT(Split<uint8_t>("1,-1,258", ","), ElementsAre(1, 255, 2));
+}
+
+} // namespace
+} // namespace testing
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/testing/test_runner.h b/tensorflow/contrib/lite/testing/test_runner.h
new file mode 100644
index 0000000000..04ee4d9f7d
--- /dev/null
+++ b/tensorflow/contrib/lite/testing/test_runner.h
@@ -0,0 +1,124 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_TEST_RUNNER_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_TEST_RUNNER_H_
+
+#include <memory>
+#include <string>
+#include <vector>
+#include "tensorflow/contrib/lite/string.h"
+
+namespace tflite {
+namespace testing {
+
+// This is the base class for processing test data. Each one of the virtual
+// methods must be implemented to forward the data to the appropriate executor
+// (e.g. TF Lite's interpreter, or the NNAPI).
+class TestRunner {
+ public:
+ TestRunner() {}
+ virtual ~TestRunner() {}
+
+ // Load the given model, as a path relative to SetModelBaseDir().
+ virtual void LoadModel(const string& bin_file_path) = 0;
+
+ // Return the list of input tensors in the loaded model.
+ virtual const std::vector<int>& GetInputs() = 0;
+
+ // Return the list of output tensors in the loaded model.
+ virtual const std::vector<int>& GetOutputs() = 0;
+
+ // Prepare for a run by resize the given tensor. The given 'id' is
+ // guaranteed to be one of the ids returned by GetInputs().
+ virtual void ReshapeTensor(int id, const string& csv_values) = 0;
+
+ // Reserve memory for all tensors.
+ virtual void AllocateTensors() = 0;
+
+ // Set the given tensor to some initial state, usually zero. This is
+ // used to reset persistent buffers in a model.
+ virtual void ResetTensor(int id) = 0;
+
+ // Define the contents of the given input tensor. The given 'id' is
+ // guaranteed to be one of the ids returned by GetInputs().
+ virtual void SetInput(int id, const string& csv_values) = 0;
+
+ // Define what should be expected for an output tensor after Invoke() runs.
+ // The given 'id' is guaranteed to be one of the ids returned by
+ // GetOutputs().
+ virtual void SetExpectation(int id, const string& csv_values) = 0;
+
+ // Run the model.
+ virtual void Invoke() = 0;
+
+ // Verify that the contents of all ouputs conform to the existing
+ // expectations. Return true if there are no expectations or they are all
+ // satisfied.
+ virtual bool CheckResults() = 0;
+
+ // Set the base path for loading models.
+ void SetModelBaseDir(const string& path) {
+ model_base_dir_ = path;
+ if (path[path.length() - 1] != '/') {
+ model_base_dir_ += "/";
+ }
+ }
+
+ // Return the full path of a model.
+ string GetFullPath(const string& path) { return model_base_dir_ + path; }
+
+ // Give an id to the next invocation to make error reporting more meaningful.
+ void SetInvocationId(const string& id) { invocation_id_ = id; }
+ const string& GetInvocationId() const { return invocation_id_; }
+
+ // Invalidate the test runner, preventing it from executing any further.
+ void Invalidate(const string& error_message) {
+ error_message_ = error_message;
+ }
+ bool IsValid() const { return error_message_.empty(); }
+ const string& GetErrorMessage() const { return error_message_; }
+
+ // Handle the overall success of this test runner. This will be true if all
+ // invocations were successful.
+ void SetOverallSuccess(bool value) { overall_success_ = value; }
+ bool GetOverallSuccess() const { return overall_success_; }
+
+ protected:
+ // A helper to check of the given number of values is consistent with the
+ // number of bytes in a tensor of type T. When incompatibles sizes are found,
+ // the test runner is invalidated and false is returned.
+ template <typename T>
+ bool CheckSizes(size_t tensor_bytes, size_t num_values) {
+ size_t num_tensor_elements = tensor_bytes / sizeof(T);
+ if (num_tensor_elements != num_values) {
+ Invalidate("Expected '" + std::to_string(num_tensor_elements) +
+ "' elements for a tensor, but only got '" +
+ std::to_string(num_values) + "'");
+ return false;
+ }
+ return true;
+ }
+
+ private:
+ string model_base_dir_;
+ string invocation_id_;
+ bool overall_success_ = true;
+
+ string error_message_;
+};
+
+} // namespace testing
+} // namespace tflite
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_TEST_RUNNER_H_
diff --git a/tensorflow/contrib/lite/testing/test_runner_test.cc b/tensorflow/contrib/lite/testing/test_runner_test.cc
new file mode 100644
index 0000000000..f712a5347a
--- /dev/null
+++ b/tensorflow/contrib/lite/testing/test_runner_test.cc
@@ -0,0 +1,84 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/testing/test_runner.h"
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+namespace tflite {
+namespace testing {
+namespace {
+
+class ConcreteTestRunner : public TestRunner {
+ public:
+ void LoadModel(const string& bin_file_path) override {}
+ const std::vector<int>& GetInputs() override { return ids_; }
+ const std::vector<int>& GetOutputs() override { return ids_; }
+ void ReshapeTensor(int id, const string& csv_values) override {}
+ void AllocateTensors() override {}
+ void ResetTensor(int id) override {}
+ void SetInput(int id, const string& csv_values) override {}
+ void SetExpectation(int id, const string& csv_values) override {}
+ void Invoke() override {}
+ bool CheckResults() override { return true; }
+ bool CheckFloatSizes(size_t bytes, size_t values) {
+ return CheckSizes<float>(bytes, values);
+ }
+
+ private:
+ std::vector<int> ids_;
+};
+
+TEST(TestRunner, ModelPath) {
+ ConcreteTestRunner runner;
+ EXPECT_EQ(runner.GetFullPath("test.bin"), "test.bin");
+ runner.SetModelBaseDir("/tmp");
+ EXPECT_EQ(runner.GetFullPath("test.bin"), "/tmp/test.bin");
+}
+
+TEST(TestRunner, InvocationId) {
+ ConcreteTestRunner runner;
+ EXPECT_EQ(runner.GetInvocationId(), "");
+ runner.SetInvocationId("X");
+ EXPECT_EQ(runner.GetInvocationId(), "X");
+}
+
+TEST(TestRunner, Invalidation) {
+ ConcreteTestRunner runner;
+ EXPECT_TRUE(runner.IsValid());
+ EXPECT_EQ(runner.GetErrorMessage(), "");
+ runner.Invalidate("Some Error");
+ EXPECT_FALSE(runner.IsValid());
+ EXPECT_EQ(runner.GetErrorMessage(), "Some Error");
+}
+
+TEST(TestRunner, OverallSuccess) {
+ ConcreteTestRunner runner;
+ EXPECT_TRUE(runner.GetOverallSuccess());
+ runner.SetOverallSuccess(false);
+ EXPECT_FALSE(runner.GetOverallSuccess());
+}
+
+TEST(TestRunner, CheckSizes) {
+ ConcreteTestRunner runner;
+ EXPECT_TRUE(runner.CheckFloatSizes(16, 4));
+ EXPECT_FALSE(runner.CheckFloatSizes(16, 2));
+ EXPECT_EQ(runner.GetErrorMessage(),
+ "Expected '4' elements for a tensor, but only got '2'");
+}
+
+} // namespace
+} // namespace testing
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/testing/tflite_driver.cc b/tensorflow/contrib/lite/testing/tflite_driver.cc
new file mode 100644
index 0000000000..cf9df2ec26
--- /dev/null
+++ b/tensorflow/contrib/lite/testing/tflite_driver.cc
@@ -0,0 +1,208 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/testing/tflite_driver.h"
+
+#include <iostream>
+
+#include "tensorflow/contrib/lite/testing/split.h"
+
+namespace tflite {
+namespace testing {
+
+namespace {
+
+// Returns the value in the given position in a tensor.
+template <typename T>
+T Value(const TfLitePtrUnion& data, int index);
+template <>
+float Value(const TfLitePtrUnion& data, int index) {
+ return data.f[index];
+}
+template <>
+uint8_t Value(const TfLitePtrUnion& data, int index) {
+ return data.uint8[index];
+}
+
+template <typename T>
+void SetTensorData(const std::vector<T>& values, TfLitePtrUnion* data) {
+ T* input_ptr = reinterpret_cast<T*>(data->raw);
+ for (const T& v : values) {
+ *input_ptr = v;
+ ++input_ptr;
+ }
+}
+
+} // namespace
+
+class TfLiteDriver::Expectation {
+ public:
+ Expectation() { data_.raw = nullptr; }
+ ~Expectation() { delete[] data_.raw; }
+ template <typename T>
+ void SetData(const string& csv_values) {
+ const auto& values = testing::Split<T>(csv_values, ",");
+ data_.raw = new char[values.size() * sizeof(T)];
+ SetTensorData(values, &data_);
+ }
+
+ bool Check(bool verbose, const TfLiteTensor& tensor) {
+ switch (tensor.type) {
+ case kTfLiteFloat32:
+ return TypedCheck<float>(verbose, tensor);
+ case kTfLiteUInt8:
+ return TypedCheck<uint8_t>(verbose, tensor);
+ default:
+ return false;
+ }
+ }
+
+ private:
+ template <typename T>
+ bool TypedCheck(bool verbose, const TfLiteTensor& tensor) {
+ int tensor_size = tensor.bytes / sizeof(T);
+
+ bool good_output = true;
+ for (int i = 0; i < tensor_size; ++i) {
+ if (std::abs(Value<T>(data_, i) - Value<T>(tensor.data, i)) > 1e-5) {
+ good_output = false;
+ if (verbose) {
+ std::cerr << " index " << i << ": " << Value<T>(data_, i)
+ << " != " << Value<T>(tensor.data, i) << std::endl;
+ }
+ }
+ }
+ return good_output;
+ }
+
+ TfLitePtrUnion data_;
+};
+
+TfLiteDriver::TfLiteDriver(bool use_nnapi) : use_nnapi_(use_nnapi) {}
+TfLiteDriver::~TfLiteDriver() {}
+
+void TfLiteDriver::AllocateTensors() {
+ if (must_allocate_tensors_) {
+ if (interpreter_->AllocateTensors() != kTfLiteOk) {
+ std::cerr << "Failed to allocate tensors" << std::endl;
+ abort();
+ }
+ must_allocate_tensors_ = false;
+ }
+}
+
+void TfLiteDriver::LoadModel(const string& bin_file_path) {
+ if (!IsValid()) return;
+ std::cout << std::endl << "Loading model: " << bin_file_path << std::endl;
+
+ model_ = FlatBufferModel::BuildFromFile(GetFullPath(bin_file_path).c_str());
+ if (!model_) {
+ Invalidate("Failed to mmap model " + bin_file_path);
+ return;
+ }
+ ops::builtin::BuiltinOpResolver builtins;
+ InterpreterBuilder(*model_, builtins)(&interpreter_);
+ if (!interpreter_) {
+ Invalidate("Failed build interpreter");
+ return;
+ }
+
+ must_allocate_tensors_ = true;
+}
+
+void TfLiteDriver::ResetTensor(int id) {
+ if (!IsValid()) return;
+ auto* tensor = interpreter_->tensor(id);
+ memset(tensor->data.raw, 0, tensor->bytes);
+}
+
+void TfLiteDriver::ReshapeTensor(int id, const string& csv_values) {
+ if (!IsValid()) return;
+ if (interpreter_->ResizeInputTensor(
+ id, testing::Split<int>(csv_values, ",")) != kTfLiteOk) {
+ Invalidate("Failed to resize input tensor " + std::to_string(id));
+ return;
+ }
+ must_allocate_tensors_ = true;
+}
+
+void TfLiteDriver::SetInput(int id, const string& csv_values) {
+ if (!IsValid()) return;
+ auto* tensor = interpreter_->tensor(id);
+ switch (tensor->type) {
+ case kTfLiteFloat32: {
+ const auto& values = testing::Split<float>(csv_values, ",");
+ if (!CheckSizes<float>(tensor->bytes, values.size())) return;
+ SetTensorData(values, &tensor->data);
+ break;
+ }
+ case kTfLiteUInt8: {
+ const auto& values = testing::Split<uint8_t>(csv_values, ",");
+ if (!CheckSizes<uint8_t>(tensor->bytes, values.size())) return;
+ SetTensorData(values, &tensor->data);
+ break;
+ }
+ default:
+ Invalidate("Unsupported tensor data type");
+ return;
+ }
+}
+
+void TfLiteDriver::SetExpectation(int id, const string& csv_values) {
+ if (!IsValid()) return;
+ auto* tensor = interpreter_->tensor(id);
+ expected_output_[id].reset(new Expectation);
+ switch (tensor->type) {
+ case kTfLiteFloat32:
+ expected_output_[id]->SetData<float>(csv_values);
+ break;
+ case kTfLiteUInt8:
+ expected_output_[id]->SetData<uint8_t>(csv_values);
+ break;
+ default:
+ Invalidate("Unsupported tensor data type");
+ return;
+ }
+}
+
+void TfLiteDriver::Invoke() {
+ if (!IsValid()) return;
+ if (interpreter_->Invoke() != kTfLiteOk) {
+ Invalidate("Failed to invoke interpreter");
+ }
+}
+
+bool TfLiteDriver::CheckResults() {
+ if (!IsValid()) return false;
+ bool success = true;
+ for (const auto& p : expected_output_) {
+ int id = p.first;
+ auto* tensor = interpreter_->tensor(id);
+ if (!p.second->Check(/*verbose=*/false, *tensor)) {
+ // Do not invalidate anything here. Instead, simply output the
+ // differences and return false. Invalidating would prevent all
+ // subsequent invocations from running..
+ std::cerr << "There were errors in invocation '" << GetInvocationId()
+ << "', output tensor '" << id << "':" << std::endl;
+ p.second->Check(/*verbose=*/true, *tensor);
+ success = false;
+ SetOverallSuccess(false);
+ }
+ }
+ expected_output_.clear();
+ return success;
+}
+
+} // namespace testing
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/testing/tflite_driver.h b/tensorflow/contrib/lite/testing/tflite_driver.h
new file mode 100644
index 0000000000..4440d4285e
--- /dev/null
+++ b/tensorflow/contrib/lite/testing/tflite_driver.h
@@ -0,0 +1,62 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_TFLITE_DRIVER_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_TFLITE_DRIVER_H_
+
+#include <map>
+
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/model.h"
+#include "tensorflow/contrib/lite/testing/test_runner.h"
+
+namespace tflite {
+namespace testing {
+
+// A test runner that feeds inputs into TF Lite and verifies its outputs.
+class TfLiteDriver : public TestRunner {
+ public:
+ explicit TfLiteDriver(bool use_nnapi);
+ ~TfLiteDriver() override;
+
+ void LoadModel(const string& bin_file_path) override;
+ const std::vector<int>& GetInputs() override {
+ return interpreter_->inputs();
+ }
+ const std::vector<int>& GetOutputs() override {
+ return interpreter_->outputs();
+ }
+ void ReshapeTensor(int id, const string& csv_values) override;
+ void AllocateTensors() override;
+ void ResetTensor(int id) override;
+ void SetInput(int id, const string& csv_values) override;
+ void SetExpectation(int id, const string& csv_values) override;
+ void Invoke() override;
+ bool CheckResults() override;
+
+ private:
+ class Expectation;
+
+ bool use_nnapi_ = false;
+ std::unique_ptr<FlatBufferModel> model_;
+ std::unique_ptr<Interpreter> interpreter_;
+ std::map<int, std::unique_ptr<Expectation>> expected_output_;
+ bool must_allocate_tensors_ = true;
+};
+
+} // namespace testing
+} // namespace tflite
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_TFLITE_DRIVER_H_
diff --git a/tensorflow/contrib/lite/testing/tflite_driver_test.cc b/tensorflow/contrib/lite/testing/tflite_driver_test.cc
new file mode 100644
index 0000000000..37010c468f
--- /dev/null
+++ b/tensorflow/contrib/lite/testing/tflite_driver_test.cc
@@ -0,0 +1,61 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/testing/tflite_driver.h"
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+namespace tflite {
+namespace testing {
+namespace {
+
+using ::testing::ElementsAre;
+
+TEST(TfliteDriverTest, SimpleTest) {
+ std::unique_ptr<TestRunner> runner(new TfLiteDriver(/*use_nnapi=*/false));
+
+ runner->SetModelBaseDir("tensorflow/contrib/lite");
+ runner->LoadModel("testdata/multi_add.bin");
+ ASSERT_TRUE(runner->IsValid());
+
+ ASSERT_THAT(runner->GetInputs(), ElementsAre(0, 1, 2, 3));
+ ASSERT_THAT(runner->GetOutputs(), ElementsAre(5, 6));
+
+ for (int i : {0, 1, 2, 3}) {
+ runner->ReshapeTensor(i, "1,2,2,1");
+ }
+ ASSERT_TRUE(runner->IsValid());
+
+ runner->AllocateTensors();
+
+ runner->SetInput(0, "0.1,0.2,0.3,0.4");
+ runner->SetInput(1, "0.001,0.002,0.003,0.004");
+ runner->SetInput(2, "0.001,0.002,0.003,0.004");
+ runner->SetInput(3, "0.01,0.02,0.03,0.04");
+
+ runner->ResetTensor(2);
+
+ runner->SetExpectation(5, "0.101,0.202,0.303,0.404");
+ runner->SetExpectation(6, "0.011,0.022,0.033,0.044");
+
+ runner->Invoke();
+ ASSERT_TRUE(runner->IsValid());
+
+ ASSERT_TRUE(runner->CheckResults());
+}
+
+} // namespace
+} // namespace testing
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/testing/tokenize.cc b/tensorflow/contrib/lite/testing/tokenize.cc
new file mode 100644
index 0000000000..2e84ea475c
--- /dev/null
+++ b/tensorflow/contrib/lite/testing/tokenize.cc
@@ -0,0 +1,95 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/testing/tokenize.h"
+#include <istream>
+#include <string>
+#include "tensorflow/contrib/lite/string.h"
+
+namespace tflite {
+namespace testing {
+
+void Tokenize(std::istream* input, TokenProcessor* processor) {
+ enum State { kBuildQuotedToken, kBuildToken, kIdle };
+
+ std::string current_token;
+ State state = kIdle;
+ auto start_token = [&](char c) {
+ state = kBuildToken;
+ current_token.clear();
+ current_token = c;
+ };
+ auto issue_token = [&]() {
+ state = kIdle;
+ processor->ConsumeToken(&current_token);
+ current_token.clear();
+ };
+ auto start_quoted_token = [&]() {
+ state = kBuildQuotedToken;
+ current_token.clear();
+ };
+ auto issue_quoted_token = [&]() {
+ state = kIdle;
+ processor->ConsumeToken(&current_token);
+ current_token.clear();
+ };
+ auto issue_delim = [&](char d) {
+ current_token = string(1, d);
+ processor->ConsumeToken(&current_token);
+ current_token.clear();
+ };
+ auto is_delim = [](char c) { return c == '{' || c == '}' || c == ':'; };
+ auto is_quote = [](char c) { return c == '"'; };
+
+ for (auto it = std::istreambuf_iterator<char>(*input);
+ it != std::istreambuf_iterator<char>(); ++it) {
+ switch (state) {
+ case kIdle:
+ if (is_delim(*it)) {
+ issue_delim(*it);
+ } else if (is_quote(*it)) {
+ start_quoted_token();
+ } else if (!isspace(*it)) {
+ start_token(*it);
+ }
+ break;
+ case kBuildToken:
+ if (is_delim(*it)) {
+ issue_token();
+ issue_delim(*it);
+ } else if (is_quote(*it)) {
+ issue_token();
+ start_quoted_token();
+ } else if (isspace(*it)) {
+ issue_token();
+ } else {
+ current_token += *it;
+ }
+ break;
+ case kBuildQuotedToken:
+ if (is_quote(*it)) {
+ issue_quoted_token();
+ } else {
+ current_token += *it;
+ }
+ break;
+ }
+ }
+ if (state != kIdle) {
+ issue_token();
+ }
+}
+
+} // namespace testing
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/testing/tokenize.h b/tensorflow/contrib/lite/testing/tokenize.h
new file mode 100644
index 0000000000..daccf0e84a
--- /dev/null
+++ b/tensorflow/contrib/lite/testing/tokenize.h
@@ -0,0 +1,42 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_TOKENIZER_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_TOKENIZER_H_
+
+#include <istream>
+#include <string>
+
+namespace tflite {
+namespace testing {
+
+// Process tokens coming from Tokenize().
+class TokenProcessor {
+ public:
+ virtual ~TokenProcessor() {}
+ // Process a single token. The token won't be reused, so it is OK to call
+ // token.swap().
+ virtual void ConsumeToken(std::string* token) = 0;
+};
+
+// Tokenize a stream on whitespaces, colons and curly braces. Whitespaces are
+// removed from the tokens and double-quotes can be used to avoid that. Note
+// that there is no way to escape double-quotes, so there's no way to have a
+// double-quote inside a token.
+void Tokenize(std::istream* input, TokenProcessor* processor);
+
+} // namespace testing
+} // namespace tflite
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_TOKENIZER_H_
diff --git a/tensorflow/contrib/lite/testing/tokenize_test.cc b/tensorflow/contrib/lite/testing/tokenize_test.cc
new file mode 100644
index 0000000000..80f44aacca
--- /dev/null
+++ b/tensorflow/contrib/lite/testing/tokenize_test.cc
@@ -0,0 +1,105 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/testing/tokenize.h"
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+namespace tflite {
+namespace testing {
+namespace {
+
+using ::testing::ElementsAre;
+using ::testing::ElementsAreArray;
+
+class TokenCollector : public TokenProcessor {
+ public:
+ void ConsumeToken(std::string* token) override { tokens_.push_back(*token); }
+ const std::vector<std::string>& Tokens() { return tokens_; }
+
+ private:
+ std::vector<std::string> tokens_;
+};
+
+std::vector<std::string> TokenizeString(const std::string& s) {
+ std::stringstream ss(s);
+ TokenCollector collector;
+ Tokenize(&ss, &collector);
+ return collector.Tokens();
+}
+
+TEST(TokenizeTest, TokenDetection) {
+ EXPECT_THAT(TokenizeString("x :1"), ElementsAre("x", ":", "1"));
+ EXPECT_THAT(TokenizeString("x:1"), ElementsAre("x", ":", "1"));
+ EXPECT_THAT(TokenizeString("x {1"), ElementsAre("x", "{", "1"));
+ EXPECT_THAT(TokenizeString("x{1"), ElementsAre("x", "{", "1"));
+ EXPECT_THAT(TokenizeString("x }1"), ElementsAre("x", "}", "1"));
+ EXPECT_THAT(TokenizeString("x}1"), ElementsAre("x", "}", "1"));
+ EXPECT_THAT(TokenizeString("x \"1"), ElementsAre("x", "1"));
+ EXPECT_THAT(TokenizeString("x\"1"), ElementsAre("x", "1"));
+}
+
+TEST(TokenizeTest, QuotedTokenDetection) {
+ EXPECT_THAT(TokenizeString("\"w:x{y}z\"1"), ElementsAre("w:x{y}z", "1"));
+ EXPECT_THAT(TokenizeString("\"w:x{y}z\"\"1\""), ElementsAre("w:x{y}z", "1"));
+}
+
+TEST(TokenizeTest, Delimiters) {
+ EXPECT_THAT(TokenizeString("}"), ElementsAre("}"));
+ EXPECT_THAT(TokenizeString("}}"), ElementsAre("}", "}"));
+ EXPECT_THAT(TokenizeString("{"), ElementsAre("{"));
+ EXPECT_THAT(TokenizeString("{{"), ElementsAre("{", "{"));
+ EXPECT_THAT(TokenizeString(":"), ElementsAre(":"));
+ EXPECT_THAT(TokenizeString("::"), ElementsAre(":", ":"));
+}
+
+TEST(TokenizeTest, CornerCases) {
+ EXPECT_THAT(TokenizeString(" i { b:a } "),
+ ElementsAre("i", "{", "b", ":", "a", "}"));
+ EXPECT_THAT(TokenizeString(" }"), ElementsAre("}"));
+ EXPECT_THAT(TokenizeString(" } "), ElementsAre("}"));
+ EXPECT_THAT(TokenizeString(" {} "), ElementsAre("{", "}"));
+ EXPECT_THAT(TokenizeString(" x{} y{} "),
+ ElementsAre("x", "{", "}", "y", "{", "}"));
+ EXPECT_THAT(TokenizeString("x:1 y:2 "),
+ ElementsAre("x", ":", "1", "y", ":", "2"));
+ EXPECT_THAT(TokenizeString("x:\"1\" y:2 "),
+ ElementsAre("x", ":", "1", "y", ":", "2"));
+ EXPECT_THAT(TokenizeString("x:\"1, 2\" y:\"\" "),
+ ElementsAre("x", ":", "1, 2", "y", ":", ""));
+}
+
+TEST(TokenizeTest, NewLines) {
+ EXPECT_THAT(TokenizeString("x:\n1,\n 2 \n y :\n3 \n"),
+ ElementsAre("x", ":", "1,", "2", "y", ":", "3"));
+}
+
+TEST(TokenizeTest, LongString) {
+ EXPECT_THAT(
+ TokenizeString(" i { b:a } input {"
+ "a: \"1e-1, 2,3\" b:\"1,2,3\"\n c{ "
+ "id:1 x{d{a:"
+ "1}}} f:2 "
+ "\n}\n t:1"),
+ ElementsAreArray({"i", "{", "b", ":", "a", "}", "input", "{",
+ "a", ":", "1e-1, 2,3", "b", ":", "1,2,3", "c", "{",
+ "id", ":", "1", "x", "{", "d", "{", "a",
+ ":", "1", "}", "}", "}", "f", ":", "2",
+ "}", "t", ":", "1"}));
+}
+
+} // namespace
+} // namespace testing
+} // namespace tflite