aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/in_topk_op_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/in_topk_op_test.py')
-rw-r--r--tensorflow/python/kernel_tests/in_topk_op_test.py36
1 files changed, 36 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/in_topk_op_test.py b/tensorflow/python/kernel_tests/in_topk_op_test.py
new file mode 100644
index 0000000000..d2a51788c4
--- /dev/null
+++ b/tensorflow/python/kernel_tests/in_topk_op_test.py
@@ -0,0 +1,36 @@
+"""Tests for PrecisionOp."""
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+
+class InTopKTest(tf.test.TestCase):
+
+ def _validateInTopK(self, predictions, target, k, expected):
+ np_ans = np.array(expected)
+ with self.test_session():
+ precision = tf.nn.in_top_k(predictions, target, k)
+ out = precision.eval()
+ self.assertAllClose(np_ans, out)
+ self.assertShapeEqual(np_ans, precision)
+
+ def testInTop1(self):
+ predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]]
+ target = [3, 1]
+ self._validateInTopK(predictions, target, 1, [True, False])
+
+ def testInTop2(self):
+ predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]]
+ target = [0, 2]
+ self._validateInTopK(predictions, target, 2, [False, True])
+
+ def testInTop2Tie(self):
+ # Class 2 and 3 tie for 2nd, so both are considered in top 2.
+ predictions = [[0.1, 0.3, 0.2, 0.2], [0.1, 0.3, 0.2, 0.2]]
+ target = [2, 3]
+ self._validateInTopK(predictions, target, 2, [True, True])
+
+
+if __name__ == "__main__":
+ tf.test.main()