"""Tests for PrecisionOp.""" import tensorflow.python.platform import numpy as np import tensorflow as tf class InTopKTest(tf.test.TestCase): def _validateInTopK(self, predictions, target, k, expected): np_ans = np.array(expected) with self.test_session(): precision = tf.nn.in_top_k(predictions, target, k) out = precision.eval() self.assertAllClose(np_ans, out) self.assertShapeEqual(np_ans, precision) def testInTop1(self): predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]] target = [3, 1] self._validateInTopK(predictions, target, 1, [True, False]) def testInTop2(self): predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]] target = [0, 2] self._validateInTopK(predictions, target, 2, [False, True]) def testInTop2Tie(self): # Class 2 and 3 tie for 2nd, so both are considered in top 2. predictions = [[0.1, 0.3, 0.2, 0.2], [0.1, 0.3, 0.2, 0.2]] target = [2, 3] self._validateInTopK(predictions, target, 2, [True, True]) if __name__ == "__main__": tf.test.main()