diff options
author | 2018-07-23 19:08:40 -0700 | |
---|---|---|
committer | 2018-07-23 19:12:07 -0700 | |
commit | df7344f1933d932f03f472402068ff1883f0c011 (patch) | |
tree | 56cd6abec5b5076195b86909eddd47696b4d3612 /tensorflow/contrib/lite/testing/generate_examples.py | |
parent | efe370fcb367efd069c8166120858492dffa9a33 (diff) |
Implementation of stack.
PiperOrigin-RevId: 205763219
Diffstat (limited to 'tensorflow/contrib/lite/testing/generate_examples.py')
-rw-r--r-- | tensorflow/contrib/lite/testing/generate_examples.py | 38 |
1 files changed, 38 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py index b3ccc65e85..41ece94237 100644 --- a/tensorflow/contrib/lite/testing/generate_examples.py +++ b/tensorflow/contrib/lite/testing/generate_examples.py @@ -2880,6 +2880,44 @@ def make_sparse_to_dense_tests(zip_path): make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) +def make_pack_tests(zip_path): + """Make a set of tests to do stack.""" + + test_parameters = [{ + "base_shape": [[3, 4, 3], [3, 4], [5]], + "num_tensors": [1, 2, 3, 4, 5, 6], + "axis": [0, 1, 2, 3], + "additional_shape": [1, 2, 3], + }] + + def get_shape(parameters): + """Return a tweaked version of 'base_shape'.""" + axis = parameters["axis"] + shape = parameters["base_shape"][:] + if axis < len(shape): + shape[axis] += parameters["additional_shape"] + return shape + + def build_graph(parameters): + all_tensors = [] + for n in range(0, parameters["num_tensors"]): + input_tensor = tf.placeholder( + dtype=tf.float32, name=("input%d" % n), shape=get_shape(parameters)) + all_tensors.append(input_tensor) + out = tf.stack(all_tensors, parameters["axis"]) + return all_tensors, [out] + + def build_inputs(parameters, sess, inputs, outputs): + all_values = [] + for _ in range(0, parameters["num_tensors"]): + input_values = create_tensor_data(np.float32, get_shape(parameters)) + all_values.append(input_values) + return all_values, sess.run( + outputs, feed_dict=dict(zip(inputs, all_values))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + # Toco binary path provided by the generate rule. bin_path = None |