diff options
Diffstat (limited to 'tensorflow/contrib/lite/testing/generate_examples.py')
-rw-r--r-- | tensorflow/contrib/lite/testing/generate_examples.py | 185 |
1 files changed, 159 insertions, 26 deletions
diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py index 1360f1a273..41ece94237 100644 --- a/tensorflow/contrib/lite/testing/generate_examples.py +++ b/tensorflow/contrib/lite/testing/generate_examples.py @@ -94,8 +94,8 @@ KNOWN_BUGS = { r"sigmoid.*input_shape=\[\]": "67645668", # Concat doesn't work with a single input tensor r"concat.*num_tensors=1": "67378344", - # Transposition in MatMul is not supported. - r"fully_connected.*transpose_.=True": "67586970", + # Transposition in MatMul is not fully supported. + "fully_connected.*transpose_a=True": "67586970", # Softmax graphs are too complex. r"softmax.*dim=0": "67749831", # BatchToSpaceND only supports 4D tensors. @@ -678,6 +678,55 @@ def make_relu6_tests(zip_path): make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) +def make_prelu_tests(zip_path): + """Make a set of tests to do PReLU.""" + + test_parameters = [{ + # The canonical case for image processing is having a 4D `input` (NHWC) + # and `shared_axes`=[1, 2], so the alpha parameter is per channel. + "input_shape": [[1, 10, 10, 3], [3, 3, 3, 3]], + "shared_axes": [[1, 2], [1]], + }] + + def build_graph(parameters): + """Build the graph for the test case.""" + + input_tensor = tf.placeholder( + dtype=tf.float32, name="input", shape=parameters["input_shape"]) + prelu = tf.keras.layers.PReLU(shared_axes=parameters["shared_axes"]) + out = prelu(input_tensor) + return [input_tensor], [out] + + def build_inputs(parameters, sess, inputs, outputs): + """Build the inputs for the test case.""" + + input_shape = parameters["input_shape"] + input_values = create_tensor_data( + np.float32, input_shape, min_value=-10, max_value=10) + shared_axes = parameters["shared_axes"] + + alpha_shape = [] + for dim in range(1, len(input_shape)): + alpha_shape.append(1 if dim in shared_axes else input_shape[dim]) + + alpha_values = create_tensor_data(np.float32, alpha_shape) + + # There should be only 1 trainable variable tensor. + variables = tf.all_variables() + assert len(variables) == 1 + sess.run(variables[0].assign(alpha_values)) + + return [input_values], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_values]))) + + make_zip_of_tests( + zip_path, + test_parameters, + build_graph, + build_inputs, + use_frozen_graph=True) + + # This function tests various TensorFLow functions that generates Const op, # including `tf.ones`, `tf.zeros` and random functions. def make_constant_tests(zip_path): @@ -723,6 +772,11 @@ def make_binary_op_tests(zip_path, binary_operator): "input_shape_1": [[1, 3, 4, 3]], "input_shape_2": [[3]], "activation": [True] + }, { + "dtype": [tf.float32], + "input_shape_1": [[]], + "input_shape_2": [[]], + "activation": [False] }] def build_graph(parameters): @@ -772,7 +826,7 @@ def make_reduce_tests(reduce_op): "input_dtype": [tf.float32, tf.int32, tf.int64], "input_shape": [[3, 2, 4]], "axis": [ - None, 0, 1, 2, [0, 1], [0, 2], [1, 2], [0, 1, 2], [1, 0], [2, 0], + 0, 1, 2, [0, 1], [0, 2], [1, 2], [0, 1, 2], [1, 0], [2, 0], [2, 1], [2, 1, 0], [2, 0, 1], -1, -2, -3, [1, -1], [0, -1], [-1, 0], [-1, -2, -3], [0, 0, 0], [2, 2, 0], [1, 0, -3, -3] ], @@ -782,13 +836,19 @@ def make_reduce_tests(reduce_op): "input_dtype": [tf.float32], "input_shape": [[1, 8, 8, 3]], "axis": [ - None, 0, 1, 2, 3, [1, 2], [0, 3], [1, 2, 3], [0, 1, 2, 3], + 0, 1, 2, 3, [1, 2], [0, 3], [1, 2, 3], [0, 1, 2, 3], [3, 2, 1, 0], [3, 1, 0, 2], [2, 0], [3, 0], [3, 1], [1, 0], -1, -2, -3, -4, [0, -2], [2, 3, -1, 0], [3, 1, 2, -3], [3, -4], [2, 2, 2], [2, 2, 3], [-3, -3, -4], [-3, 2, 1] ], "const_axis": [True, False], "keepdims": [True, False], + }, { + "input_dtype": [tf.float32], + "input_shape": [[], [1, 8, 8, 3], [3, 2, 4]], + "axis": [None], + "const_axis": [True], + "keepdims": [True, False], }] def build_graph(parameters): @@ -806,7 +866,7 @@ def make_reduce_tests(reduce_op): if isinstance(parameters["axis"], list): shape = [len(parameters["axis"])] else: - shape = [0] # shape for None or integers. + shape = [] # shape for None or integers. axis = tf.placeholder(dtype=tf.int32, name="axis", shape=shape) input_tensors = [input_tensor, axis] @@ -817,10 +877,11 @@ def make_reduce_tests(reduce_op): def build_inputs(parameters, sess, inputs, outputs): values = [ create_tensor_data(parameters["input_dtype"], - parameters["input_shape"])] + parameters["input_shape"], + min_value=-10, + max_value=10)] if not parameters["const_axis"]: - if parameters["axis"]: - values.append(np.array(parameters["axis"])) + values.append(np.array(parameters["axis"])) return values, sess.run(outputs, feed_dict=dict(zip(inputs, values))) make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) @@ -830,22 +891,30 @@ def make_reduce_tests(reduce_op): def make_mean_tests(zip_path): """Make a set of tests to do mean.""" - return make_reduce_tests(tf.reduce_mean)(zip_path) def make_sum_tests(zip_path): """Make a set of tests to do sum.""" - return make_reduce_tests(tf.reduce_sum)(zip_path) +def make_reduce_prod_tests(zip_path): + """Make a set of tests to do prod.""" + return make_reduce_tests(tf.reduce_prod)(zip_path) + + +def make_reduce_max_tests(zip_path): + """Make a set of tests to do max.""" + return make_reduce_tests(tf.reduce_max)(zip_path) + + def make_exp_tests(zip_path): """Make a set of tests to do exp.""" test_parameters = [{ "input_dtype": [tf.float32], - "input_shape": [[3], [1, 100], [4, 2, 3], [5, 224, 224, 3]], + "input_shape": [[], [3], [1, 100], [4, 2, 3], [5, 224, 224, 3]], }] def build_graph(parameters): @@ -904,8 +973,8 @@ def make_maximum_tests(zip_path): test_parameters = [{ "input_dtype": [tf.float32], - "input_shape_1": [[3], [1, 100], [4, 2, 3], [5, 224, 224, 3]], - "input_shape_2": [[3], [1, 100], [4, 2, 3], [5, 224, 224, 3]], + "input_shape_1": [[], [3], [1, 100], [4, 2, 3], [5, 224, 224, 3]], + "input_shape_2": [[], [3], [1, 100], [4, 2, 3], [5, 224, 224, 3]], }] def build_graph(parameters): @@ -939,8 +1008,8 @@ def make_minimum_tests(zip_path): test_parameters = [{ "input_dtype": [tf.float32], - "input_shape_1": [[3], [1, 100], [4, 2, 3], [5, 224, 224, 3]], - "input_shape_2": [[3], [1, 100], [4, 2, 3], [5, 224, 224, 3]], + "input_shape_1": [[], [3], [1, 100], [4, 2, 3], [5, 224, 224, 3]], + "input_shape_2": [[], [3], [1, 100], [4, 2, 3], [5, 224, 224, 3]], }] def build_graph(parameters): @@ -1325,6 +1394,12 @@ def make_fully_connected_tests(zip_path): "transpose_a": [False], "transpose_b": [False], "constant_filter": [True, False], + }, { + "shape1": [[40, 37]], + "shape2": [[40, 37]], + "transpose_a": [False], + "transpose_b": [True], + "constant_filter": [True, False], }] def build_graph(parameters): @@ -1532,19 +1607,34 @@ def make_reshape_tests(zip_path): "dtype": [tf.float32, tf.int32], "input_shape": [[3, 4, 5, 7], [4, 105], [21, 5, 2, 2], [420]], "output_shape": [[15, 28], [420], [1, -1, 5, 7], [-1]], + "constant_shape": [True, False], }] def build_graph(parameters): input_tensor = tf.placeholder(dtype=parameters["dtype"], name="input", shape=parameters["input_shape"]) - out = tf.reshape(input_tensor, shape=parameters["output_shape"]) - return [input_tensor], [out] + + # Get shape as either a placeholder or constants. + if parameters["constant_shape"]: + output_shape = parameters["output_shape"] + input_tensors = [input_tensor] + else: + # The shape of the shape tensor. + shape_tensor_shape = [len(parameters["output_shape"])] + output_shape = tf.placeholder( + dtype=tf.int32, name="output_shape", shape=shape_tensor_shape) + input_tensors = [input_tensor, output_shape] + out = tf.reshape(input_tensor, shape=output_shape) + return input_tensors, [out] def build_inputs(parameters, sess, inputs, outputs): - input_values = create_tensor_data(parameters["dtype"], - parameters["input_shape"]) - return [input_values], sess.run( - outputs, feed_dict=dict(zip(inputs, [input_values]))) + values = [ + create_tensor_data(parameters["dtype"], parameters["input_shape"]) + ] + if not parameters["constant_shape"]: + values.append(np.array(parameters["output_shape"])) + + return values, sess.run(outputs, feed_dict=dict(zip(inputs, values))) make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) @@ -2169,14 +2259,15 @@ def make_topk_tests(zip_path): make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) -def make_arg_max_tests(zip_path): +def make_arg_min_max_tests(zip_path): """Make a set of tests to do arg_max.""" test_parameters = [{ "input_dtype": [tf.float32, tf.int32], - "input_shape": [[1, 1, 1, 3], [2, 3, 4, 5], [2, 3, 3], [5, 5], [10]], + "input_shape": [[], [1, 1, 1, 3], [2, 3, 4, 5], [2, 3, 3], [5, 5], [10]], "output_type": [tf.int32, tf.int64], "axis_is_last_dim": [True, False], + "is_arg_max": [True], }] def build_graph(parameters): @@ -2189,7 +2280,10 @@ def make_arg_max_tests(zip_path): axis = len(parameters["input_shape"]) - 1 else: axis = random.randint(0, max(len(parameters["input_shape"]) - 2, 0)) - out = tf.arg_max(input_value, axis, output_type=parameters["output_type"]) + if parameters["is_arg_max"]: + out = tf.arg_max(input_value, axis, output_type=parameters["output_type"]) + else: + out = tf.arg_min(input_value, axis, output_type=parameters["output_type"]) return [input_value], [out] def build_inputs(parameters, sess, inputs, outputs): @@ -2206,7 +2300,8 @@ def make_equal_tests(zip_path): test_parameters = [{ "input_dtype": [tf.float32, tf.int32, tf.int64], - "input_shape_pair": [([1, 1, 1, 3], [1, 1, 1, 3]), + "input_shape_pair": [([], []), + ([1, 1, 1, 3], [1, 1, 1, 3]), ([2, 3, 4, 5], [2, 3, 4, 5]), ([2, 3, 3], [2, 3]), ([5, 5], [1]), ([10], [2, 4, 10])], }] @@ -2463,7 +2558,7 @@ def _make_elementwise_tests(op): """Actual function that generates examples.""" test_parameters = [{ "input_dtype": [tf.float32], - "input_shape": [[1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]], + "input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]], }] def build_graph(parameters): @@ -2785,6 +2880,44 @@ def make_sparse_to_dense_tests(zip_path): make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) +def make_pack_tests(zip_path): + """Make a set of tests to do stack.""" + + test_parameters = [{ + "base_shape": [[3, 4, 3], [3, 4], [5]], + "num_tensors": [1, 2, 3, 4, 5, 6], + "axis": [0, 1, 2, 3], + "additional_shape": [1, 2, 3], + }] + + def get_shape(parameters): + """Return a tweaked version of 'base_shape'.""" + axis = parameters["axis"] + shape = parameters["base_shape"][:] + if axis < len(shape): + shape[axis] += parameters["additional_shape"] + return shape + + def build_graph(parameters): + all_tensors = [] + for n in range(0, parameters["num_tensors"]): + input_tensor = tf.placeholder( + dtype=tf.float32, name=("input%d" % n), shape=get_shape(parameters)) + all_tensors.append(input_tensor) + out = tf.stack(all_tensors, parameters["axis"]) + return all_tensors, [out] + + def build_inputs(parameters, sess, inputs, outputs): + all_values = [] + for _ in range(0, parameters["num_tensors"]): + input_values = create_tensor_data(np.float32, get_shape(parameters)) + all_values.append(input_values) + return all_values, sess.run( + outputs, feed_dict=dict(zip(inputs, all_values))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + # Toco binary path provided by the generate rule. bin_path = None |