diff options
Diffstat (limited to 'tensorflow/contrib/lite/testing/generate_examples.py')
-rw-r--r-- | tensorflow/contrib/lite/testing/generate_examples.py | 36 |
1 files changed, 36 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py index cb5c500136..8045052452 100644 --- a/tensorflow/contrib/lite/testing/generate_examples.py +++ b/tensorflow/contrib/lite/testing/generate_examples.py @@ -862,6 +862,41 @@ def make_log_softmax_tests(zip_path): make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) +def make_maximum_tests(zip_path): + """Make a set of tests to do maximum.""" + + test_parameters = [{ + "input_dtype": [tf.float32], + "input_shape_1": [[3], [1, 100], [4, 2, 3], [5, 224, 224, 3]], + "input_shape_2": [[3], [1, 100], [4, 2, 3], [5, 224, 224, 3]], + }] + + def build_graph(parameters): + """Build the maximum op testing graph.""" + input_tensor_1 = tf.placeholder( + dtype=parameters["input_dtype"], + name="input_1", + shape=parameters["input_shape_1"]) + input_tensor_2 = tf.placeholder( + dtype=parameters["input_dtype"], + name="input_2", + shape=parameters["input_shape_2"]) + + out = tf.maximum(input_tensor_1, input_tensor_2) + return [input_tensor_1, input_tensor_2], [out] + + def build_inputs(parameters, sess, inputs, outputs): + values = [ + create_tensor_data(parameters["input_dtype"], + parameters["input_shape_1"]), + create_tensor_data(parameters["input_dtype"], + parameters["input_shape_2"]) + ] + return values, sess.run(outputs, feed_dict=dict(zip(inputs, values))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + def make_binary_op_tests_func(binary_operator): """Return a function that does a test on a binary operator.""" return lambda zip_path: make_binary_op_tests(zip_path, binary_operator) @@ -1977,6 +2012,7 @@ def main(unused_args): "exp.zip": make_exp_tests, "log_softmax.zip": make_log_softmax_tests, "lstm.zip": make_lstm_tests, + "maximum.zip": make_maximum_tests, } out = FLAGS.zip_to_output bin_path = FLAGS.toco |