diff options
Diffstat (limited to 'tensorflow/python/kernel_tests/in_topk_op_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/in_topk_op_test.py | 36 |
1 files changed, 36 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/in_topk_op_test.py b/tensorflow/python/kernel_tests/in_topk_op_test.py new file mode 100644 index 0000000000..d2a51788c4 --- /dev/null +++ b/tensorflow/python/kernel_tests/in_topk_op_test.py @@ -0,0 +1,36 @@ +"""Tests for PrecisionOp.""" +import tensorflow.python.platform + +import numpy as np +import tensorflow as tf + + +class InTopKTest(tf.test.TestCase): + + def _validateInTopK(self, predictions, target, k, expected): + np_ans = np.array(expected) + with self.test_session(): + precision = tf.nn.in_top_k(predictions, target, k) + out = precision.eval() + self.assertAllClose(np_ans, out) + self.assertShapeEqual(np_ans, precision) + + def testInTop1(self): + predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]] + target = [3, 1] + self._validateInTopK(predictions, target, 1, [True, False]) + + def testInTop2(self): + predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]] + target = [0, 2] + self._validateInTopK(predictions, target, 2, [False, True]) + + def testInTop2Tie(self): + # Class 2 and 3 tie for 2nd, so both are considered in top 2. + predictions = [[0.1, 0.3, 0.2, 0.2], [0.1, 0.3, 0.2, 0.2]] + target = [2, 3] + self._validateInTopK(predictions, target, 2, [True, True]) + + +if __name__ == "__main__": + tf.test.main() |