diff options
Diffstat (limited to 'tensorflow/contrib/lite/testing')
21 files changed, 3585 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/testing/BUILD b/tensorflow/contrib/lite/testing/BUILD new file mode 100644 index 0000000000..5e40a13d3c --- /dev/null +++ b/tensorflow/contrib/lite/testing/BUILD @@ -0,0 +1,213 @@ +package(default_visibility = [ + "//visibility:public", +]) + +licenses(["notice"]) # Apache 2.0 + +load( + "//tensorflow/contrib/lite:build_def.bzl", + "gen_zipped_test_files", +) +load( + "//tensorflow:tensorflow.bzl", + "tf_cc_test", +) + +gen_zipped_test_files( + name = "optest", + files = [ + "add.zip", + "avg_pool.zip", + "concat.zip", + "constant.zip", + "control_dep.zip", + "conv.zip", + "depthwiseconv.zip", + "fully_connected.zip", + "fused_batch_norm.zip", + "global_batch_norm.zip", + "l2_pool.zip", + "l2norm.zip", + "local_response_norm.zip", + "max_pool.zip", + "mul.zip", + "relu.zip", + "relu1.zip", + "relu6.zip", + "reshape.zip", + "resize_bilinear.zip", + "sigmoid.zip", + "softmax.zip", + "space_to_depth.zip", + ], +) + +py_binary( + name = "generate_examples", + srcs = ["generate_examples.py"], + data = [ + "//tensorflow/contrib/lite/toco", + ], + srcs_version = "PY2AND3", + deps = [ + ":generate_examples_report", + "//tensorflow:tensorflow_py", + "//tensorflow/python:graph_util", + "//third_party/py/numpy", + "@six_archive//:six", + ], +) + +py_library( + name = "generate_examples_report", + srcs = ["generate_examples_report.py"], + srcs_version = "PY2AND3", +) + +cc_library( + name = "parse_testdata_lib", + srcs = ["parse_testdata.cc"], + hdrs = ["parse_testdata.h"], + deps = [ + ":message", + ":split", + ":test_runner", + "//tensorflow/contrib/lite:framework", + ], +) + +cc_library( + name = "message", + srcs = ["message.cc"], + hdrs = ["message.h"], + deps = [":tokenize"], +) + +cc_test( + name = "message_test", + srcs = ["message_test.cc"], + deps = [ + ":message", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "split", + srcs = ["split.cc"], + hdrs = ["split.h"], + deps = [ + "//tensorflow/contrib/lite:string", + ], +) + +cc_test( + name = "split_test", + size = "small", + srcs = ["split_test.cc"], + deps = [ + ":split", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "tflite_driver", + srcs = ["tflite_driver.cc"], + hdrs = ["tflite_driver.h"], + deps = [ + ":split", + ":test_runner", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:builtin_ops", + ], +) + +cc_test( + name = "tflite_driver_test", + size = "small", + srcs = ["tflite_driver_test.cc"], + data = ["//tensorflow/contrib/lite:testdata/multi_add.bin"], + deps = [ + ":tflite_driver", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "tokenize", + srcs = ["tokenize.cc"], + hdrs = ["tokenize.h"], + deps = [ + "//tensorflow/contrib/lite:string", + ], +) + +cc_test( + name = "tokenize_test", + srcs = ["tokenize_test.cc"], + deps = [ + ":tokenize", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "test_runner", + hdrs = ["test_runner.h"], + deps = [ + "//tensorflow/contrib/lite:string", + ], +) + +cc_test( + name = "test_runner_test", + srcs = ["test_runner_test.cc"], + deps = [ + ":test_runner", + "@com_google_googletest//:gtest_main", + ], +) + +cc_binary( + name = "nnapi_example", + srcs = ["nnapi_example.cc"], + deps = [ + ":parse_testdata_lib", + "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:builtin_ops", + "//tensorflow/contrib/lite/nnapi:nnapi_lib", + ], +) + +tf_cc_test( + name = "generated_examples_zip_test", + size = "medium", + srcs = ["generated_examples_zip_test.cc"], + data = [":optest"], + shard_count = 10, + deps = [ + ":parse_testdata_lib", + "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:builtin_ops", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "@com_google_googletest//:gtest", + "@com_googlesource_code_re2//:re2", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py new file mode 100644 index 0000000000..86540d58a6 --- /dev/null +++ b/tensorflow/contrib/lite/testing/generate_examples.py @@ -0,0 +1,1189 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Generate a series of TensorFlow graphs that become tflite test cases. + +Usage: + +generate_examples <output directory> zipped + +bazel run //tensorflow/contrib/lite/testing:generate_examples + third_party/tensorflow/contrib/lite/testing/generated_examples zipped +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import itertools +import os +import re +import sys +import tempfile +import traceback +import zipfile +import numpy as np +from six import StringIO +import tensorflow as tf +from google.protobuf import text_format +# TODO(aselle): switch to TensorFlow's resource_loader +from tensorflow.contrib.lite.testing import generate_examples_report as report_lib +from tensorflow.python.framework import graph_util as tf_graph_util + +parser = argparse.ArgumentParser(description="Script to generate TFLite tests.") +parser.add_argument("output_path", + help="Directory where the outputs will be go.") +# TODO(ahentz): remove this flag +parser.add_argument("type", help="zipped") +parser.add_argument("--zip_to_output", + type=str, + help="Particular zip to output.", + required=False) +parser.add_argument("--toco", + type=str, + help="Path to toco tool.", + required=True) +parser.add_argument( + "--known_bugs_are_errors", + action="store_true", + help=("If a particular model is affected by a known bug," + " count it as a toco error.")) +parser.add_argument( + "--ignore_toco_errors", + action="store_true", + help="Raise an exception if any toco error is encountered.") +parser.add_argument( + "--save_graphdefs", + action="store_true", + help="Include intermediate graphdefs in the output zip files.") + + +RANDOM_SEED = 342 +TEST_INPUT_DEPTH = 3 + + +# A map from regular expression to bug number. Any test failure with label +# matching the expression will be considered due to the corresponding bug. +KNOWN_BUGS = { + # TOCO doesn't support scalars as input. + r"relu.*input_shape=\[\]": "67587484", + r"sigmoid.*input_shape=\[\]": "67645668", + # Concat doesn't work with a single input tensor + r"concat.*num_tensors=1": "67378344", + # Transposition in MatMul is not supported. + r"fully_connected.*transpose_.=True": "67586970", + # Softmax graphs are too complex. + r"softmax.*dim=0": "67749831", + r"softmax.*input_shape=\[1,3,4,3\]": "67749831", + # SpaceToDepth only supports float32. + r"space_to_depth.*(float16|int32|uint8|int64)": "68018134", +} + + +def toco_options(data_types, + input_arrays, + output_arrays, + shapes, + drop_control_dependency): + """Create TOCO options to process a model. + + Args: + data_types: input and inference types used by TOCO. + input_arrays: names of the input tensors + output_arrays: name of the output tensors + shapes: shapes of the input tensors + drop_control_dependency: whether to ignore control dependency nodes. + + Returns: + the options in a string. + """ + shape_str = ":".join([",".join(str(y) for y in x) for x in shapes]) + inference_type = "FLOAT" + # TODO(ahentz): if we get multi-input quantization to work we need this + # to change + if data_types[0] == "QUANTIZED_UINT8": + inference_type = "QUANTIZED_UINT8" + s = (" --input_types=%s" % ",".join(data_types) + + " --inference_type=%s" % inference_type + + " --input_format=TENSORFLOW_GRAPHDEF" + " --output_format=TFLITE" + + " --input_arrays=%s" % ",".join(input_arrays) + + " --input_shapes=%s" % shape_str + + " --output_arrays=%s" % ",".join(output_arrays)) + if drop_control_dependency: + s += " --drop_control_dependency" + return s + + +def write_toco_options(filename, + data_types, + input_arrays, + output_arrays, + shapes, + drop_control_dependency=False): + """Create TOCO options to process a model. + + Args: + filename: Filename to write the options to. + data_types: input and inference types used by TOCO. + input_arrays: names of the input tensors + output_arrays: names of the output tensors + shapes: shapes of the input tensors + drop_control_dependency: whether to ignore control dependency nodes. + """ + with open(filename, "w") as fp: + fp.write( + toco_options( + data_types=data_types, + input_arrays=input_arrays, + output_arrays=output_arrays, + shapes=shapes, + drop_control_dependency=drop_control_dependency)) + + +def write_examples(fp, examples): + """Given a list `examples`, write a text format representation. + + The file format is csv like with a simple repeated pattern. We would ike + to use proto here, but we can't yet due to interfacing with the Android + team using this format. + + Args: + fp: File-like object to write to. + examples: Example dictionary consiting of keys "inputs" and "outputs" + """ + + def write_tensor(fp, x): + """Write tensor in file format supported by TFLITE example.""" + fp.write("dtype,%s\n" % x.dtype) + fp.write("shape," + ",".join(map(str, x.shape)) + "\n") + # Output 9 digits after the point to ensure the precision is good enough. + values = ["{:.9f}".format(value) for value in list(x.flatten())] + fp.write("values," + ",".join(values) + "\n") + + fp.write("test_cases,%d\n" % len(examples)) + for example in examples: + fp.write("inputs,%d\n" % len(example["inputs"])) + for i in example["inputs"]: + write_tensor(fp, i) + fp.write("outputs,%d\n" % len(example["outputs"])) + for i in example["outputs"]: + write_tensor(fp, i) + + +def write_test_cases(fp, model_name, examples): + """Given a dictionary of `examples`, write a text format representation. + + The file format is protocol-buffer-like, even though we don't use proto due + to the needs of the Android team. + + Args: + fp: File-like object to write to. + model_name: Filename where the model was written to, relative to filename. + examples: Example dictionary consiting of keys "inputs" and "outputs" + """ + + fp.write("load_model: %s\n" % os.path.basename(model_name)) + for example in examples: + fp.write("reshape {\n") + for t in example["inputs"]: + fp.write(" input: \"" + ",".join(map(str, t.shape)) + "\"\n") + fp.write("}\n") + fp.write("invoke {\n") + + for t in example["inputs"]: + values = ["{:.9f}".format(value) for value in list(t.flatten())] + fp.write(" input: \"" + ",".join(values) + "\"\n") + for t in example["outputs"]: + values = ["{:.9f}".format(value) for value in list(t.flatten())] + fp.write(" output: \"" + ",".join(values) + "\"\n") + fp.write("}\n") + + +_TF_TYPE_INFO = { + tf.float32: (np.float32, "FLOAT"), + tf.float16: (np.float16, "FLOAT"), + tf.int32: (np.int32, "INT32"), + tf.uint8: (np.uint8, "QUANTIZED_UINT8"), + tf.int64: (np.int64, "INT64"), +} + + +def create_tensor_data(dtype, shape, min_value=-100, max_value=100): + """Build tensor data spreading the range [min_value, max_value).""" + + if dtype in _TF_TYPE_INFO: + dtype = _TF_TYPE_INFO[dtype][0] + + if dtype in (tf.float32, tf.float16): + value = (max_value-min_value)*np.random.random_sample(shape)+min_value + elif dtype in (tf.int32, tf.uint8, tf.int64): + value = np.random.random_integers(min_value, max_value, shape) + return value.astype(dtype) + + +def freeze_graph(session, outputs): + """Freeze the current graph. + + Args: + session: Tensorflow sessions containing the graph + outputs: List of output tensors + + Returns: + The frozen graph_def. + """ + return tf_graph_util.convert_variables_to_constants( + session, session.graph.as_graph_def(), [x.op.name for x in outputs]) + + +def make_control_dep_tests(zip_path): + """Make a set of tests that use control dependencies.""" + + test_parameters = [{ + "input_shape": [[], [1, 1, 1, 1], [1, 15, 14, 1], [3, 15, 14, 3]], + }] + + def build_graph(parameters): + input_tensor = tf.placeholder( + dtype=tf.float32, name="input", shape=parameters["input_shape"]) + filter_value = tf.zeros((3, 3, TEST_INPUT_DEPTH, 8), tf.float32) + assert_op = tf.assert_greater_equal(input_tensor, input_tensor - 1) + with tf.control_dependencies([assert_op]): + out = tf.nn.conv2d(input_tensor, filter_value, + strides=(1, 1, 1, 1), padding="SAME") + return [input_tensor], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_values = create_tensor_data(tf.float32, parameters["input_shape"]) + return [input_values], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_values]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs, + drop_control_dependency=True) + + +def toco_convert(graph_def_str, input_tensors, output_tensors, + drop_control_dependency=False): + """Convert a model's graph def into a tflite model. + + NOTE: this currently shells out to the toco binary, but we would like + convert to Python API tooling in the future. + + Args: + graph_def_str: Graph def proto in serialized string format. + input_tensors: List of input tensor tuples `(name, shape, type)` + output_tensors: List of output tensors (names) + drop_control_dependency: whether to ignore control dependency nodes. + + Returns: + output tflite model, log_txt from conversion + or None, log_txt if it did not convert properly. + """ + data_types = [_TF_TYPE_INFO[x[2]][1] for x in input_tensors] + opts = toco_options( + data_types=data_types, + input_arrays=[x[0] for x in input_tensors], + shapes=[x[1] for x in input_tensors], + output_arrays=output_tensors, + drop_control_dependency=drop_control_dependency) + + with tempfile.NamedTemporaryFile() as graphdef_file, \ + tempfile.NamedTemporaryFile() as output_file, \ + tempfile.NamedTemporaryFile("w+") as stdout_file: + graphdef_file.write(graph_def_str) + graphdef_file.flush() + + # TODO(aselle): Switch this to subprocess at some point. + cmd = ("%s --input_file=%s --output_file=%s %s > %s 2>&1" % + (bin_path, graphdef_file.name, output_file.name, opts, + stdout_file.name)) + exit_code = os.system(cmd) + log = ( + cmd + "exited with code %d" % exit_code + "\n------------------\n" + + stdout_file.read()) + return (None if exit_code != 0 else output_file.read()), log + + +def make_zip_of_tests(zip_path, + test_parameters, + make_graph, + make_test_inputs, + drop_control_dependency=False): + """Helper to make a zip file of a bunch of TensorFlow models. + + This does a cartestian product of the dictionary of test_parameters and + calls make_graph() for each item in the cartestian product set. + If the graph is built successfully, then make_test_inputs() is called to + build expected input/output value pairs. The model is then converted to tflite + with toco, and the examples are serialized with the tflite model into a zip + file (2 files per item in the cartesian product set). + + Args: + zip_path: Path of zip file to write + test_parameters: Dictionary mapping to lists for each parameter. + e.g. `{"strides": [[1,3,3,1], [1,2,2,1]], "foo": [1.2, 1.3]}` + make_graph: function that takes current parameters and returns tuple + `[input1, input2, ...], [output1, output2, ...]` + make_test_inputs: function taking `curr_params`, `session`, `input_tensors`, + `output_tensors` and returns tuple `(input_values, output_values)`. + drop_control_dependency: whether to ignore control dependency nodes. + Raises: + RuntimeError: if there are toco errors that can't be ignored. + """ + + # TODO(aselle): Make this allow multiple inputs outputs. + archive = zipfile.PyZipFile(zip_path, "w") + zip_manifest = [] + convert_report = [] + toco_errors = 0 + for parameters in test_parameters: + keys = parameters.keys() + for curr in itertools.product(*parameters.values()): + label = zip_path.replace(".zip", "") + (",".join( + "%s=%r" % z for z in sorted(zip(keys, curr))).replace(" ", "")) + if label[0] == "/": + label = label[1:] + param_dict = dict(zip(keys, curr)) + + def build_example(label, param_dict_real): + """Build the model with parameter values set in param_dict_real. + + Args: + label: Label of the model (i.e. the filename in the zip). + param_dict_real: Parameter dictionary (arguments to the factories + make_graph and make_test_inputs) + Returns: + (tflite_model_binary, report) where tflite_model_binary is the + serialized flatbuffer as a string and report is a dictionary with + keys `toco_log` (log of toco conversion), `tf_log` (log of tf + conversion), `toco` (a string of success status of the conversion), + `tf` (a string success status of the conversion). + """ + + np.random.seed(RANDOM_SEED) + report = {"toco": report_lib.NOTRUN, "tf": report_lib.FAILED} + + # Build graph + report["tf_log"] = "" + report["toco_log"] = "" + tf.reset_default_graph() + + try: + inputs, outputs = make_graph(param_dict_real) + except (tf.errors.UnimplementedError, tf.errors.InvalidArgumentError, + ValueError): + report["tf_log"] += traceback.format_exc() + return None, report + + sess = tf.Session() + try: + baseline_inputs, baseline_outputs = (make_test_inputs( + param_dict_real, sess, inputs, outputs)) + except (tf.errors.UnimplementedError, tf.errors.InvalidArgumentError, + ValueError): + report["tf_log"] += traceback.format_exc() + return None, report + report["toco"] = report_lib.FAILED + report["tf"] = report_lib.SUCCESS + + # Convert graph to toco + tflite_model_binary, toco_log = toco_convert( + sess.graph_def.SerializeToString(), + [(input_tensor.name.split(":")[0], input_tensor.get_shape(), + input_tensor.dtype) for input_tensor in inputs], + [out.name.split(":")[0] + for out in outputs], drop_control_dependency) + report["toco"] = (report_lib.SUCCESS if tflite_model_binary is not None + else report_lib.FAILED) + report["toco_log"] = toco_log + + if FLAGS.save_graphdefs: + archive.writestr(label + ".pb", + text_format.MessageToString(sess.graph_def), + zipfile.ZIP_DEFLATED) + + if tflite_model_binary: + archive.writestr(label + ".bin", tflite_model_binary, + zipfile.ZIP_DEFLATED) + example = {"inputs": baseline_inputs, "outputs": baseline_outputs} + + example_fp = StringIO() + write_examples(example_fp, [example]) + archive.writestr(label + ".inputs", + example_fp.getvalue(), zipfile.ZIP_DEFLATED) + + example_fp2 = StringIO() + write_test_cases(example_fp2, label + ".bin", [example]) + archive.writestr(label + "_tests.txt", + example_fp2.getvalue(), zipfile.ZIP_DEFLATED) + + zip_manifest.append(label + "\n") + + return tflite_model_binary, report + + _, report = build_example(label, param_dict) + + if report["toco"] == report_lib.FAILED: + ignore_error = False + if not FLAGS.known_bugs_are_errors: + for pattern, bug_number in KNOWN_BUGS.items(): + if re.search(pattern, label): + print("Ignored TOCO error due to bug %s" % bug_number) + ignore_error = True + if not ignore_error: + toco_errors += 1 + print("-----------------\ntoco error!\n%s\n-----------------\n" % + report["toco_log"]) + + convert_report.append((param_dict, report)) + report_io = StringIO() + report_lib.make_report_table(report_io, zip_path, convert_report) + archive.writestr("report.html", report_io.getvalue()) + + archive.writestr("manifest.txt", "".join(zip_manifest), zipfile.ZIP_DEFLATED) + + # Log statistics of what succeeded + total_conversions = len(convert_report) + tf_success = sum(1 for x in convert_report + if x[1]["tf"] == report_lib.SUCCESS) + toco_success = sum(1 for x in convert_report + if x[1]["toco"] == report_lib.SUCCESS) + percent = 0 + if tf_success > 0: + percent = float(toco_success) / float(tf_success) * 100. + tf.logging.info(("Archive %s Considered %d graphs, %d TF evaluated graphs " + " and %d TOCO converted graphs (%.1f%%"), zip_path, + total_conversions, tf_success, toco_success, percent) + + if not FLAGS.ignore_toco_errors and toco_errors > 0: + raise RuntimeError( + "Found %d errors while generating toco models" % toco_errors) + + +def make_pool_tests(pool_op_in): + """Make a set of tests to do average pooling. + + Args: + pool_op_in: TensorFlow pooling operation to test i.e. `tf.nn.avg_pool`. + + Returns: + A function representing the true generator (after curried pool_op_in). + """ + + pool_op = pool_op_in + + def f(zip_path): + """Actual function that generates examples. + + Args: + zip_path: path to write zip to. + """ + + # Chose a set of parameters + test_parameters = [{ + "ksize": [[2, 1, 1, 2], [1, 1, 1, 1], [1, 1, 2, 1], [1, 10, 11, 1]], + "strides": [[2, 1, 1, 2], [1, 1, 1, 1], [1, 1, 2, 1], [1, 10, 11, 1]], + # TODO(aselle): should add in a degenerate shape (e.g. [1, 0, 1, 1]). + "input_shape": [[], [1, 1, 1, 1], [1, 15, 14, 1], [3, 15, 14, 3]], + "padding": ["SAME", "VALID"], + "data_format": ["NHWC"], # TODO(aselle): NCHW would be good + }] + + def build_graph(parameters): + input_tensor = tf.placeholder( + dtype=tf.float32, name="input", shape=parameters["input_shape"]) + out = pool_op( + input_tensor, + ksize=parameters["ksize"], + strides=parameters["strides"], + data_format=parameters["data_format"], + padding=parameters["padding"]) + return [input_tensor], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_values = create_tensor_data(tf.float32, parameters["input_shape"]) + return [input_values], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_values]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + return f + + +def make_relu_tests(zip_path): + """Make a set of tests to do relu.""" + + # Chose a set of parameters + test_parameters = [{ + "input_shape": [[], [1], [2, 3], [1, 1, 1, 1], [1, 3, 4, 3], + [3, 15, 14, 3], [3, 1, 2, 4, 6], [2, 2, 3, 4, 5, 6]], + }] + + def build_graph(parameters): + input_tensor = tf.placeholder( + dtype=tf.float32, name="input", shape=parameters["input_shape"]) + out = tf.nn.relu(input_tensor) + return [input_tensor], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_values = create_tensor_data( + np.float32, parameters["input_shape"], min_value=-4, max_value=10) + return [input_values], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_values]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + +def make_relu1_tests(zip_path): + """Make a set of tests to do relu1.""" + + # Chose a set of parameters + test_parameters = [{ + "input_shape": [[], [1, 1, 1, 1], [1, 3, 4, 3], [3, 15, 14, 3], + [3, 1, 2, 4, 6], [2, 2, 3, 4, 5, 6]], + }] + + def build_graph(parameters): + input_tensor = tf.placeholder( + dtype=tf.float32, name="input", shape=parameters["input_shape"]) + # Note that the following is not supported: + # out = tf.maximum(-1.0, tf.minimum(input_tensor, 1.0)) + out = tf.minimum(1.0, tf.maximum(input_tensor, -1.0)) + return [input_tensor], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_values = create_tensor_data( + np.float32, parameters["input_shape"], min_value=-3, max_value=10) + return [input_values], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_values]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + +def make_relu6_tests(zip_path): + """Make a set of tests to do relu6.""" + + # Chose a set of parameters + test_parameters = [{ + "input_shape": [[], [1, 1, 1, 1], [1, 3, 4, 3], [3, 15, 14, 3], + [3, 1, 2, 4, 6], [2, 2, 3, 4, 5, 6]], + }] + + def build_graph(parameters): + input_tensor = tf.placeholder( + dtype=tf.float32, name="input", shape=parameters["input_shape"]) + out = tf.nn.relu(input_tensor) + return [input_tensor], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_values = create_tensor_data( + np.float32, parameters["input_shape"], min_value=-3, max_value=10) + return [input_values], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_values]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + +# This function tests various TensorFLow functions that generates Const op, +# including `tf.ones`, `tf.zeros` and random functions. +def make_constant_tests(zip_path): + """Make a set of tests to do constant ops.""" + + test_parameters = [{ + "dtype": [tf.float32, tf.int32], + "input_shape": [[1], [2], [1, 1, 1, 1], [2, 2, 2, 2]], + }] + + def build_graph(parameters): + # Since Toco & Tflite can't have a single constant op in the entire graph, + # this test adds a zero tesnor with a constant op tensor. + input1 = tf.placeholder(dtype=parameters["dtype"], name="input1", + shape=parameters["input_shape"]) + out = tf.ones(parameters["input_shape"], dtype=parameters["dtype"]) + input1 + return [input1], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input1 = np.zeros(parameters["input_shape"], + dtype=_TF_TYPE_INFO[parameters["dtype"]][0]) + return [input1], sess.run(outputs, feed_dict={inputs[0]: input1}) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + +def make_add_tests(zip_path): + """Make a set of tests to do add with and without broadcast.""" + + # These parameters are split because we don't support broadcasting. + test_parameters = [{ + "dtype": [tf.float32, tf.int32], + "input_shape_1": [[1, 3, 4, 3]], + "input_shape_2": [[1, 3, 4, 3]], + }, { + "dtype": [tf.float32], + "input_shape_1": [[5]], + "input_shape_2": [[5]], + }, { + "dtype": [tf.float32], + "input_shape_1": [[1, 3, 4, 3]], + "input_shape_2": [[3]], + }] + + def build_graph(parameters): + input1 = tf.placeholder(dtype=parameters["dtype"], name="input1", + shape=parameters["input_shape_1"]) + input2 = tf.placeholder(dtype=parameters["dtype"], name="input2", + shape=parameters["input_shape_2"]) + out = tf.add(input1, input2) + return [input1, input2], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input1 = create_tensor_data(parameters["dtype"], + parameters["input_shape_1"]) + input2 = create_tensor_data(parameters["dtype"], + parameters["input_shape_2"]) + return [input1, input2], sess.run( + outputs, feed_dict={ + inputs[0]: input1, + inputs[1]: input2 + }) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + +def make_mul_tests(zip_path): + """Make a set of tests to do mul with and without broadcast.""" + + # These parameters are split because we don't support broadcasting. + test_parameters = [{ + "dtype": [tf.float32, tf.int32], + "input_shape_1": [[1, 3, 4, 3]], + "input_shape_2": [[1, 3, 4, 3]], + }, { + "dtype": [tf.float32], + "input_shape_1": [[5]], + "input_shape_2": [[5]], + }, { + "dtype": [tf.float32], + "input_shape_1": [[1, 3, 4, 3]], + "input_shape_2": [[3]], + }] + + def build_graph(parameters): + input1 = tf.placeholder(dtype=parameters["dtype"], name="input1", + shape=parameters["input_shape_1"]) + input2 = tf.placeholder(dtype=parameters["dtype"], name="input2", + shape=parameters["input_shape_2"]) + out = tf.multiply(input1, input2) + return [input1, input2], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input1 = create_tensor_data(parameters["dtype"], + parameters["input_shape_1"]) + input2 = create_tensor_data(parameters["dtype"], + parameters["input_shape_2"]) + return [input1, input2], sess.run( + outputs, feed_dict={inputs[0]: input1, + inputs[1]: input2}) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + +def make_global_batch_norm_tests(zip_path): + """Make a set of tests to do batch_norm_with_global_normalization.""" + + test_parameters = [{ + "dtype": [tf.float32], + "input_shape": [[1, 1, 6, 2], [3, 4, 5, 4]], + "epsilon": [0.1, 0.0001], + "scale_after": [True, False], + }] + + def build_graph(parameters): + """Build the global batch norm testing graph.""" + input_shape = parameters["input_shape"] + scale_shape = input_shape[3] + + scale = create_tensor_data(parameters["dtype"], scale_shape) + offset = create_tensor_data(parameters["dtype"], scale_shape) + mean = create_tensor_data(parameters["dtype"], scale_shape) + variance = create_tensor_data(parameters["dtype"], scale_shape) + + x = create_tensor_data(parameters["dtype"], parameters["input_shape"]) + x_norm = tf.nn.batch_norm_with_global_normalization( + x, mean, variance, scale, offset, + parameters["epsilon"], parameters["scale_after"]) + + input_tensor = tf.placeholder(dtype=parameters["dtype"], name="input", + shape=parameters["input_shape"]) + out = tf.add(input_tensor, x_norm) + return [input_tensor], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_value = create_tensor_data(parameters["dtype"], + parameters["input_shape"]) + return [input_value], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_value]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + +def make_fused_batch_norm_tests(zip_path): + """Make a set of tests to do fused_batch_norm.""" + + test_parameters = [{ + "dtype": [tf.float32], + "input_shape": [[1, 1, 6, 2]], + "epsilon": [0.001, 0.1], + }] + + def build_graph(parameters): + """Build the testing graph for fused batch normalization.""" + input_shape = parameters["input_shape"] + scale_shape = input_shape[3] + + scale = create_tensor_data(parameters["dtype"], scale_shape) + offset = create_tensor_data(parameters["dtype"], scale_shape) + mean = create_tensor_data(parameters["dtype"], scale_shape) + variance = create_tensor_data(parameters["dtype"], scale_shape) + + x = create_tensor_data(parameters["dtype"], parameters["input_shape"]) + [x_norm, _, _] = tf.nn.fused_batch_norm( + x, scale, offset, mean, variance, + parameters["epsilon"], data_format="NHWC", is_training=False) + + input_tensor = tf.placeholder(dtype=parameters["dtype"], name="input", + shape=parameters["input_shape"]) + out = tf.add(input_tensor, x_norm) + return [input_tensor], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_value = create_tensor_data(parameters["dtype"], + parameters["input_shape"]) + return [input_value], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_value]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + +def make_conv_tests(zip_path): + """Make a set of tests to do convolution.""" + + test_parameters = [{ + "input_shape": [[1, 3, 4, 3]], + "filter_shape": [[1, 1, 3, 2]], + "strides": [[1, 1, 1, 1], [1, 2, 3, 1]], + "padding": ["SAME", "VALID"], + "data_format": ["NHWC"], # TODO(aselle): NCHW would be good + }, { + "input_shape": [[2, 14, 14, 2]], + "filter_shape": [[6, 6, 2, 2]], + "strides": [[1, 1, 1, 1], [1, 2, 3, 1]], + "padding": ["SAME", "VALID"], + "data_format": ["NHWC"], # TODO(aselle): NCHW would be good + }] + + def build_graph(parameters): + input_tensor = tf.placeholder( + dtype=tf.float32, name="input", shape=parameters["input_shape"]) + filter_values = create_tensor_data(np.float32, parameters["filter_shape"]) + out = tf.nn.conv2d(input_tensor, filter_values, + strides=parameters["strides"], + padding=parameters["padding"], + data_format=parameters["data_format"]) + return [input_tensor], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_values = create_tensor_data(np.float32, parameters["input_shape"]) + return [input_values], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_values]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + +def make_depthwiseconv_tests(zip_path): + """Make a set of tests to do convolution.""" + + # Tensorflow only supports equal strides + test_parameters = [{ + "input_shape": [[1, 3, 4, 3], [1, 10, 10, 3]], + "filter_size": [[1, 1], [1, 2], [3, 3]], + "strides": [[1, 1, 1, 1], [1, 3, 3, 1]], + "channel_multiplier": [1, 2], + "rate": [[1, 1]], + "padding": ["SAME", "VALID"], + "data_format": ["NHWC"], + }, { + "input_shape": [[1, 3, 4, 3]], + "filter_size": [[1, 1]], + "strides": [[1, 1, 2, 1]], # TF needs [1, x, x, 1] + "channel_multiplier": [2], + "rate": [[2, 2]], # Only [1, 1] is supported + "padding": ["SAME"], + "data_format": ["NHWC"], + }] + + def build_graph(parameters): + """Build a depthwise conv graph given `parameters`.""" + input_shape = parameters["input_shape"] + filter_size = parameters["filter_size"] + input_tensor = tf.placeholder( + dtype=tf.float32, name="input", shape=input_shape) + filter_shape = filter_size + [ + input_shape[3], parameters["channel_multiplier"]] + filter_values = create_tensor_data(np.float32, filter_shape) + out = tf.nn.depthwise_conv2d( + input_tensor, filter_values, + strides=parameters["strides"], + rate=parameters["rate"], + padding=parameters["padding"], + data_format=parameters["data_format"]) + return [input_tensor], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_values = create_tensor_data(np.float32, parameters["input_shape"]) + return [input_values], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_values]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + +def make_concatenation_tests(zip_path): + """Make a set of tests to do concatenatinon.""" + + test_parameters = [{ + "base_shape": [[1, 3, 4, 3], [3, 4]], + "num_tensors": [1, 2, 3, 4, 5, 6], + "axis": [0, 1, 2, 3], + }] + + def get_shape(parameters, delta): + """Return a tweaked version of 'base_shape'.""" + axis = parameters["axis"] + shape = parameters["base_shape"][:] + if axis < len(shape): + shape[axis] += delta + return shape + + def build_graph(parameters): + all_tensors = [] + for n in range(0, parameters["num_tensors"]): + input_tensor = tf.placeholder(dtype=tf.float32, name=("input%d" % n), + shape=get_shape(parameters, n)) + all_tensors.append(input_tensor) + out = tf.concat(all_tensors, parameters["axis"]) + return all_tensors, [out] + + def build_inputs(parameters, sess, inputs, outputs): + all_values = [] + for n in range(0, parameters["num_tensors"]): + input_values = create_tensor_data(np.float32, + get_shape(parameters, n)) + all_values.append(input_values) + return all_values, sess.run( + outputs, feed_dict=dict(zip(inputs, all_values))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + +def make_fully_connected_tests(zip_path): + """Make a set of tests to do fully_connected.""" + + test_parameters = [{ + "shape1": [[3, 3]], + "shape2": [[3, 3]], + "transpose_a": [True, False], + "transpose_b": [True, False], + }, { + "shape1": [[4, 4], [1, 4], [4]], + "shape2": [[4, 4], [4, 1], [4]], + "transpose_a": [False], + "transpose_b": [False], + }, { + "shape1": [[40, 37]], + "shape2": [[37, 40]], + "transpose_a": [False], + "transpose_b": [False], + + }] + + def build_graph(parameters): + input_tensor1 = tf.placeholder(dtype=tf.float32, name="input1", + shape=parameters["shape1"]) + input_tensor2 = create_tensor_data(np.float32, parameters["shape2"]) + out = tf.matmul(input_tensor1, input_tensor2, + transpose_a=parameters["transpose_a"], + transpose_b=parameters["transpose_b"]) + return [input_tensor1], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_values1 = create_tensor_data(np.float32, shape=parameters["shape1"]) + return [input_values1], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_values1]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + +def make_l2norm_tests(zip_path): + """Make a set of tests to do l2norm.""" + + # Chose a set of parameters + test_parameters = [{ + "input_shape": [[5, 7], [1, 1, 1, 1], [1, 3, 4, 3], [3, 15, 14, 3], + [3, 1, 2, 4, 6], [2, 2, 3, 4, 5, 6]], + "dim": [0, 1, 2, 3, [2, 3], -2], + "epsilon": [None, 1e-12, 1e-3], + }] + + def build_graph(parameters): + input_tensor = tf.placeholder( + dtype=tf.float32, name="input", shape=parameters["input_shape"]) + if parameters["epsilon"]: + out = tf.nn.l2_normalize( + input_tensor, parameters["dim"], epsilon=parameters["epsilon"]) + else: + out = tf.nn.l2_normalize(input_tensor, parameters["dim"]) + return [input_tensor], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_values = create_tensor_data( + np.float32, parameters["input_shape"], min_value=-4, max_value=10) + return [input_values], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_values]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + +def make_local_response_norm_tests(zip_path): + """Make a set of tests to do local_response_norm.""" + + # Chose a set of parameters + test_parameters = [{ + "input_shape": [[1, 1, 1, 1], [1, 3, 4, 3], [3, 15, 14, 3]], + "depth_radius": [None, 0, 1, 3, 4, 5], + "bias": [None, 0.1, 0.3, -0.1], + "alpha": [None, 1, 2, -3], + "beta": [None, 0.5, 0.25, 2], + }] + + def build_graph(parameters): + input_tensor = tf.placeholder( + dtype=tf.float32, name="input", shape=parameters["input_shape"]) + out = tf.nn.local_response_normalization( + input_tensor, depth_radius=parameters["depth_radius"], + bias=parameters["bias"], alpha=parameters["alpha"], + beta=parameters["beta"]) + return [input_tensor], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_values = create_tensor_data( + np.float32, parameters["input_shape"], min_value=-4, max_value=10) + return [input_values], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_values]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + +def make_reshape_tests(zip_path): + """Make a set of tests to do reshape.""" + + # Alll shapes below are suitable for tensors with 420 elements. + test_parameters = [{ + "dtype": [tf.float32, tf.int32], + "input_shape": [[3, 4, 5, 7], [4, 105], [21, 5, 2, 2], [420]], + "output_shape": [[15, 28], [420], [1, -1, 5, 7], [-1]], + }] + + def build_graph(parameters): + input_tensor = tf.placeholder(dtype=parameters["dtype"], name="input", + shape=parameters["input_shape"]) + out = tf.reshape(input_tensor, shape=parameters["output_shape"]) + return [input_tensor], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_values = create_tensor_data(parameters["dtype"], + parameters["input_shape"]) + return [input_values], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_values]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + +def make_resize_bilinear_tests(zip_path): + """Make a set of tests to do resize_bilinear.""" + + test_parameters = [{ + "dtype": [tf.float32, tf.int32], + "input_shape": [[1, 3, 4, 3], [1, 10, 2, 1]], + "size": [[1, 1], [4, 3], [2, 2], [5, 6]], + "align_corners": [None, True, False], + }] + + def build_graph(parameters): + input_tensor = tf.placeholder(dtype=parameters["dtype"], name="input", + shape=parameters["input_shape"]) + out = tf.image.resize_bilinear(input_tensor, size=parameters["size"], + align_corners=parameters["align_corners"]) + return [input_tensor], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_values = create_tensor_data(parameters["dtype"], + parameters["input_shape"]) + return [input_values], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_values]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + +def make_sigmoid_tests(zip_path): + """Make a set of tests to do sigmoid.""" + + test_parameters = [{ + "dtype": [tf.float32], + "input_shape": [[1, 3, 4, 3], [4], [], [1, 2, 3, 4, 5, 6]], + }] + + def build_graph(parameters): + input_tensor = tf.placeholder(dtype=parameters["dtype"], name="input", + shape=parameters["input_shape"]) + out = tf.sigmoid(input_tensor) + return [input_tensor], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_values = create_tensor_data(parameters["dtype"], + parameters["input_shape"]) + return [input_values], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_values]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + +def make_softmax_tests(zip_path): + """Make a set of tests to do softmax.""" + + test_parameters = [{ + "dtype": [tf.float32], + "input_shape": [[1, 3, 4, 3], [2, 3]], + "dim": [-1, 0], + }, { + "dtype": [tf.float32], + "input_shape": [[4, 7]], + "dim": [-1, 1], + }] + + def build_graph(parameters): + input_tensor = tf.placeholder(dtype=parameters["dtype"], name="input", + shape=parameters["input_shape"]) + out = tf.nn.softmax(input_tensor, dim=parameters["dim"]) + return [input_tensor], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_values = create_tensor_data(parameters["dtype"], + parameters["input_shape"]) + return [input_values], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_values]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + +def make_space_to_depth_tests(zip_path): + """Make a set of tests to do space_to_depth.""" + + test_parameters = [{ + "dtype": [tf.float32, tf.float16, tf.int32, tf.uint8, tf.int64], + "input_shape": [[2, 12, 24, 1]], + "block_size": [2, 3, 4], + }] + + def build_graph(parameters): + input_tensor = tf.placeholder(dtype=parameters["dtype"], name="input", + shape=parameters["input_shape"]) + out = tf.space_to_depth(input_tensor, block_size=parameters["block_size"]) + return [input_tensor], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_values = create_tensor_data(parameters["dtype"], + parameters["input_shape"]) + return [input_values], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_values]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + +def make_l2_pool(input_tensor, ksize, strides, padding, data_format): + """Given an input perform a sequence of TensorFlow ops to produce l2pool.""" + return tf.sqrt(tf.nn.avg_pool( + tf.square(input_tensor), ksize=ksize, strides=strides, + padding=padding, data_format=data_format)) + + +# Toco binary path provided by the generate rule. +bin_path = None + + +def main(unused_args): + global bin_path + def mkdir_if_not_exist(x): + if not os.path.isdir(x): + os.mkdir(x) + if not os.path.isdir(x): + raise RuntimeError("Failed to create dir %r" % x) + + if FLAGS.type == "zipped": + opstest_path = os.path.join(FLAGS.output_path) + mkdir_if_not_exist(opstest_path) + def _path(filename): + return os.path.join(opstest_path, filename) + + dispatch = { + "control_dep.zip": make_control_dep_tests, + "add.zip": make_add_tests, + "conv.zip": make_conv_tests, + "constant.zip": make_constant_tests, + "depthwiseconv.zip": make_depthwiseconv_tests, + "concat.zip": make_concatenation_tests, + "fully_connected.zip": make_fully_connected_tests, + "global_batch_norm.zip": make_global_batch_norm_tests, + "fused_batch_norm.zip": make_fused_batch_norm_tests, + "l2norm.zip": make_l2norm_tests, + "local_response_norm.zip": make_local_response_norm_tests, + "mul.zip": make_mul_tests, + "relu.zip": make_relu_tests, + "relu1.zip": make_relu1_tests, + "relu6.zip": make_relu6_tests, + "l2_pool.zip": make_pool_tests(make_l2_pool), + "avg_pool.zip": make_pool_tests(tf.nn.avg_pool), + "max_pool.zip": make_pool_tests(tf.nn.max_pool), + "reshape.zip": make_reshape_tests, + "resize_bilinear.zip": make_resize_bilinear_tests, + "sigmoid.zip": make_sigmoid_tests, + "softmax.zip": make_softmax_tests, + "space_to_depth.zip": make_space_to_depth_tests, + } + out = FLAGS.zip_to_output + bin_path = FLAGS.toco + if out in dispatch: + dispatch[out](_path(out)) + else: + raise RuntimeError("Invalid zip to output %r" % out) + + else: + raise RuntimeError("Invalid argument for type of generation.") + + +if __name__ == "__main__": + FLAGS, unparsed = parser.parse_known_args() + + if unparsed: + print("Usage: %s <path out> zipped <zip file to generate>") + else: + tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/contrib/lite/testing/generate_examples_report.py b/tensorflow/contrib/lite/testing/generate_examples_report.py new file mode 100644 index 0000000000..7bcf8cd86a --- /dev/null +++ b/tensorflow/contrib/lite/testing/generate_examples_report.py @@ -0,0 +1,125 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Make HTML tables that report where TF and TOCO failed to convert models. + +This is primarily used by generate_examples.py. See it or +`make_report_table` for more details on usage. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import cgi +import json + +FAILED = "FAILED" +SUCCESS = "SUCCESS" +NOTRUN = "NOTRUN" + + +def make_report_table(fp, title, reports): + """Make an HTML report of the success/failure reports. + + Args: + fp: File-like object in which to put the html. + title: "Title of the zip file this pertains to." + reports: a list of conversion attempts. (report_args, report_vals) i.e. + ({"shape": [1,2,3], "type": "tf.float32"}, + {"tf": "SUCCESS", "toco": "FAILURE", "toco_log": "Unsupported type.", + "tf_log": ""}) + """ + # sort reports by if TOCO failure and then TF failure (reversed) + reports.sort(key=lambda x: x[1]["toco"], reverse=False) + reports.sort(key=lambda x: x[1]["tf"], reverse=True) + def result_cell(x, row, col): + """Produce a cell with the condition string `x`.""" + s = cgi.escape(repr(x), quote=True) + color = "#44ff44" if x == SUCCESS else ( + "#ff4444" if x == FAILED else "#eeeeee") + handler = "ShowLog(%d, %d)" % (row, col) + fp.write("<td style='background-color: %s' onclick='%s'>%s</td>\n" % ( + color, handler, s)) + + fp.write("""<html> +<head> +<title>tflite report</title> +<style> +body { font-family: Arial; } +th { background-color: #555555; color: #eeeeee; } +td { vertical-align: top; } +td.horiz {width: 50%;} +pre { white-space: pre-wrap; word-break: keep-all; } +table {width: 100%;} +</style> +</head> +""") + # Write the log data to a javascript variable and also make a function + # in javascript to show the log when an item is clicked. + fp.write("<script> \n") + fp.write(""" +function ShowLog(row, col) { + +var log = document.getElementById("log"); +log.innerHTML = "<pre>" + data[row][col] + "</pre>"; +} +""") + fp.write("var data = \n") + fp.write(json.dumps([[cgi.escape(x[1]["tf_log"], quote=True), + cgi.escape(x[1]["toco_log"], quote=True)] + for x in reports])) + fp.write(";</script>\n") + + # Write the main table and use onclick on the items that have log items. + fp.write(""" +<body> +<h1>TOCO Conversion</h1> +<h2>%s</h2> +""" % title) + + # Get a list of keys that are in any of the records. + param_keys = {} + for params, _ in reports: + for k in params.keys(): + param_keys[k] = True + + fp.write("<table>\n") + fp.write("<tr><td class='horiz'>\n") + fp.write("<div style='height:1000px; overflow:auto'>\n") + fp.write("<table>\n") + fp.write("<tr>\n") + for p in param_keys: + fp.write("<th>%s</th>\n" % cgi.escape(p, quote=True)) + fp.write("<th>TensorFlow</th>\n") + fp.write("<th>TOCO</th>\n") + fp.write("</tr>\n") + for idx, (params, vals) in enumerate(reports): + fp.write("<tr>\n") + for p in param_keys: + fp.write(" <td>%s</td>\n" % cgi.escape(repr(params[p]), quote=True)) + + result_cell(vals["tf"], idx, 0) + result_cell(vals["toco"], idx, 1) + fp.write("</tr>\n") + fp.write("</table>\n") + fp.write("</div>\n") + fp.write("</td>\n") + fp.write("<td class='horiz' id='log'></td></tr>\n") + fp.write("</table>\n") + fp.write("<script>\n") + fp.write("</script>\n") + fp.write(""" + </body> + </html> + """) diff --git a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc new file mode 100644 index 0000000000..e7df97ee54 --- /dev/null +++ b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc @@ -0,0 +1,279 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <cstdarg> +#include <cstdio> +#include <cstdlib> +#include <fstream> +#include <map> +#include <sstream> +#include <gtest/gtest.h> +#include "re2/re2.h" +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/testing/parse_testdata.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/subprocess.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace { +bool FLAGS_ignore_known_bugs = true; +} // namespace + +namespace tflite { +namespace testing { + +// TensorFlow system environment for file system called. +tensorflow::Env* env = tensorflow::Env::Default(); + +// List of tests that are expected to fail when +// --test_arg=--ignore_known_bugs=false +// Key is a substring of the test name and value is a bug number. +// TODO(ahentz): make sure we clean this list up frequently. +std::map<string, string> kBrokenTests = { + // Add doesn't support broadcasting. + {R"(addd.*input_shape_1=\[1,3,4,3\],input_shape_2=\[3\])", "68500195"}, + {R"(muld.*input_shape_1=\[1,3,4,3\],input_shape_2=\[3\])", "68500195"}, + + // Add only supports float32. (and "constant" tests use Add) + {R"(addd.*int32)", "68808744"}, + {R"(constant.*int32)", "68808744"}, + {R"(mul.*int32)", "68808744"}, + + // Toco or TFLite has a bug to deal with some constant functions with + // more than 1 element. + {R"(constant.*input_shape=\[(2|2,2,2,2)\])", "68721522"}, + + // L2Norm only supports 4D tensors. + {R"(l2normdim=.*,epsilon=.*,input_shape=\[.,.\])", "67963684"}, + {R"(l2normdim=.*,epsilon=.*,input_shape=\[.,.,.,.,.*\])", "67963684"}, + + // L2Norm only works for dim=-1. + {R"(l2normdim=-2,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"}, + {R"(l2normdim=-2,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"}, + {R"(l2normdim=2,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"}, + {R"(l2normdim=2,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"}, + {R"(l2normdim=0,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"}, + {R"(l2normdim=0,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"}, + {R"(l2normdim=1,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"}, + {R"(l2normdim=1,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"}, + {R"(l2normdim=\[2,3\],epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"}, + {R"(l2normdim=\[2,3\],epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"}, + + // ResizeBilinear looks completely incompatible with Tensorflow + {R"(resize_bilinear)", "67964336"}, +}; + +// Allows test data to be unzipped into a temporary directory and makes +// sure those temporary directories are removed later. +class ZipEnvironment : public ::testing::Environment { + public: + ~ZipEnvironment() override {} + + // Delete all temporary directories on teardown. + void TearDown() override { + for (const auto& dir : temporary_directories_) { + tensorflow::int64 undeleted_dirs, undeleted_files; + TF_CHECK_OK( + env->DeleteRecursively(dir, &undeleted_dirs, &undeleted_files)); + } + temporary_directories_.clear(); + } + + // Unzip `zip` file into a new temporary directory `out_dir`. + tensorflow::Status UnZip(const std::string& zip, std::string* out_dir) { + string dir; + TF_CHECK_OK(MakeTemporaryDirectory(&dir)); + tensorflow::SubProcess proc; + std::string unzip_binary = + "/usr/bin/unzip"; + proc.SetProgram(unzip_binary, {"unzip", "-d", dir, zip.c_str()}); + proc.SetChannelAction(tensorflow::CHAN_STDOUT, tensorflow::ACTION_PIPE); + proc.SetChannelAction(tensorflow::CHAN_STDERR, tensorflow::ACTION_PIPE); + if (!proc.Start()) + return tensorflow::Status(tensorflow::error::UNKNOWN, + "unzip couldn't start"); + string out, err; + int status = proc.Communicate(nullptr, &out, &err); + if (WEXITSTATUS(status) == 0) { + *out_dir = dir; + return tensorflow::Status::OK(); + } else { + return tensorflow::Status(tensorflow::error::UNKNOWN, "unzip failed"); + } + } + + private: + // Make a temporary directory and return its name in `temporary`. + tensorflow::Status MakeTemporaryDirectory(string* temporary) { + if (env->LocalTempFilename(temporary)) { + TF_CHECK_OK(env->CreateDir(*temporary)); + temporary_directories_.push_back(*temporary); + return tensorflow::Status::OK(); + } + return tensorflow::Status(tensorflow::error::UNKNOWN, + "make temporary directory failed"); + } + + std::vector<string> temporary_directories_; +}; + +// Return the singleton zip_environment. +ZipEnvironment* zip_environment() { + static ZipEnvironment* env = new ZipEnvironment; + return env; +} + +// Read the manifest.txt out of the unarchived zip file. Specifically +// `original_file` is the original zip file for error messages. `dir` is +// the temporary directory where the zip file has been unarchived and +// `test_paths` is the list of test prefixes that were in the manifest. +// Note, it is an error for a manifest to contain no tests. +tensorflow::Status ReadManifest(const std::string& original_file, + const std::string& dir, + std::vector<std::string>* test_paths) { + // Read the newline delimited list of entries in the manifest. + std::ifstream manifest_fp(dir + "/manifest.txt"); + std::string manifest((std::istreambuf_iterator<char>(manifest_fp)), + std::istreambuf_iterator<char>()); + size_t pos = 0; + int added = 0; + while (true) { + size_t end_pos = manifest.find("\n", pos); + if (end_pos == std::string::npos) break; + std::string filename = manifest.substr(pos, end_pos - pos); + test_paths->push_back(dir + "/" + filename); + pos = end_pos + 1; + added += 1; + } + if (!added) { + std::string message = "Test had no examples: " + original_file; + return tensorflow::Status(tensorflow::error::UNKNOWN, message.c_str()); + } + return tensorflow::Status::OK(); +} + +// Get a list of tests from a zip file `zip_file_name`. +std::vector<std::string> UnarchiveZipAndFindTestNames( + const std::string& zip_file_name) { + std::string zip_file = ::tensorflow::testing::TensorFlowSrcRoot() + + "/contrib/lite/testing/optest/" + zip_file_name; + std::string decompress_tmp_dir; + TF_CHECK_OK(zip_environment()->UnZip(zip_file, &decompress_tmp_dir)); + std::vector<std::string> stuff; + TF_CHECK_OK(ReadManifest(zip_file, decompress_tmp_dir, &stuff)); + return stuff; +} + +class OpsTest : public ::testing::TestWithParam<std::string> {}; + +TEST_P(OpsTest, RunStuff) { + std::string test_path = GetParam(); + std::string tflite_file = test_path + ".bin"; + std::string tflite_examples = test_path + ".inputs"; + auto model = tflite::FlatBufferModel::BuildFromFile(tflite_file.c_str()); + std::unique_ptr<tflite::Interpreter> interpreter; + + tflite::ops::builtin::BuiltinOpResolver builtins; + ASSERT_EQ(tflite::InterpreterBuilder(*model, builtins)(&interpreter), + kTfLiteOk); + + std::vector<tflite::testing::Example> examples; + ASSERT_EQ(tflite::testing::ParseExamples(tflite_examples.c_str(), &examples), + kTfLiteOk); + + string bug_number; + for (const auto& p : kBrokenTests) { + if (RE2::PartialMatch(test_path, p.first)) { + bug_number = p.second; + } + } + + for (const auto& example : examples) { + ASSERT_EQ(interpreter->inputs().size(), example.inputs.size()); + auto result = [&]() { + TF_LITE_ENSURE_STATUS(FeedExample(interpreter.get(), example)); + TF_LITE_ENSURE_STATUS(interpreter->Invoke()); + TF_LITE_ENSURE_STATUS(CheckOutputs(interpreter.get(), example)); + return kTfLiteOk; + }(); + + if (bug_number.empty()) { + ASSERT_EQ(result, kTfLiteOk); + } else { + if (FLAGS_ignore_known_bugs) { + ASSERT_EQ(result, kTfLiteError) + << "Not failing as expected dut to http://b/" << bug_number; + } else { + ASSERT_EQ(result, kTfLiteOk) + << "Possibly due to http://b/" << bug_number; + } + } + } +} + +// Instantiate a test. This assumes `zip_base`.zip is a declared data file +// of this test. +#define INSTANTIATE_TESTS(zip_base) \ + INSTANTIATE_TEST_CASE_P( \ + zip_base, OpsTest, \ + ::testing::ValuesIn(UnarchiveZipAndFindTestNames(#zip_base ".zip"))); + +INSTANTIATE_TESTS(add) +INSTANTIATE_TESTS(avg_pool) +INSTANTIATE_TESTS(concat) +INSTANTIATE_TESTS(constant) +INSTANTIATE_TESTS(control_dep) +INSTANTIATE_TESTS(conv) +INSTANTIATE_TESTS(depthwiseconv) +INSTANTIATE_TESTS(fully_connected) +INSTANTIATE_TESTS(fused_batch_norm) +INSTANTIATE_TESTS(global_batch_norm) +INSTANTIATE_TESTS(l2norm) +INSTANTIATE_TESTS(l2_pool) +INSTANTIATE_TESTS(local_response_norm) +INSTANTIATE_TESTS(max_pool) +INSTANTIATE_TESTS(mul) +INSTANTIATE_TESTS(relu) +INSTANTIATE_TESTS(relu1) +INSTANTIATE_TESTS(relu6) +INSTANTIATE_TESTS(reshape) +INSTANTIATE_TESTS(resize_bilinear) +INSTANTIATE_TESTS(sigmoid) +INSTANTIATE_TESTS(softmax) +INSTANTIATE_TESTS(space_to_depth) + +} // namespace testing +} // namespace tflite + +int main(int argc, char** argv) { + ::testing::AddGlobalTestEnvironment(tflite::testing::zip_environment()); + + std::vector<tensorflow::Flag> flags = {tensorflow::Flag( + "ignore_known_bugs", &FLAGS_ignore_known_bugs, + "If a particular model is affected by a known bug, the " + "corresponding test should expect the outputs to not match.")}; + bool success = tensorflow::Flags::Parse(&argc, argv, flags); + if (!success || (argc == 2 && !strcmp(argv[1], "--helpfull"))) { + fprintf(stderr, "%s", tensorflow::Flags::Usage(argv[0], flags).c_str()); + return 1; + } + + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/testing/message.cc b/tensorflow/contrib/lite/testing/message.cc new file mode 100644 index 0000000000..03fae4bb86 --- /dev/null +++ b/tensorflow/contrib/lite/testing/message.cc @@ -0,0 +1,96 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/testing/message.h" + +#include <stack> + +#include "tensorflow/contrib/lite/testing/tokenize.h" + +namespace tflite { +namespace testing { + +// A token processor that builds messages and forward calls to the current +// message object. Place a new message at the top of the stack when it start +// and remove it when it is finished. +class MessageStack : public TokenProcessor { + public: + // Start a new MessageStack with the given first_node, which will be used to + // process freestanding fields and submessages. + explicit MessageStack(Message* first_node) { + nodes_.push(first_node); + valid_ = true; + } + + void ConsumeToken(std::string* token) override { + if (!valid_) return; + Message* current_node = nodes_.top(); + if (*token == "{") { + // This is the beginning of a new message, names after the previous token. + if (previous_token_.empty()) { + valid_ = false; + return; + } + nodes_.push(current_node ? current_node->AddChild(previous_token_) + : nullptr); + previous_token_.clear(); + } else if (*token == "}") { + // A message is being completed. There should be no previous token. Note + // that the top-level message never closes, so we should always have at + // least one entry in the stack. + if (nodes_.size() == 1 || !previous_token_.empty()) { + valid_ = false; + return; + } + if (current_node) { + current_node->Finish(); + } + nodes_.pop(); + } else if (*token == ":") { + // We reached the end of the 'key' portion of a field. Store the token + // until we have the 'value' portion. + if (previous_token_.empty()) { + valid_ = false; + return; + } + } else { + if (previous_token_.empty()) { + previous_token_.swap(*token); + } else { + // This is the 'value' portion of a field. The previous token is the + // 'key'. + if (current_node) { + current_node->SetField(previous_token_, *token); + } + previous_token_.clear(); + } + } + } + + bool valid() const { return valid_; } + + private: + std::stack<Message*> nodes_; + std::string previous_token_; + bool valid_; +}; + +bool Message::Read(std::istream* input, Message* message) { + MessageStack stack(message); + Tokenize(input, &stack); + return stack.valid(); +} + +} // namespace testing +} // namespace tflite diff --git a/tensorflow/contrib/lite/testing/message.h b/tensorflow/contrib/lite/testing/message.h new file mode 100644 index 0000000000..78ef7e2cbe --- /dev/null +++ b/tensorflow/contrib/lite/testing/message.h @@ -0,0 +1,82 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_MESSAGE_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_MESSAGE_H_ + +#include <memory> +#include <string> +#include <vector> + +namespace tflite { +namespace testing { + +// A Message is a textual protobuf-like structure that looks like: +// tag { +// f : "values" +// child { +// a : 1 +// } +// } +// This class provides the framework for processing message but does not +// associate any particular behavior to fields and submessage. In order +// to properly parse a stream this class must be derived. +class Message { + public: + // Reads a stream, tokenizes it and create a new message under the given + // top-level message. Returns true if the parsing succeeded. + static bool Read(std::istream* input, Message* message); + + Message() {} + virtual ~Message() {} + + // Called when a new field is found. For example, when: + // f : "values" + // is found, it triggers: + // SetField("f", "values"); + virtual void SetField(const std::string& name, const std::string& value) {} + + // Called when a submessage is started. For example, when: + // child { + // is found, it triggers + // AddChild("child"); + // If nullptr is returned, the contents of the submessage will be ignored. + // Otherwise, the returned Message will be used to handle new fields and new + // submessages. The caller should not take ownership of the returned pointer. + virtual Message* AddChild(const std::string& name) { return nullptr; } + + // Called when a submessage is completed, that is, whenever a '}' is found. + virtual void Finish() {} + + protected: + // Takes ownership of the given pointer. Subclasses can use this method if + // they don't want to implement their own ownership semantics. + Message* Store(Message* n) { + children_.emplace_back(n); + return n; + } + + // Returns a list of all owned submessages. + const std::vector<std::unique_ptr<Message>>& Children() const { + return children_; + } + + private: + std::vector<std::unique_ptr<Message>> children_; +}; + +} // namespace testing +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_MESSAGE_H_ diff --git a/tensorflow/contrib/lite/testing/message_test.cc b/tensorflow/contrib/lite/testing/message_test.cc new file mode 100644 index 0000000000..fb6a49bd6f --- /dev/null +++ b/tensorflow/contrib/lite/testing/message_test.cc @@ -0,0 +1,121 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/testing/message.h" + +#include <map> + +#include <gmock/gmock.h> +#include <gtest/gtest.h> + +namespace tflite { +namespace testing { +namespace { + +// A hierarchical, key-value store. +class TestMessage : public Message { + public: + TestMessage() {} + explicit TestMessage(const std::string& text_to_parse) { + std::stringstream ss(text_to_parse); + finished_ = Message::Read(&ss, this); + } + void SetField(const std::string& name, const std::string& value) override { + fields_[name] = value; + } + Message* AddChild(const std::string& name) override { + TestMessage* m = new TestMessage; + m->name_ = name; + return Store(m); + } + void Finish() override { finished_ = true; } + + int NumChildren() const { return Children().size(); } + + const TestMessage* GetChild(int i) const { + return dynamic_cast<TestMessage*>(Children()[i].get()); + } + + int NumFields() const { return fields_.size(); } + const std::string& GetField(const std::string& key) const { + return fields_.at(key); + } + + const std::string& name() const { return name_; } + bool finished() const { return finished_; } + + protected: + std::string name_; + std::map<std::string, std::string> fields_; + bool finished_ = false; +}; + +TEST(MessageTest, Simple) { + TestMessage message("x{a:1 b:2} y{} z{c:3} d:4"); + ASSERT_TRUE(message.finished()); + + ASSERT_EQ(message.NumFields(), 1); + EXPECT_EQ(message.GetField("d"), "4"); + + ASSERT_EQ(message.NumChildren(), 3); + + auto* x = message.GetChild(0); + EXPECT_EQ(x->name(), "x"); + ASSERT_EQ(x->NumFields(), 2); + EXPECT_EQ(x->GetField("a"), "1"); + EXPECT_EQ(x->GetField("b"), "2"); + + auto* y = message.GetChild(1); + EXPECT_EQ(y->name(), "y"); + ASSERT_EQ(y->NumFields(), 0); + + auto* z = message.GetChild(2); + EXPECT_EQ(z->name(), "z"); + ASSERT_EQ(z->NumFields(), 1); + EXPECT_EQ(z->GetField("c"), "3"); +} + +TEST(MessageTest, Unnamed) { + TestMessage message("x{c:3} {} y{d:4}"); + ASSERT_FALSE(message.finished()); + EXPECT_EQ(message.NumChildren(), 1); +} + +TEST(MessageTest, TooManyBraces) { + TestMessage message("x{c:3} } y{d:4}"); + ASSERT_FALSE(message.finished()); + EXPECT_EQ(message.NumChildren(), 1); +} + +TEST(MessageTest, LeftoverToken) { + TestMessage message("x{c:3} z{test} y{d:4}"); + ASSERT_FALSE(message.finished()); + EXPECT_EQ(message.NumChildren(), 2); +} + +TEST(MessageTest, MissingKey) { + TestMessage message("x{c:3} z{:test} y{d:4}"); + ASSERT_FALSE(message.finished()); + EXPECT_EQ(message.NumChildren(), 2); +} + +TEST(MessageTest, MissingValue) { + TestMessage message("x{c:3} z{test:} y{d:4}"); + ASSERT_FALSE(message.finished()); + EXPECT_EQ(message.NumChildren(), 2); +} + +} // namespace +} // namespace testing +} // namespace tflite diff --git a/tensorflow/contrib/lite/testing/nnapi_example.cc b/tensorflow/contrib/lite/testing/nnapi_example.cc new file mode 100644 index 0000000000..74f6cfc3de --- /dev/null +++ b/tensorflow/contrib/lite/testing/nnapi_example.cc @@ -0,0 +1,114 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// NOTE: this is an example driver that converts a tflite model to TensorFlow. +// This is an example that will be integrated more tightly into tflite in +// the future. +// +// Usage: bazel run -c opt \ +// tensorflow/contrib/lite/nnapi:nnapi_example -- <filename> +// +#include <cstdarg> +#include <cstdio> +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h" +#include "tensorflow/contrib/lite/testing/parse_testdata.h" + +// TODO(aselle): FATAL leaves resources hanging. +void FATAL(const char* format, ...) { + va_list args; + va_start(args, format); + vfprintf(stderr, format, args); + va_end(args); + fflush(stderr); + exit(1); +} + +#define CHECK_TFLITE_SUCCESS(x) \ + if (x != kTfLiteOk) { \ + FATAL("Aborting since tflite returned failure."); \ + } + +void Interpret(const char* filename, const char* examples_filename, + bool use_nnapi) { + // TODO(aselle): Resize of input image should go here + // ... + // For now I am allocating all tensors. This means I am fixed size. + // So I am not using the variable size ability yet. + fprintf(stderr, "example file %s\n", examples_filename); + std::vector<tflite::testing::Example> examples; + CHECK_TFLITE_SUCCESS( + tflite::testing::ParseExamples(examples_filename, &examples)); + + for (const tflite::testing::Example& example : examples) { + auto model = tflite::FlatBufferModel::BuildFromFile(filename); + if (!model) FATAL("Cannot read file %s\n", filename); + std::unique_ptr<tflite::Interpreter> interpreter; + tflite::ops::builtin::BuiltinOpResolver builtins; + + CHECK_TFLITE_SUCCESS( + tflite::InterpreterBuilder(*model, builtins)(&interpreter)); + + printf("Use nnapi is set to: %d\n", use_nnapi); + interpreter->UseNNAPI(use_nnapi); + CHECK_TFLITE_SUCCESS( + tflite::testing::FeedExample(interpreter.get(), example)); + + { + TfLiteTensor* tensor = interpreter->tensor(interpreter->outputs()[0]); + if (float* data = + interpreter->typed_tensor<float>(interpreter->outputs()[0])) { + size_t num = tensor->bytes / sizeof(float); + for (float* p = data; p < data + num; p++) { + *p = 0; + } + } + } + interpreter->Invoke(); + + CHECK_TFLITE_SUCCESS( + tflite::testing::CheckOutputs(interpreter.get(), example)); + + printf("Result:\n"); + TfLiteTensor* tensor = interpreter->tensor(interpreter->outputs()[0]); + if (float* data = + interpreter->typed_tensor<float>(interpreter->outputs()[0])) { + size_t num = tensor->bytes / sizeof(float); + for (float* p = data; p < data + num; p++) { + printf(" %f", *p); + } + } + } +} + +int main(int argc, char* argv[]) { + bool use_nnapi = true; + if (argc == 4) { + use_nnapi = strcmp(argv[3], "1") == 0 ? true : false; + } + if (argc < 3) { + fprintf(stderr, + "Compiled " __DATE__ __TIME__ + "\n" + "Usage!!!: %s <tflite model> <examples to test> " + "{ use nn api i.e. 0,1}\n", + argv[0]); + return 1; + } + Interpret(argv[1], argv[2], use_nnapi); + return 0; +} diff --git a/tensorflow/contrib/lite/testing/parse_testdata.cc b/tensorflow/contrib/lite/testing/parse_testdata.cc new file mode 100644 index 0000000000..2b67052cad --- /dev/null +++ b/tensorflow/contrib/lite/testing/parse_testdata.cc @@ -0,0 +1,335 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Parses tflite example input data. +// Format is ASCII +// TODO(aselle): Switch to protobuf, but the android team requested a simple +// ASCII file. +#include "tensorflow/contrib/lite/testing/parse_testdata.h" + +#include <cmath> +#include <cstdint> +#include <cstdio> +#include <fstream> +#include <iostream> +#include <streambuf> + +#include "tensorflow/contrib/lite/error_reporter.h" +#include "tensorflow/contrib/lite/testing/message.h" +#include "tensorflow/contrib/lite/testing/split.h" + +namespace tflite { +namespace testing { +namespace { + +// Fatal error if parse error occurs +#define PARSE_CHECK_EQ(filename, current_line, x, y) \ + if ((x) != (y)) { \ + fprintf(stderr, "Parse Error @ %s:%d\n File %s\n Line %d, %s != %s\n", \ + __FILE__, __LINE__, filename, current_line + 1, #x, #y); \ + return kTfLiteError; \ + } + +// Breakup a "," delimited line into a std::vector<std::string>. +// This is extremely inefficient, and just used for testing code. +// TODO(aselle): replace with absl when we use it. +std::vector<std::string> ParseLine(const std::string& line) { + size_t pos = 0; + std::vector<std::string> elements; + while (true) { + size_t end = line.find(',', pos); + if (end == std::string::npos) { + elements.push_back(line.substr(pos)); + break; + } else { + elements.push_back(line.substr(pos, end - pos)); + } + pos = end + 1; + } + return elements; +} + +} // namespace + +// Given a `filename`, produce a vector of Examples corresopnding +// to test cases that can be applied to a tflite model. +TfLiteStatus ParseExamples(const char* filename, + std::vector<Example>* examples) { + std::ifstream fp(filename); + if (!fp.good()) { + fprintf(stderr, "Could not read '%s'\n", filename); + return kTfLiteError; + } + std::string str((std::istreambuf_iterator<char>(fp)), + std::istreambuf_iterator<char>()); + size_t pos = 0; + + // \n and , delimit parse a file. + std::vector<std::vector<std::string>> csv; + while (true) { + size_t end = str.find('\n', pos); + + if (end == std::string::npos) { + csv.emplace_back(ParseLine(str.substr(pos))); + break; + } + csv.emplace_back(ParseLine(str.substr(pos, end - pos))); + pos = end + 1; + } + + int current_line = 0; + PARSE_CHECK_EQ(filename, current_line, csv[0][0], "test_cases"); + int example_count = std::stoi(csv[0][1]); + current_line++; + + auto parse_tensor = [&filename, ¤t_line, + &csv](FloatTensor* tensor_ptr) { + PARSE_CHECK_EQ(filename, current_line, csv[current_line][0], "dtype"); + current_line++; + // parse shape + PARSE_CHECK_EQ(filename, current_line, csv[current_line][0], "shape"); + size_t elements = 1; + FloatTensor& tensor = *tensor_ptr; + + for (size_t i = 1; i < csv[current_line].size(); i++) { + const auto& shape_part_to_parse = csv[current_line][i]; + if (shape_part_to_parse.empty()) { + // Case of a 0-dimensional shape + break; + } + int shape_part = std::stoi(shape_part_to_parse); + elements *= shape_part; + tensor.shape.push_back(shape_part); + } + current_line++; + // parse data + PARSE_CHECK_EQ(filename, current_line, csv[current_line].size() - 1, + elements); + for (size_t i = 1; i < csv[current_line].size(); i++) { + tensor.flat_data.push_back(std::stof(csv[current_line][i])); + } + current_line++; + + return kTfLiteOk; + }; + + for (int example_idx = 0; example_idx < example_count; example_idx++) { + Example example; + PARSE_CHECK_EQ(filename, current_line, csv[current_line][0], "inputs"); + int inputs = std::stoi(csv[current_line][1]); + current_line++; + // parse dtype + for (int input_index = 0; input_index < inputs; input_index++) { + example.inputs.push_back(FloatTensor()); + TF_LITE_ENSURE_STATUS(parse_tensor(&example.inputs.back())); + } + + PARSE_CHECK_EQ(filename, current_line, csv[current_line][0], "outputs"); + int outputs = std::stoi(csv[current_line][1]); + current_line++; + for (int input_index = 0; input_index < outputs; input_index++) { + example.outputs.push_back(FloatTensor()); + TF_LITE_ENSURE_STATUS(parse_tensor(&example.outputs.back())); + } + examples->emplace_back(example); + } + return kTfLiteOk; +} + +TfLiteStatus FeedExample(tflite::Interpreter* interpreter, + const Example& example) { + // Resize inputs to match example & allocate. + for (size_t i = 0; i < interpreter->inputs().size(); i++) { + int input_index = interpreter->inputs()[i]; + + TF_LITE_ENSURE_STATUS( + interpreter->ResizeInputTensor(input_index, example.inputs[i].shape)); + } + TF_LITE_ENSURE_STATUS(interpreter->AllocateTensors()); + // Copy data into tensors. + for (size_t i = 0; i < interpreter->inputs().size(); i++) { + int input_index = interpreter->inputs()[i]; + if (float* data = interpreter->typed_tensor<float>(input_index)) { + for (size_t idx = 0; idx < example.inputs[i].flat_data.size(); idx++) { + data[idx] = example.inputs[i].flat_data[idx]; + } + } else if (int32_t* data = + interpreter->typed_tensor<int32_t>(input_index)) { + for (size_t idx = 0; idx < example.inputs[i].flat_data.size(); idx++) { + data[idx] = example.inputs[i].flat_data[idx]; + } + } else { + fprintf(stderr, "input[%zu] was not float or int data\n", i); + return kTfLiteError; + } + } + return kTfLiteOk; +} + +TfLiteStatus CheckOutputs(tflite::Interpreter* interpreter, + const Example& example) { + constexpr double kRelativeThreshold = 1e-2f; + constexpr double kAbsoluteThreshold = 1e-4f; + + ErrorReporter* context = DefaultErrorReporter(); + int model_outputs = interpreter->outputs().size(); + TF_LITE_ENSURE_EQ(context, model_outputs, example.outputs.size()); + for (size_t i = 0; i < interpreter->outputs().size(); i++) { + int output_index = interpreter->outputs()[i]; + if (const float* data = interpreter->typed_tensor<float>(output_index)) { + for (size_t idx = 0; idx < example.outputs[i].flat_data.size(); idx++) { + float computed = data[idx]; + float reference = example.outputs[0].flat_data[idx]; + float diff = std::abs(computed - reference); + bool error_is_large = false; + // For very small numbers, try absolute error, otherwise go with + // relative. + if (std::abs(reference) < kRelativeThreshold) { + error_is_large = (diff > kAbsoluteThreshold); + } else { + error_is_large = (diff > kRelativeThreshold * std::abs(reference)); + } + if (error_is_large) { + fprintf(stdout, "output[%zu][%zu] did not match %f vs reference %f\n", + i, idx, data[idx], reference); + return kTfLiteError; + } + } + fprintf(stderr, "\n"); + } else if (const int32_t* data = + interpreter->typed_tensor<int32_t>(output_index)) { + for (size_t idx = 0; idx < example.outputs[i].flat_data.size(); idx++) { + int32_t computed = data[idx]; + int32_t reference = example.outputs[0].flat_data[idx]; + if (std::abs(computed - reference) > 0) { + fprintf(stderr, "output[%zu][%zu] did not match %d vs reference %f\n", + i, idx, data[idx], example.outputs[0].flat_data[idx]); + return kTfLiteError; + } + } + fprintf(stderr, "\n"); + } else { + fprintf(stderr, "output[%zu] was not float or int data\n", i); + return kTfLiteError; + } + } + return kTfLiteOk; +} + +// Process an 'invoke' message, triggering execution of the test runner, as +// well as verification of outputs. An 'invoke' message looks like: +// invoke { +// id: xyz +// input: 1,2,1,1,1,2,3,4 +// ouput: 4,5,6 +// } +class Invoke : public Message { + public: + explicit Invoke(TestRunner* test_runner) : test_runner_(test_runner) { + expected_inputs_ = test_runner->GetInputs(); + expected_outputs_ = test_runner->GetOutputs(); + } + + void SetField(const std::string& name, const std::string& value) override { + if (name == "id") { + test_runner_->SetInvocationId(value); + } else if (name == "input") { + if (expected_inputs_.empty()) { + return test_runner_->Invalidate("Too many inputs"); + } + test_runner_->SetInput(*expected_inputs_.begin(), value); + expected_inputs_.erase(expected_inputs_.begin()); + } else if (name == "output") { + if (expected_outputs_.empty()) { + return test_runner_->Invalidate("Too many outputs"); + } + test_runner_->SetExpectation(*expected_outputs_.begin(), value); + expected_outputs_.erase(expected_outputs_.begin()); + } + } + void Finish() override { + test_runner_->Invoke(); + test_runner_->CheckResults(); + } + + private: + std::vector<int> expected_inputs_; + std::vector<int> expected_outputs_; + + TestRunner* test_runner_; +}; + +// Process an 'reshape' message, triggering resizing of the input tensors via +// the test runner. A 'reshape' message looks like: +// reshape { +// input: 1,2,1,1,1,2,3,4 +// } +class Reshape : public Message { + public: + explicit Reshape(TestRunner* test_runner) : test_runner_(test_runner) { + expected_inputs_ = test_runner->GetInputs(); + } + + void SetField(const std::string& name, const std::string& value) override { + if (name == "input") { + if (expected_inputs_.empty()) { + return test_runner_->Invalidate("Too many inputs to reshape"); + } + test_runner_->ReshapeTensor(*expected_inputs_.begin(), value); + expected_inputs_.erase(expected_inputs_.begin()); + } + } + + private: + std::vector<int> expected_inputs_; + TestRunner* test_runner_; +}; + +// This is the top-level message in a test file. +class TestData : public Message { + public: + explicit TestData(TestRunner* test_runner) : test_runner_(test_runner) {} + + void SetField(const std::string& name, const std::string& value) override { + if (name == "load_model") { + test_runner_->LoadModel(value); + } else if (name == "init_state") { + test_runner_->AllocateTensors(); + for (int id : Split<int>(value, ",")) { + test_runner_->ResetTensor(id); + } + } + } + Message* AddChild(const std::string& s) override { + if (s == "invoke") { + test_runner_->AllocateTensors(); + return Store(new Invoke(test_runner_)); + } else if (s == "reshape") { + return Store(new Reshape(test_runner_)); + } + return nullptr; + } + + private: + TestRunner* test_runner_; +}; + +bool ParseAndRunTests(std::istream* input, TestRunner* test_runner) { + TestData test_data(test_runner); + Message::Read(input, &test_data); + return test_runner->IsValid() && test_runner->GetOverallSuccess(); +} + +} // namespace testing +} // namespace tflite diff --git a/tensorflow/contrib/lite/testing/parse_testdata.h b/tensorflow/contrib/lite/testing/parse_testdata.h new file mode 100644 index 0000000000..90839fe245 --- /dev/null +++ b/tensorflow/contrib/lite/testing/parse_testdata.h @@ -0,0 +1,74 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_NNAPI_PARSE_TESTDATA_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_NNAPI_PARSE_TESTDATA_H_ + +#include <vector> +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/testing/test_runner.h" + +namespace tflite { +namespace testing { + +// Shape and data for a float tensor +struct FloatTensor { + std::vector<int> shape; + std::vector<float> flat_data; +}; + +// A prescribed input, output example +struct Example { + std::vector<FloatTensor> inputs; + std::vector<FloatTensor> outputs; +}; + +// Parses an example input and output file (used for unit tests) +TfLiteStatus ParseExamples(const char* filename, + std::vector<Example>* examples); + +// Inputs Tensors into a TensorFlow lite interpreter. Note, this will run +// interpreter.AllocateTensors(); +TfLiteStatus FeedExample(tflite::Interpreter* interpreter, const Example&); + +// Check outputs against (already) evaluated result. +TfLiteStatus CheckOutputs(tflite::Interpreter* interpreter, const Example&); + +// Parses a test description and feeds the given test runner with data. +// The input format is similar to an ASCII proto: +// // Loads model 'add.bin' from the TestRunner's model directory. +// load_model: "add.bin" +// // Changes the shape of inputs, provided in the same order they appear +// // in the model. +// reshape { +// input: "1,224,224,3" +// input: "1,3,4,1" +// } +// // Fills the given persistent tensors with zeros. +// init_state: 0,1,2,3 +// // Invokes the interpreter with the given input and checks that it +// // produces the expected output. Inputs and outputs should be specified in +// // the order they appear in the model. +// invoke { +// input: "1,2,3,4,56" +// input: "0.1,0.2,0.3,4.3,56.4" +// output: "12,3,4,545,3" +// output: "0.01,0.02" +// } +bool ParseAndRunTests(std::istream* input, TestRunner* test_runner); + +} // namespace testing +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_NNAPI_PARSE_TESTDATA_H_ diff --git a/tensorflow/contrib/lite/testing/split.cc b/tensorflow/contrib/lite/testing/split.cc new file mode 100644 index 0000000000..5836f4ff04 --- /dev/null +++ b/tensorflow/contrib/lite/testing/split.cc @@ -0,0 +1,42 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/testing/split.h" + +namespace tflite { +namespace testing { + +std::vector<std::pair<size_t, size_t>> SplitToPos(const string& s, + const string& delimiter) { + std::vector<std::pair<size_t, size_t>> fields; + if (delimiter.length() == 0) { + fields.emplace_back(0, s.length()); + return fields; + } + size_t pos = 0; + size_t start = 0; + while ((pos = s.find(delimiter, start)) != string::npos) { + if (pos != start) { + fields.emplace_back(start, pos); + } + start = pos + delimiter.length(); + } + if (start != s.length()) { + fields.emplace_back(start, s.length()); + } + return fields; +} + +} // namespace testing +} // namespace tflite diff --git a/tensorflow/contrib/lite/testing/split.h b/tensorflow/contrib/lite/testing/split.h new file mode 100644 index 0000000000..24071442e8 --- /dev/null +++ b/tensorflow/contrib/lite/testing/split.h @@ -0,0 +1,77 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_SPLIT_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_SPLIT_H_ + +#include <cstdlib> +#include <string> +#include <utility> +#include <vector> +#include "tensorflow/contrib/lite/string.h" + +namespace tflite { +namespace testing { + +// Splits a string based on the given delimiter string. Each pair in the +// returned vector has the start and past-the-end positions for each of the +// parts of the original string. Empty fields are not represented in the +// output. +std::vector<std::pair<size_t, size_t>> SplitToPos(const string& s, + const string& delimiter); + +// Splits the given string and converts each part to the given T. +template <typename T> +std::vector<T> Split(const string& s, const string& delimiter); + +template <> +inline std::vector<string> Split(const string& s, const string& delimiter) { + std::vector<string> fields; + for (const auto& p : SplitToPos(s, delimiter)) { + fields.push_back(s.substr(p.first, p.second - p.first)); + } + return fields; +} + +template <> +inline std::vector<int> Split(const string& s, const string& delimiter) { + std::vector<int> fields; + for (const auto& p : SplitToPos(s, delimiter)) { + fields.push_back(strtol(s.data() + p.first, nullptr, 10)); + } + return fields; +} + +template <> +inline std::vector<float> Split(const string& s, const string& delimiter) { + std::vector<float> fields; + for (const auto& p : SplitToPos(s, delimiter)) { + fields.push_back(strtod(s.data() + p.first, nullptr)); + } + return fields; +} + +template <> +inline std::vector<uint8_t> Split(const string& s, const string& delimiter) { + std::vector<uint8_t> fields; + for (const auto& p : SplitToPos(s, delimiter)) { + fields.push_back(strtol(s.data() + p.first, nullptr, 10)); + } + return fields; +} + +} // namespace testing +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_SPLIT_H_ diff --git a/tensorflow/contrib/lite/testing/split_test.cc b/tensorflow/contrib/lite/testing/split_test.cc new file mode 100644 index 0000000000..3d1e25d9c7 --- /dev/null +++ b/tensorflow/contrib/lite/testing/split_test.cc @@ -0,0 +1,57 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/testing/split.h" + +#include <gmock/gmock.h> +#include <gtest/gtest.h> + +namespace tflite { +namespace testing { +namespace { + +using ::testing::ElementsAre; +using ::testing::Pair; + +TEST(SplitTest, SplitToPos) { + EXPECT_THAT(SplitToPos("test;:1-2-3 ;: test", ";:"), + ElementsAre(Pair(0, 4), Pair(6, 12), Pair(14, 19))); + EXPECT_THAT(SplitToPos("test;:1-2-3 ;: test", ":"), + ElementsAre(Pair(0, 5), Pair(6, 13), Pair(14, 19))); + EXPECT_THAT(SplitToPos("test", ":"), ElementsAre(Pair(0, 4))); + EXPECT_THAT(SplitToPos("test ", ":"), ElementsAre(Pair(0, 5))); + EXPECT_THAT(SplitToPos("", ":"), ElementsAre()); + EXPECT_THAT(SplitToPos("test ", ""), ElementsAre(Pair(0, 5))); + EXPECT_THAT(SplitToPos("::::", ":"), ElementsAre()); +} + +TEST(SplitTest, SplitString) { + EXPECT_THAT(Split<string>("A;B;C", ";"), ElementsAre("A", "B", "C")); +} + +TEST(SplitTest, SplitFloat) { + EXPECT_THAT(Split<float>("1.0 B 1e-5", " "), ElementsAre(1.0, 0.0, 1e-5)); +} + +TEST(SplitTest, SplitInt) { + EXPECT_THAT(Split<int>("1,-1,258", ","), ElementsAre(1, -1, 258)); +} + +TEST(SplitTest, SplitUint8) { + EXPECT_THAT(Split<uint8_t>("1,-1,258", ","), ElementsAre(1, 255, 2)); +} + +} // namespace +} // namespace testing +} // namespace tflite diff --git a/tensorflow/contrib/lite/testing/test_runner.h b/tensorflow/contrib/lite/testing/test_runner.h new file mode 100644 index 0000000000..04ee4d9f7d --- /dev/null +++ b/tensorflow/contrib/lite/testing/test_runner.h @@ -0,0 +1,124 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_TEST_RUNNER_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_TEST_RUNNER_H_ + +#include <memory> +#include <string> +#include <vector> +#include "tensorflow/contrib/lite/string.h" + +namespace tflite { +namespace testing { + +// This is the base class for processing test data. Each one of the virtual +// methods must be implemented to forward the data to the appropriate executor +// (e.g. TF Lite's interpreter, or the NNAPI). +class TestRunner { + public: + TestRunner() {} + virtual ~TestRunner() {} + + // Load the given model, as a path relative to SetModelBaseDir(). + virtual void LoadModel(const string& bin_file_path) = 0; + + // Return the list of input tensors in the loaded model. + virtual const std::vector<int>& GetInputs() = 0; + + // Return the list of output tensors in the loaded model. + virtual const std::vector<int>& GetOutputs() = 0; + + // Prepare for a run by resize the given tensor. The given 'id' is + // guaranteed to be one of the ids returned by GetInputs(). + virtual void ReshapeTensor(int id, const string& csv_values) = 0; + + // Reserve memory for all tensors. + virtual void AllocateTensors() = 0; + + // Set the given tensor to some initial state, usually zero. This is + // used to reset persistent buffers in a model. + virtual void ResetTensor(int id) = 0; + + // Define the contents of the given input tensor. The given 'id' is + // guaranteed to be one of the ids returned by GetInputs(). + virtual void SetInput(int id, const string& csv_values) = 0; + + // Define what should be expected for an output tensor after Invoke() runs. + // The given 'id' is guaranteed to be one of the ids returned by + // GetOutputs(). + virtual void SetExpectation(int id, const string& csv_values) = 0; + + // Run the model. + virtual void Invoke() = 0; + + // Verify that the contents of all ouputs conform to the existing + // expectations. Return true if there are no expectations or they are all + // satisfied. + virtual bool CheckResults() = 0; + + // Set the base path for loading models. + void SetModelBaseDir(const string& path) { + model_base_dir_ = path; + if (path[path.length() - 1] != '/') { + model_base_dir_ += "/"; + } + } + + // Return the full path of a model. + string GetFullPath(const string& path) { return model_base_dir_ + path; } + + // Give an id to the next invocation to make error reporting more meaningful. + void SetInvocationId(const string& id) { invocation_id_ = id; } + const string& GetInvocationId() const { return invocation_id_; } + + // Invalidate the test runner, preventing it from executing any further. + void Invalidate(const string& error_message) { + error_message_ = error_message; + } + bool IsValid() const { return error_message_.empty(); } + const string& GetErrorMessage() const { return error_message_; } + + // Handle the overall success of this test runner. This will be true if all + // invocations were successful. + void SetOverallSuccess(bool value) { overall_success_ = value; } + bool GetOverallSuccess() const { return overall_success_; } + + protected: + // A helper to check of the given number of values is consistent with the + // number of bytes in a tensor of type T. When incompatibles sizes are found, + // the test runner is invalidated and false is returned. + template <typename T> + bool CheckSizes(size_t tensor_bytes, size_t num_values) { + size_t num_tensor_elements = tensor_bytes / sizeof(T); + if (num_tensor_elements != num_values) { + Invalidate("Expected '" + std::to_string(num_tensor_elements) + + "' elements for a tensor, but only got '" + + std::to_string(num_values) + "'"); + return false; + } + return true; + } + + private: + string model_base_dir_; + string invocation_id_; + bool overall_success_ = true; + + string error_message_; +}; + +} // namespace testing +} // namespace tflite +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_TEST_RUNNER_H_ diff --git a/tensorflow/contrib/lite/testing/test_runner_test.cc b/tensorflow/contrib/lite/testing/test_runner_test.cc new file mode 100644 index 0000000000..f712a5347a --- /dev/null +++ b/tensorflow/contrib/lite/testing/test_runner_test.cc @@ -0,0 +1,84 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/testing/test_runner.h" + +#include <gmock/gmock.h> +#include <gtest/gtest.h> + +namespace tflite { +namespace testing { +namespace { + +class ConcreteTestRunner : public TestRunner { + public: + void LoadModel(const string& bin_file_path) override {} + const std::vector<int>& GetInputs() override { return ids_; } + const std::vector<int>& GetOutputs() override { return ids_; } + void ReshapeTensor(int id, const string& csv_values) override {} + void AllocateTensors() override {} + void ResetTensor(int id) override {} + void SetInput(int id, const string& csv_values) override {} + void SetExpectation(int id, const string& csv_values) override {} + void Invoke() override {} + bool CheckResults() override { return true; } + bool CheckFloatSizes(size_t bytes, size_t values) { + return CheckSizes<float>(bytes, values); + } + + private: + std::vector<int> ids_; +}; + +TEST(TestRunner, ModelPath) { + ConcreteTestRunner runner; + EXPECT_EQ(runner.GetFullPath("test.bin"), "test.bin"); + runner.SetModelBaseDir("/tmp"); + EXPECT_EQ(runner.GetFullPath("test.bin"), "/tmp/test.bin"); +} + +TEST(TestRunner, InvocationId) { + ConcreteTestRunner runner; + EXPECT_EQ(runner.GetInvocationId(), ""); + runner.SetInvocationId("X"); + EXPECT_EQ(runner.GetInvocationId(), "X"); +} + +TEST(TestRunner, Invalidation) { + ConcreteTestRunner runner; + EXPECT_TRUE(runner.IsValid()); + EXPECT_EQ(runner.GetErrorMessage(), ""); + runner.Invalidate("Some Error"); + EXPECT_FALSE(runner.IsValid()); + EXPECT_EQ(runner.GetErrorMessage(), "Some Error"); +} + +TEST(TestRunner, OverallSuccess) { + ConcreteTestRunner runner; + EXPECT_TRUE(runner.GetOverallSuccess()); + runner.SetOverallSuccess(false); + EXPECT_FALSE(runner.GetOverallSuccess()); +} + +TEST(TestRunner, CheckSizes) { + ConcreteTestRunner runner; + EXPECT_TRUE(runner.CheckFloatSizes(16, 4)); + EXPECT_FALSE(runner.CheckFloatSizes(16, 2)); + EXPECT_EQ(runner.GetErrorMessage(), + "Expected '4' elements for a tensor, but only got '2'"); +} + +} // namespace +} // namespace testing +} // namespace tflite diff --git a/tensorflow/contrib/lite/testing/tflite_driver.cc b/tensorflow/contrib/lite/testing/tflite_driver.cc new file mode 100644 index 0000000000..cf9df2ec26 --- /dev/null +++ b/tensorflow/contrib/lite/testing/tflite_driver.cc @@ -0,0 +1,208 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/testing/tflite_driver.h" + +#include <iostream> + +#include "tensorflow/contrib/lite/testing/split.h" + +namespace tflite { +namespace testing { + +namespace { + +// Returns the value in the given position in a tensor. +template <typename T> +T Value(const TfLitePtrUnion& data, int index); +template <> +float Value(const TfLitePtrUnion& data, int index) { + return data.f[index]; +} +template <> +uint8_t Value(const TfLitePtrUnion& data, int index) { + return data.uint8[index]; +} + +template <typename T> +void SetTensorData(const std::vector<T>& values, TfLitePtrUnion* data) { + T* input_ptr = reinterpret_cast<T*>(data->raw); + for (const T& v : values) { + *input_ptr = v; + ++input_ptr; + } +} + +} // namespace + +class TfLiteDriver::Expectation { + public: + Expectation() { data_.raw = nullptr; } + ~Expectation() { delete[] data_.raw; } + template <typename T> + void SetData(const string& csv_values) { + const auto& values = testing::Split<T>(csv_values, ","); + data_.raw = new char[values.size() * sizeof(T)]; + SetTensorData(values, &data_); + } + + bool Check(bool verbose, const TfLiteTensor& tensor) { + switch (tensor.type) { + case kTfLiteFloat32: + return TypedCheck<float>(verbose, tensor); + case kTfLiteUInt8: + return TypedCheck<uint8_t>(verbose, tensor); + default: + return false; + } + } + + private: + template <typename T> + bool TypedCheck(bool verbose, const TfLiteTensor& tensor) { + int tensor_size = tensor.bytes / sizeof(T); + + bool good_output = true; + for (int i = 0; i < tensor_size; ++i) { + if (std::abs(Value<T>(data_, i) - Value<T>(tensor.data, i)) > 1e-5) { + good_output = false; + if (verbose) { + std::cerr << " index " << i << ": " << Value<T>(data_, i) + << " != " << Value<T>(tensor.data, i) << std::endl; + } + } + } + return good_output; + } + + TfLitePtrUnion data_; +}; + +TfLiteDriver::TfLiteDriver(bool use_nnapi) : use_nnapi_(use_nnapi) {} +TfLiteDriver::~TfLiteDriver() {} + +void TfLiteDriver::AllocateTensors() { + if (must_allocate_tensors_) { + if (interpreter_->AllocateTensors() != kTfLiteOk) { + std::cerr << "Failed to allocate tensors" << std::endl; + abort(); + } + must_allocate_tensors_ = false; + } +} + +void TfLiteDriver::LoadModel(const string& bin_file_path) { + if (!IsValid()) return; + std::cout << std::endl << "Loading model: " << bin_file_path << std::endl; + + model_ = FlatBufferModel::BuildFromFile(GetFullPath(bin_file_path).c_str()); + if (!model_) { + Invalidate("Failed to mmap model " + bin_file_path); + return; + } + ops::builtin::BuiltinOpResolver builtins; + InterpreterBuilder(*model_, builtins)(&interpreter_); + if (!interpreter_) { + Invalidate("Failed build interpreter"); + return; + } + + must_allocate_tensors_ = true; +} + +void TfLiteDriver::ResetTensor(int id) { + if (!IsValid()) return; + auto* tensor = interpreter_->tensor(id); + memset(tensor->data.raw, 0, tensor->bytes); +} + +void TfLiteDriver::ReshapeTensor(int id, const string& csv_values) { + if (!IsValid()) return; + if (interpreter_->ResizeInputTensor( + id, testing::Split<int>(csv_values, ",")) != kTfLiteOk) { + Invalidate("Failed to resize input tensor " + std::to_string(id)); + return; + } + must_allocate_tensors_ = true; +} + +void TfLiteDriver::SetInput(int id, const string& csv_values) { + if (!IsValid()) return; + auto* tensor = interpreter_->tensor(id); + switch (tensor->type) { + case kTfLiteFloat32: { + const auto& values = testing::Split<float>(csv_values, ","); + if (!CheckSizes<float>(tensor->bytes, values.size())) return; + SetTensorData(values, &tensor->data); + break; + } + case kTfLiteUInt8: { + const auto& values = testing::Split<uint8_t>(csv_values, ","); + if (!CheckSizes<uint8_t>(tensor->bytes, values.size())) return; + SetTensorData(values, &tensor->data); + break; + } + default: + Invalidate("Unsupported tensor data type"); + return; + } +} + +void TfLiteDriver::SetExpectation(int id, const string& csv_values) { + if (!IsValid()) return; + auto* tensor = interpreter_->tensor(id); + expected_output_[id].reset(new Expectation); + switch (tensor->type) { + case kTfLiteFloat32: + expected_output_[id]->SetData<float>(csv_values); + break; + case kTfLiteUInt8: + expected_output_[id]->SetData<uint8_t>(csv_values); + break; + default: + Invalidate("Unsupported tensor data type"); + return; + } +} + +void TfLiteDriver::Invoke() { + if (!IsValid()) return; + if (interpreter_->Invoke() != kTfLiteOk) { + Invalidate("Failed to invoke interpreter"); + } +} + +bool TfLiteDriver::CheckResults() { + if (!IsValid()) return false; + bool success = true; + for (const auto& p : expected_output_) { + int id = p.first; + auto* tensor = interpreter_->tensor(id); + if (!p.second->Check(/*verbose=*/false, *tensor)) { + // Do not invalidate anything here. Instead, simply output the + // differences and return false. Invalidating would prevent all + // subsequent invocations from running.. + std::cerr << "There were errors in invocation '" << GetInvocationId() + << "', output tensor '" << id << "':" << std::endl; + p.second->Check(/*verbose=*/true, *tensor); + success = false; + SetOverallSuccess(false); + } + } + expected_output_.clear(); + return success; +} + +} // namespace testing +} // namespace tflite diff --git a/tensorflow/contrib/lite/testing/tflite_driver.h b/tensorflow/contrib/lite/testing/tflite_driver.h new file mode 100644 index 0000000000..4440d4285e --- /dev/null +++ b/tensorflow/contrib/lite/testing/tflite_driver.h @@ -0,0 +1,62 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_TFLITE_DRIVER_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_TFLITE_DRIVER_H_ + +#include <map> + +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/testing/test_runner.h" + +namespace tflite { +namespace testing { + +// A test runner that feeds inputs into TF Lite and verifies its outputs. +class TfLiteDriver : public TestRunner { + public: + explicit TfLiteDriver(bool use_nnapi); + ~TfLiteDriver() override; + + void LoadModel(const string& bin_file_path) override; + const std::vector<int>& GetInputs() override { + return interpreter_->inputs(); + } + const std::vector<int>& GetOutputs() override { + return interpreter_->outputs(); + } + void ReshapeTensor(int id, const string& csv_values) override; + void AllocateTensors() override; + void ResetTensor(int id) override; + void SetInput(int id, const string& csv_values) override; + void SetExpectation(int id, const string& csv_values) override; + void Invoke() override; + bool CheckResults() override; + + private: + class Expectation; + + bool use_nnapi_ = false; + std::unique_ptr<FlatBufferModel> model_; + std::unique_ptr<Interpreter> interpreter_; + std::map<int, std::unique_ptr<Expectation>> expected_output_; + bool must_allocate_tensors_ = true; +}; + +} // namespace testing +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_TFLITE_DRIVER_H_ diff --git a/tensorflow/contrib/lite/testing/tflite_driver_test.cc b/tensorflow/contrib/lite/testing/tflite_driver_test.cc new file mode 100644 index 0000000000..37010c468f --- /dev/null +++ b/tensorflow/contrib/lite/testing/tflite_driver_test.cc @@ -0,0 +1,61 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/testing/tflite_driver.h" + +#include <gmock/gmock.h> +#include <gtest/gtest.h> + +namespace tflite { +namespace testing { +namespace { + +using ::testing::ElementsAre; + +TEST(TfliteDriverTest, SimpleTest) { + std::unique_ptr<TestRunner> runner(new TfLiteDriver(/*use_nnapi=*/false)); + + runner->SetModelBaseDir("tensorflow/contrib/lite"); + runner->LoadModel("testdata/multi_add.bin"); + ASSERT_TRUE(runner->IsValid()); + + ASSERT_THAT(runner->GetInputs(), ElementsAre(0, 1, 2, 3)); + ASSERT_THAT(runner->GetOutputs(), ElementsAre(5, 6)); + + for (int i : {0, 1, 2, 3}) { + runner->ReshapeTensor(i, "1,2,2,1"); + } + ASSERT_TRUE(runner->IsValid()); + + runner->AllocateTensors(); + + runner->SetInput(0, "0.1,0.2,0.3,0.4"); + runner->SetInput(1, "0.001,0.002,0.003,0.004"); + runner->SetInput(2, "0.001,0.002,0.003,0.004"); + runner->SetInput(3, "0.01,0.02,0.03,0.04"); + + runner->ResetTensor(2); + + runner->SetExpectation(5, "0.101,0.202,0.303,0.404"); + runner->SetExpectation(6, "0.011,0.022,0.033,0.044"); + + runner->Invoke(); + ASSERT_TRUE(runner->IsValid()); + + ASSERT_TRUE(runner->CheckResults()); +} + +} // namespace +} // namespace testing +} // namespace tflite diff --git a/tensorflow/contrib/lite/testing/tokenize.cc b/tensorflow/contrib/lite/testing/tokenize.cc new file mode 100644 index 0000000000..2e84ea475c --- /dev/null +++ b/tensorflow/contrib/lite/testing/tokenize.cc @@ -0,0 +1,95 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/testing/tokenize.h" +#include <istream> +#include <string> +#include "tensorflow/contrib/lite/string.h" + +namespace tflite { +namespace testing { + +void Tokenize(std::istream* input, TokenProcessor* processor) { + enum State { kBuildQuotedToken, kBuildToken, kIdle }; + + std::string current_token; + State state = kIdle; + auto start_token = [&](char c) { + state = kBuildToken; + current_token.clear(); + current_token = c; + }; + auto issue_token = [&]() { + state = kIdle; + processor->ConsumeToken(¤t_token); + current_token.clear(); + }; + auto start_quoted_token = [&]() { + state = kBuildQuotedToken; + current_token.clear(); + }; + auto issue_quoted_token = [&]() { + state = kIdle; + processor->ConsumeToken(¤t_token); + current_token.clear(); + }; + auto issue_delim = [&](char d) { + current_token = string(1, d); + processor->ConsumeToken(¤t_token); + current_token.clear(); + }; + auto is_delim = [](char c) { return c == '{' || c == '}' || c == ':'; }; + auto is_quote = [](char c) { return c == '"'; }; + + for (auto it = std::istreambuf_iterator<char>(*input); + it != std::istreambuf_iterator<char>(); ++it) { + switch (state) { + case kIdle: + if (is_delim(*it)) { + issue_delim(*it); + } else if (is_quote(*it)) { + start_quoted_token(); + } else if (!isspace(*it)) { + start_token(*it); + } + break; + case kBuildToken: + if (is_delim(*it)) { + issue_token(); + issue_delim(*it); + } else if (is_quote(*it)) { + issue_token(); + start_quoted_token(); + } else if (isspace(*it)) { + issue_token(); + } else { + current_token += *it; + } + break; + case kBuildQuotedToken: + if (is_quote(*it)) { + issue_quoted_token(); + } else { + current_token += *it; + } + break; + } + } + if (state != kIdle) { + issue_token(); + } +} + +} // namespace testing +} // namespace tflite diff --git a/tensorflow/contrib/lite/testing/tokenize.h b/tensorflow/contrib/lite/testing/tokenize.h new file mode 100644 index 0000000000..daccf0e84a --- /dev/null +++ b/tensorflow/contrib/lite/testing/tokenize.h @@ -0,0 +1,42 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_TOKENIZER_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_TOKENIZER_H_ + +#include <istream> +#include <string> + +namespace tflite { +namespace testing { + +// Process tokens coming from Tokenize(). +class TokenProcessor { + public: + virtual ~TokenProcessor() {} + // Process a single token. The token won't be reused, so it is OK to call + // token.swap(). + virtual void ConsumeToken(std::string* token) = 0; +}; + +// Tokenize a stream on whitespaces, colons and curly braces. Whitespaces are +// removed from the tokens and double-quotes can be used to avoid that. Note +// that there is no way to escape double-quotes, so there's no way to have a +// double-quote inside a token. +void Tokenize(std::istream* input, TokenProcessor* processor); + +} // namespace testing +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_TOKENIZER_H_ diff --git a/tensorflow/contrib/lite/testing/tokenize_test.cc b/tensorflow/contrib/lite/testing/tokenize_test.cc new file mode 100644 index 0000000000..80f44aacca --- /dev/null +++ b/tensorflow/contrib/lite/testing/tokenize_test.cc @@ -0,0 +1,105 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/testing/tokenize.h" + +#include <gmock/gmock.h> +#include <gtest/gtest.h> + +namespace tflite { +namespace testing { +namespace { + +using ::testing::ElementsAre; +using ::testing::ElementsAreArray; + +class TokenCollector : public TokenProcessor { + public: + void ConsumeToken(std::string* token) override { tokens_.push_back(*token); } + const std::vector<std::string>& Tokens() { return tokens_; } + + private: + std::vector<std::string> tokens_; +}; + +std::vector<std::string> TokenizeString(const std::string& s) { + std::stringstream ss(s); + TokenCollector collector; + Tokenize(&ss, &collector); + return collector.Tokens(); +} + +TEST(TokenizeTest, TokenDetection) { + EXPECT_THAT(TokenizeString("x :1"), ElementsAre("x", ":", "1")); + EXPECT_THAT(TokenizeString("x:1"), ElementsAre("x", ":", "1")); + EXPECT_THAT(TokenizeString("x {1"), ElementsAre("x", "{", "1")); + EXPECT_THAT(TokenizeString("x{1"), ElementsAre("x", "{", "1")); + EXPECT_THAT(TokenizeString("x }1"), ElementsAre("x", "}", "1")); + EXPECT_THAT(TokenizeString("x}1"), ElementsAre("x", "}", "1")); + EXPECT_THAT(TokenizeString("x \"1"), ElementsAre("x", "1")); + EXPECT_THAT(TokenizeString("x\"1"), ElementsAre("x", "1")); +} + +TEST(TokenizeTest, QuotedTokenDetection) { + EXPECT_THAT(TokenizeString("\"w:x{y}z\"1"), ElementsAre("w:x{y}z", "1")); + EXPECT_THAT(TokenizeString("\"w:x{y}z\"\"1\""), ElementsAre("w:x{y}z", "1")); +} + +TEST(TokenizeTest, Delimiters) { + EXPECT_THAT(TokenizeString("}"), ElementsAre("}")); + EXPECT_THAT(TokenizeString("}}"), ElementsAre("}", "}")); + EXPECT_THAT(TokenizeString("{"), ElementsAre("{")); + EXPECT_THAT(TokenizeString("{{"), ElementsAre("{", "{")); + EXPECT_THAT(TokenizeString(":"), ElementsAre(":")); + EXPECT_THAT(TokenizeString("::"), ElementsAre(":", ":")); +} + +TEST(TokenizeTest, CornerCases) { + EXPECT_THAT(TokenizeString(" i { b:a } "), + ElementsAre("i", "{", "b", ":", "a", "}")); + EXPECT_THAT(TokenizeString(" }"), ElementsAre("}")); + EXPECT_THAT(TokenizeString(" } "), ElementsAre("}")); + EXPECT_THAT(TokenizeString(" {} "), ElementsAre("{", "}")); + EXPECT_THAT(TokenizeString(" x{} y{} "), + ElementsAre("x", "{", "}", "y", "{", "}")); + EXPECT_THAT(TokenizeString("x:1 y:2 "), + ElementsAre("x", ":", "1", "y", ":", "2")); + EXPECT_THAT(TokenizeString("x:\"1\" y:2 "), + ElementsAre("x", ":", "1", "y", ":", "2")); + EXPECT_THAT(TokenizeString("x:\"1, 2\" y:\"\" "), + ElementsAre("x", ":", "1, 2", "y", ":", "")); +} + +TEST(TokenizeTest, NewLines) { + EXPECT_THAT(TokenizeString("x:\n1,\n 2 \n y :\n3 \n"), + ElementsAre("x", ":", "1,", "2", "y", ":", "3")); +} + +TEST(TokenizeTest, LongString) { + EXPECT_THAT( + TokenizeString(" i { b:a } input {" + "a: \"1e-1, 2,3\" b:\"1,2,3\"\n c{ " + "id:1 x{d{a:" + "1}}} f:2 " + "\n}\n t:1"), + ElementsAreArray({"i", "{", "b", ":", "a", "}", "input", "{", + "a", ":", "1e-1, 2,3", "b", ":", "1,2,3", "c", "{", + "id", ":", "1", "x", "{", "d", "{", "a", + ":", "1", "}", "}", "}", "f", ":", "2", + "}", "t", ":", "1"})); +} + +} // namespace +} // namespace testing +} // namespace tflite |