aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/testing
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-08 11:22:41 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-08 11:26:51 -0700
commita1d3ebbe6f40768ea5ccf6beef9e905bce207b42 (patch)
tree21f4ba44273ba4b2f6caf10f56c20012f11fbd6e /tensorflow/contrib/lite/testing
parente041ae7dcc8916a4f5b358c45a68a6ec731e86e2 (diff)
Fixing for TopKV2 shape propagation in Toco.
PiperOrigin-RevId: 207917321
Diffstat (limited to 'tensorflow/contrib/lite/testing')
-rw-r--r--tensorflow/contrib/lite/testing/generate_examples.py19
1 files changed, 14 insertions, 5 deletions
diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py
index 73c05e21b6..1dfd086f5a 100644
--- a/tensorflow/contrib/lite/testing/generate_examples.py
+++ b/tensorflow/contrib/lite/testing/generate_examples.py
@@ -1670,7 +1670,7 @@ def make_shape_tests(zip_path):
}]
def build_graph(parameters):
- """Build the topk op testing graph."""
+ """Build the shape op testing graph."""
# Note that we intentionally leave out the shape from the input placeholder
# to prevent the Shape operation from being optimized out during conversion.
input_value = tf.placeholder(dtype=parameters["input_dtype"], name="input")
@@ -2318,6 +2318,7 @@ def make_topk_tests(zip_path):
test_parameters = [{
"input_dtype": [tf.float32, tf.int32],
"input_shape": [[10], [5, 20]],
+ "input_k": [None, 1, 3],
}]
def build_graph(parameters):
@@ -2326,15 +2327,23 @@ def make_topk_tests(zip_path):
dtype=parameters["input_dtype"],
name="input",
shape=parameters["input_shape"])
- k = tf.constant(3, name="k")
+ if parameters["input_k"] is not None:
+ k = tf.placeholder(dtype=tf.int32, name="input_k", shape=[])
+ else:
+ k = tf.constant(3, name="k")
out = tf.nn.top_k(input_value, k)
- return [input_value], [out[1]]
+ return [input_value, k], [out[1]]
def build_inputs(parameters, sess, inputs, outputs):
input_value = create_tensor_data(parameters["input_dtype"],
parameters["input_shape"])
- return [input_value], sess.run(
- outputs, feed_dict=dict(zip(inputs, [input_value])))
+ if parameters["input_k"] is not None:
+ k = np.array(parameters["input_k"], dtype=np.int32)
+ return [input_value, k], sess.run(
+ outputs, feed_dict=dict(zip(inputs, [input_value, k])))
+ else:
+ return [input_value], sess.run(
+ outputs, feed_dict=dict(zip(inputs, [input_value])))
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)