aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/testing/generate_examples.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/testing/generate_examples.py')
-rw-r--r--tensorflow/contrib/lite/testing/generate_examples.py185
1 files changed, 159 insertions, 26 deletions
diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py
index 1360f1a273..41ece94237 100644
--- a/tensorflow/contrib/lite/testing/generate_examples.py
+++ b/tensorflow/contrib/lite/testing/generate_examples.py
@@ -94,8 +94,8 @@ KNOWN_BUGS = {
r"sigmoid.*input_shape=\[\]": "67645668",
# Concat doesn't work with a single input tensor
r"concat.*num_tensors=1": "67378344",
- # Transposition in MatMul is not supported.
- r"fully_connected.*transpose_.=True": "67586970",
+ # Transposition in MatMul is not fully supported.
+ "fully_connected.*transpose_a=True": "67586970",
# Softmax graphs are too complex.
r"softmax.*dim=0": "67749831",
# BatchToSpaceND only supports 4D tensors.
@@ -678,6 +678,55 @@ def make_relu6_tests(zip_path):
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+def make_prelu_tests(zip_path):
+ """Make a set of tests to do PReLU."""
+
+ test_parameters = [{
+ # The canonical case for image processing is having a 4D `input` (NHWC)
+ # and `shared_axes`=[1, 2], so the alpha parameter is per channel.
+ "input_shape": [[1, 10, 10, 3], [3, 3, 3, 3]],
+ "shared_axes": [[1, 2], [1]],
+ }]
+
+ def build_graph(parameters):
+ """Build the graph for the test case."""
+
+ input_tensor = tf.placeholder(
+ dtype=tf.float32, name="input", shape=parameters["input_shape"])
+ prelu = tf.keras.layers.PReLU(shared_axes=parameters["shared_axes"])
+ out = prelu(input_tensor)
+ return [input_tensor], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ """Build the inputs for the test case."""
+
+ input_shape = parameters["input_shape"]
+ input_values = create_tensor_data(
+ np.float32, input_shape, min_value=-10, max_value=10)
+ shared_axes = parameters["shared_axes"]
+
+ alpha_shape = []
+ for dim in range(1, len(input_shape)):
+ alpha_shape.append(1 if dim in shared_axes else input_shape[dim])
+
+ alpha_values = create_tensor_data(np.float32, alpha_shape)
+
+ # There should be only 1 trainable variable tensor.
+ variables = tf.all_variables()
+ assert len(variables) == 1
+ sess.run(variables[0].assign(alpha_values))
+
+ return [input_values], sess.run(
+ outputs, feed_dict=dict(zip(inputs, [input_values])))
+
+ make_zip_of_tests(
+ zip_path,
+ test_parameters,
+ build_graph,
+ build_inputs,
+ use_frozen_graph=True)
+
+
# This function tests various TensorFLow functions that generates Const op,
# including `tf.ones`, `tf.zeros` and random functions.
def make_constant_tests(zip_path):
@@ -723,6 +772,11 @@ def make_binary_op_tests(zip_path, binary_operator):
"input_shape_1": [[1, 3, 4, 3]],
"input_shape_2": [[3]],
"activation": [True]
+ }, {
+ "dtype": [tf.float32],
+ "input_shape_1": [[]],
+ "input_shape_2": [[]],
+ "activation": [False]
}]
def build_graph(parameters):
@@ -772,7 +826,7 @@ def make_reduce_tests(reduce_op):
"input_dtype": [tf.float32, tf.int32, tf.int64],
"input_shape": [[3, 2, 4]],
"axis": [
- None, 0, 1, 2, [0, 1], [0, 2], [1, 2], [0, 1, 2], [1, 0], [2, 0],
+ 0, 1, 2, [0, 1], [0, 2], [1, 2], [0, 1, 2], [1, 0], [2, 0],
[2, 1], [2, 1, 0], [2, 0, 1], -1, -2, -3, [1, -1], [0, -1], [-1, 0],
[-1, -2, -3], [0, 0, 0], [2, 2, 0], [1, 0, -3, -3]
],
@@ -782,13 +836,19 @@ def make_reduce_tests(reduce_op):
"input_dtype": [tf.float32],
"input_shape": [[1, 8, 8, 3]],
"axis": [
- None, 0, 1, 2, 3, [1, 2], [0, 3], [1, 2, 3], [0, 1, 2, 3],
+ 0, 1, 2, 3, [1, 2], [0, 3], [1, 2, 3], [0, 1, 2, 3],
[3, 2, 1, 0], [3, 1, 0, 2], [2, 0], [3, 0], [3, 1], [1, 0], -1, -2,
-3, -4, [0, -2], [2, 3, -1, 0], [3, 1, 2, -3], [3, -4], [2, 2, 2],
[2, 2, 3], [-3, -3, -4], [-3, 2, 1]
],
"const_axis": [True, False],
"keepdims": [True, False],
+ }, {
+ "input_dtype": [tf.float32],
+ "input_shape": [[], [1, 8, 8, 3], [3, 2, 4]],
+ "axis": [None],
+ "const_axis": [True],
+ "keepdims": [True, False],
}]
def build_graph(parameters):
@@ -806,7 +866,7 @@ def make_reduce_tests(reduce_op):
if isinstance(parameters["axis"], list):
shape = [len(parameters["axis"])]
else:
- shape = [0] # shape for None or integers.
+ shape = [] # shape for None or integers.
axis = tf.placeholder(dtype=tf.int32, name="axis", shape=shape)
input_tensors = [input_tensor, axis]
@@ -817,10 +877,11 @@ def make_reduce_tests(reduce_op):
def build_inputs(parameters, sess, inputs, outputs):
values = [
create_tensor_data(parameters["input_dtype"],
- parameters["input_shape"])]
+ parameters["input_shape"],
+ min_value=-10,
+ max_value=10)]
if not parameters["const_axis"]:
- if parameters["axis"]:
- values.append(np.array(parameters["axis"]))
+ values.append(np.array(parameters["axis"]))
return values, sess.run(outputs, feed_dict=dict(zip(inputs, values)))
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
@@ -830,22 +891,30 @@ def make_reduce_tests(reduce_op):
def make_mean_tests(zip_path):
"""Make a set of tests to do mean."""
-
return make_reduce_tests(tf.reduce_mean)(zip_path)
def make_sum_tests(zip_path):
"""Make a set of tests to do sum."""
-
return make_reduce_tests(tf.reduce_sum)(zip_path)
+def make_reduce_prod_tests(zip_path):
+ """Make a set of tests to do prod."""
+ return make_reduce_tests(tf.reduce_prod)(zip_path)
+
+
+def make_reduce_max_tests(zip_path):
+ """Make a set of tests to do max."""
+ return make_reduce_tests(tf.reduce_max)(zip_path)
+
+
def make_exp_tests(zip_path):
"""Make a set of tests to do exp."""
test_parameters = [{
"input_dtype": [tf.float32],
- "input_shape": [[3], [1, 100], [4, 2, 3], [5, 224, 224, 3]],
+ "input_shape": [[], [3], [1, 100], [4, 2, 3], [5, 224, 224, 3]],
}]
def build_graph(parameters):
@@ -904,8 +973,8 @@ def make_maximum_tests(zip_path):
test_parameters = [{
"input_dtype": [tf.float32],
- "input_shape_1": [[3], [1, 100], [4, 2, 3], [5, 224, 224, 3]],
- "input_shape_2": [[3], [1, 100], [4, 2, 3], [5, 224, 224, 3]],
+ "input_shape_1": [[], [3], [1, 100], [4, 2, 3], [5, 224, 224, 3]],
+ "input_shape_2": [[], [3], [1, 100], [4, 2, 3], [5, 224, 224, 3]],
}]
def build_graph(parameters):
@@ -939,8 +1008,8 @@ def make_minimum_tests(zip_path):
test_parameters = [{
"input_dtype": [tf.float32],
- "input_shape_1": [[3], [1, 100], [4, 2, 3], [5, 224, 224, 3]],
- "input_shape_2": [[3], [1, 100], [4, 2, 3], [5, 224, 224, 3]],
+ "input_shape_1": [[], [3], [1, 100], [4, 2, 3], [5, 224, 224, 3]],
+ "input_shape_2": [[], [3], [1, 100], [4, 2, 3], [5, 224, 224, 3]],
}]
def build_graph(parameters):
@@ -1325,6 +1394,12 @@ def make_fully_connected_tests(zip_path):
"transpose_a": [False],
"transpose_b": [False],
"constant_filter": [True, False],
+ }, {
+ "shape1": [[40, 37]],
+ "shape2": [[40, 37]],
+ "transpose_a": [False],
+ "transpose_b": [True],
+ "constant_filter": [True, False],
}]
def build_graph(parameters):
@@ -1532,19 +1607,34 @@ def make_reshape_tests(zip_path):
"dtype": [tf.float32, tf.int32],
"input_shape": [[3, 4, 5, 7], [4, 105], [21, 5, 2, 2], [420]],
"output_shape": [[15, 28], [420], [1, -1, 5, 7], [-1]],
+ "constant_shape": [True, False],
}]
def build_graph(parameters):
input_tensor = tf.placeholder(dtype=parameters["dtype"], name="input",
shape=parameters["input_shape"])
- out = tf.reshape(input_tensor, shape=parameters["output_shape"])
- return [input_tensor], [out]
+
+ # Get shape as either a placeholder or constants.
+ if parameters["constant_shape"]:
+ output_shape = parameters["output_shape"]
+ input_tensors = [input_tensor]
+ else:
+ # The shape of the shape tensor.
+ shape_tensor_shape = [len(parameters["output_shape"])]
+ output_shape = tf.placeholder(
+ dtype=tf.int32, name="output_shape", shape=shape_tensor_shape)
+ input_tensors = [input_tensor, output_shape]
+ out = tf.reshape(input_tensor, shape=output_shape)
+ return input_tensors, [out]
def build_inputs(parameters, sess, inputs, outputs):
- input_values = create_tensor_data(parameters["dtype"],
- parameters["input_shape"])
- return [input_values], sess.run(
- outputs, feed_dict=dict(zip(inputs, [input_values])))
+ values = [
+ create_tensor_data(parameters["dtype"], parameters["input_shape"])
+ ]
+ if not parameters["constant_shape"]:
+ values.append(np.array(parameters["output_shape"]))
+
+ return values, sess.run(outputs, feed_dict=dict(zip(inputs, values)))
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
@@ -2169,14 +2259,15 @@ def make_topk_tests(zip_path):
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
-def make_arg_max_tests(zip_path):
+def make_arg_min_max_tests(zip_path):
"""Make a set of tests to do arg_max."""
test_parameters = [{
"input_dtype": [tf.float32, tf.int32],
- "input_shape": [[1, 1, 1, 3], [2, 3, 4, 5], [2, 3, 3], [5, 5], [10]],
+ "input_shape": [[], [1, 1, 1, 3], [2, 3, 4, 5], [2, 3, 3], [5, 5], [10]],
"output_type": [tf.int32, tf.int64],
"axis_is_last_dim": [True, False],
+ "is_arg_max": [True],
}]
def build_graph(parameters):
@@ -2189,7 +2280,10 @@ def make_arg_max_tests(zip_path):
axis = len(parameters["input_shape"]) - 1
else:
axis = random.randint(0, max(len(parameters["input_shape"]) - 2, 0))
- out = tf.arg_max(input_value, axis, output_type=parameters["output_type"])
+ if parameters["is_arg_max"]:
+ out = tf.arg_max(input_value, axis, output_type=parameters["output_type"])
+ else:
+ out = tf.arg_min(input_value, axis, output_type=parameters["output_type"])
return [input_value], [out]
def build_inputs(parameters, sess, inputs, outputs):
@@ -2206,7 +2300,8 @@ def make_equal_tests(zip_path):
test_parameters = [{
"input_dtype": [tf.float32, tf.int32, tf.int64],
- "input_shape_pair": [([1, 1, 1, 3], [1, 1, 1, 3]),
+ "input_shape_pair": [([], []),
+ ([1, 1, 1, 3], [1, 1, 1, 3]),
([2, 3, 4, 5], [2, 3, 4, 5]), ([2, 3, 3], [2, 3]),
([5, 5], [1]), ([10], [2, 4, 10])],
}]
@@ -2463,7 +2558,7 @@ def _make_elementwise_tests(op):
"""Actual function that generates examples."""
test_parameters = [{
"input_dtype": [tf.float32],
- "input_shape": [[1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]],
+ "input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]],
}]
def build_graph(parameters):
@@ -2785,6 +2880,44 @@ def make_sparse_to_dense_tests(zip_path):
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+def make_pack_tests(zip_path):
+ """Make a set of tests to do stack."""
+
+ test_parameters = [{
+ "base_shape": [[3, 4, 3], [3, 4], [5]],
+ "num_tensors": [1, 2, 3, 4, 5, 6],
+ "axis": [0, 1, 2, 3],
+ "additional_shape": [1, 2, 3],
+ }]
+
+ def get_shape(parameters):
+ """Return a tweaked version of 'base_shape'."""
+ axis = parameters["axis"]
+ shape = parameters["base_shape"][:]
+ if axis < len(shape):
+ shape[axis] += parameters["additional_shape"]
+ return shape
+
+ def build_graph(parameters):
+ all_tensors = []
+ for n in range(0, parameters["num_tensors"]):
+ input_tensor = tf.placeholder(
+ dtype=tf.float32, name=("input%d" % n), shape=get_shape(parameters))
+ all_tensors.append(input_tensor)
+ out = tf.stack(all_tensors, parameters["axis"])
+ return all_tensors, [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ all_values = []
+ for _ in range(0, parameters["num_tensors"]):
+ input_values = create_tensor_data(np.float32, get_shape(parameters))
+ all_values.append(input_values)
+ return all_values, sess.run(
+ outputs, feed_dict=dict(zip(inputs, all_values)))
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
# Toco binary path provided by the generate rule.
bin_path = None