diff options
author | 2018-08-08 11:22:41 -0700 | |
---|---|---|
committer | 2018-08-08 11:26:51 -0700 | |
commit | a1d3ebbe6f40768ea5ccf6beef9e905bce207b42 (patch) | |
tree | 21f4ba44273ba4b2f6caf10f56c20012f11fbd6e /tensorflow/contrib/lite/testing | |
parent | e041ae7dcc8916a4f5b358c45a68a6ec731e86e2 (diff) |
Fixing for TopKV2 shape propagation in Toco.
PiperOrigin-RevId: 207917321
Diffstat (limited to 'tensorflow/contrib/lite/testing')
-rw-r--r-- | tensorflow/contrib/lite/testing/generate_examples.py | 19 |
1 files changed, 14 insertions, 5 deletions
diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py index 73c05e21b6..1dfd086f5a 100644 --- a/tensorflow/contrib/lite/testing/generate_examples.py +++ b/tensorflow/contrib/lite/testing/generate_examples.py @@ -1670,7 +1670,7 @@ def make_shape_tests(zip_path): }] def build_graph(parameters): - """Build the topk op testing graph.""" + """Build the shape op testing graph.""" # Note that we intentionally leave out the shape from the input placeholder # to prevent the Shape operation from being optimized out during conversion. input_value = tf.placeholder(dtype=parameters["input_dtype"], name="input") @@ -2318,6 +2318,7 @@ def make_topk_tests(zip_path): test_parameters = [{ "input_dtype": [tf.float32, tf.int32], "input_shape": [[10], [5, 20]], + "input_k": [None, 1, 3], }] def build_graph(parameters): @@ -2326,15 +2327,23 @@ def make_topk_tests(zip_path): dtype=parameters["input_dtype"], name="input", shape=parameters["input_shape"]) - k = tf.constant(3, name="k") + if parameters["input_k"] is not None: + k = tf.placeholder(dtype=tf.int32, name="input_k", shape=[]) + else: + k = tf.constant(3, name="k") out = tf.nn.top_k(input_value, k) - return [input_value], [out[1]] + return [input_value, k], [out[1]] def build_inputs(parameters, sess, inputs, outputs): input_value = create_tensor_data(parameters["input_dtype"], parameters["input_shape"]) - return [input_value], sess.run( - outputs, feed_dict=dict(zip(inputs, [input_value]))) + if parameters["input_k"] is not None: + k = np.array(parameters["input_k"], dtype=np.int32) + return [input_value, k], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_value, k]))) + else: + return [input_value], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_value]))) make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) |