aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/in_topk_op_test.py
blob: d2a51788c4e4bc75501b011772f0601aefcc9ad3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
"""Tests for PrecisionOp."""
import tensorflow.python.platform

import numpy as np
import tensorflow as tf


class InTopKTest(tf.test.TestCase):

  def _validateInTopK(self, predictions, target, k, expected):
    np_ans = np.array(expected)
    with self.test_session():
      precision = tf.nn.in_top_k(predictions, target, k)
      out = precision.eval()
      self.assertAllClose(np_ans, out)
      self.assertShapeEqual(np_ans, precision)

  def testInTop1(self):
    predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]]
    target = [3, 1]
    self._validateInTopK(predictions, target, 1, [True, False])

  def testInTop2(self):
    predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]]
    target = [0, 2]
    self._validateInTopK(predictions, target, 2, [False, True])

  def testInTop2Tie(self):
    # Class 2 and 3 tie for 2nd, so both are considered in top 2.
    predictions = [[0.1, 0.3, 0.2, 0.2], [0.1, 0.3, 0.2, 0.2]]
    target = [2, 3]
    self._validateInTopK(predictions, target, 2, [True, True])


if __name__ == "__main__":
  tf.test.main()