aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/testing/generate_examples.py
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-07-23 19:08:40 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-23 19:12:07 -0700
commitdf7344f1933d932f03f472402068ff1883f0c011 (patch)
tree56cd6abec5b5076195b86909eddd47696b4d3612 /tensorflow/contrib/lite/testing/generate_examples.py
parentefe370fcb367efd069c8166120858492dffa9a33 (diff)
Implementation of stack.
PiperOrigin-RevId: 205763219
Diffstat (limited to 'tensorflow/contrib/lite/testing/generate_examples.py')
-rw-r--r--tensorflow/contrib/lite/testing/generate_examples.py38
1 files changed, 38 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py
index b3ccc65e85..41ece94237 100644
--- a/tensorflow/contrib/lite/testing/generate_examples.py
+++ b/tensorflow/contrib/lite/testing/generate_examples.py
@@ -2880,6 +2880,44 @@ def make_sparse_to_dense_tests(zip_path):
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+def make_pack_tests(zip_path):
+ """Make a set of tests to do stack."""
+
+ test_parameters = [{
+ "base_shape": [[3, 4, 3], [3, 4], [5]],
+ "num_tensors": [1, 2, 3, 4, 5, 6],
+ "axis": [0, 1, 2, 3],
+ "additional_shape": [1, 2, 3],
+ }]
+
+ def get_shape(parameters):
+ """Return a tweaked version of 'base_shape'."""
+ axis = parameters["axis"]
+ shape = parameters["base_shape"][:]
+ if axis < len(shape):
+ shape[axis] += parameters["additional_shape"]
+ return shape
+
+ def build_graph(parameters):
+ all_tensors = []
+ for n in range(0, parameters["num_tensors"]):
+ input_tensor = tf.placeholder(
+ dtype=tf.float32, name=("input%d" % n), shape=get_shape(parameters))
+ all_tensors.append(input_tensor)
+ out = tf.stack(all_tensors, parameters["axis"])
+ return all_tensors, [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ all_values = []
+ for _ in range(0, parameters["num_tensors"]):
+ input_values = create_tensor_data(np.float32, get_shape(parameters))
+ all_values.append(input_values)
+ return all_values, sess.run(
+ outputs, feed_dict=dict(zip(inputs, all_values)))
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
# Toco binary path provided by the generate rule.
bin_path = None