diff options
Diffstat (limited to 'tensorflow/python/kernel_tests/topk_op_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/topk_op_test.py | 52 |
1 files changed, 52 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/topk_op_test.py b/tensorflow/python/kernel_tests/topk_op_test.py new file mode 100644 index 0000000000..497dc9ac1e --- /dev/null +++ b/tensorflow/python/kernel_tests/topk_op_test.py @@ -0,0 +1,52 @@ +"""Tests for TopK op.""" +import tensorflow.python.platform + +import numpy as np +import tensorflow as tf + + +class TopKTest(tf.test.TestCase): + + def _validateTopK(self, inputs, k, expected_values, expected_indices): + np_values = np.array(expected_values) + np_indices = np.array(expected_indices) + with self.test_session(): + values_op, indices_op = tf.nn.top_k(inputs, k) + values = values_op.eval() + indices = indices_op.eval() + self.assertAllClose(np_values, values) + self.assertAllEqual(np_indices, indices) + self.assertShapeEqual(np_values, values_op) + self.assertShapeEqual(np_indices, indices_op) + + def testTop1(self): + inputs = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.3, 0.3, 0.2]] + self._validateTopK(inputs, 1, + [[0.4], [0.3]], + [[3], [1]]) + + def testTop2(self): + inputs = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.3, 0.3, 0.2]] + self._validateTopK(inputs, 2, + [[0.4, 0.3], [0.3, 0.3]], + [[3, 1], [1, 2]]) + + def testTopAll(self): + inputs = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.3, 0.3, 0.2]] + self._validateTopK(inputs, 4, + [[0.4, 0.3, 0.2, 0.1], [0.3, 0.3, 0.2, 0.1]], + [[3, 1, 2, 0], [1, 2, 3, 0]]) + + def testKNegative(self): + inputs = [[0.1, 0.2], [0.3, 0.4]] + with self.assertRaisesRegexp(ValueError, "less than minimum 1"): + tf.nn.top_k(inputs, -1) + + def testKTooLarge(self): + inputs = [[0.1, 0.2], [0.3, 0.4]] + with self.assertRaisesRegexp(ValueError, "input must have at least k"): + tf.nn.top_k(inputs, 4) + + +if __name__ == "__main__": + tf.test.main() |