aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/nth_element_op_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/nth_element_op_test.py')
-rw-r--r--tensorflow/python/kernel_tests/nth_element_op_test.py174
1 files changed, 174 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/nth_element_op_test.py b/tensorflow/python/kernel_tests/nth_element_op_test.py
new file mode 100644
index 0000000000..58cd46d2d5
--- /dev/null
+++ b/tensorflow/python/kernel_tests/nth_element_op_test.py
@@ -0,0 +1,174 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+import tensorflow.python.ops.nn_grad # pylint: disable=unused-import
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import nn_ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gradients_impl
+from tensorflow.python.platform import test
+
+
+class NthElementTest(test.TestCase):
+
+ def _validateNthElement(self, inputs, dtype, n, reverse, expected_values):
+ np_expected_values = np.array(expected_values)
+ with self.test_session(use_gpu=False) as sess:
+ inputs_op = ops.convert_to_tensor(inputs, dtype=dtype)
+ values_op = nn_ops.nth_element(inputs_op, n, reverse=reverse)
+ values = sess.run(values_op)
+
+ self.assertShapeEqual(np_expected_values, values_op)
+ self.assertAllClose(np_expected_values, values)
+
+ def testExample1(self):
+ inputs = [2.2, 4.4, 1.1, 5.5, 3.3]
+ self._validateNthElement(inputs, dtypes.float32, 1, False, 2.2)
+ self._validateNthElement(inputs, dtypes.float32, 1, True, 4.4)
+
+ def testExample2(self):
+ inputs = [[2.2, 4.4, 1.1], [5.5, 3.3, 6.6]]
+ self._validateNthElement(inputs, dtypes.float64, 2, False, [4.4, 6.6])
+ self._validateNthElement(inputs, dtypes.float64, 2, True, [1.1, 3.3])
+
+ def testExample3(self):
+ inputs = [[[2, 4, 1], [5, -3, 6]],
+ [[7, 9, -8], [9, 0, 4]]]
+ self._validateNthElement(inputs, dtypes.int32, 0, False,
+ [[1, -3], [-8, 0]])
+ self._validateNthElement(inputs, dtypes.int64, 0, True,
+ [[4, 6], [9, 9]])
+
+ def _testFloatLargeInput(self, input_shape):
+ inputs = np.random.random_sample(input_shape)
+ n = np.random.randint(input_shape[-1])
+ sort_inputs = np.sort(inputs)
+ expected_values = sort_inputs[..., n]
+ self._validateNthElement(
+ inputs, dtypes.float32, n, False, expected_values)
+ expected_values = sort_inputs[..., ::-1][..., n]
+ self._validateNthElement(
+ inputs, dtypes.float64, n, True, expected_values)
+
+ def _testIntLargeInput(self, input_shape):
+ inputs = np.random.randint(-1e3, 1e3, input_shape)
+ n = np.random.randint(input_shape[-1])
+ sort_inputs = np.sort(inputs)
+ expected_values = sort_inputs[..., n]
+ self._validateNthElement(
+ inputs, dtypes.int32, n, False, expected_values)
+ expected_values = sort_inputs[..., ::-1][..., n]
+ self._validateNthElement(
+ inputs, dtypes.int64, n, True, expected_values)
+
+ def _testLargeInput(self, input_shape):
+ self._testFloatLargeInput(input_shape)
+ self._testIntLargeInput(input_shape)
+
+ def testLargeInput(self):
+ self._testLargeInput([1])
+ self._testLargeInput([10])
+ self._testLargeInput([5, 10])
+ self._testLargeInput([50, 100])
+ self._testLargeInput([50, 10000])
+ self._testLargeInput([50, 10, 100])
+ self._testLargeInput([50, 10, 10, 100])
+
+ def _testEnumerateN(self, input_shape):
+ inputs = np.random.random_sample(input_shape)
+ sort_inputs = np.sort(inputs)
+ for n in range(input_shape[-1]):
+ expected_values = sort_inputs[..., n]
+ self._validateNthElement(
+ inputs, dtypes.float32, n, False, expected_values)
+ expected_values = sort_inputs[..., ::-1][..., n]
+ self._validateNthElement(
+ inputs, dtypes.float64, n, True, expected_values)
+
+ def testEnumerateN(self):
+ self._testEnumerateN([1])
+ self._testEnumerateN([10])
+ self._testEnumerateN([10, 10])
+ self._testEnumerateN([10, 10, 10])
+ self._testEnumerateN([10, 10, 10, 10])
+
+ def testInvalidInput(self):
+ with self.assertRaisesRegexp(ValueError,
+ "at least rank 1 but is rank 0"):
+ nn_ops.nth_element(5, 0)
+
+ def testInvalidInputAtEval(self):
+ with self.test_session(use_gpu=False):
+ v = array_ops.placeholder(dtype=dtypes.float32)
+ with self.assertRaisesOpError("Input must be >= 1-D"):
+ nn_ops.nth_element(v, 0).eval(feed_dict={v: 5.0})
+
+ def testInvalidN(self):
+ with self.assertRaisesRegexp(ValueError,
+ "non-negative but is -1"):
+ nn_ops.nth_element([5], -1)
+ with self.assertRaisesRegexp(ValueError,
+ "scalar but has rank 1"):
+ nn_ops.nth_element([5, 6, 3], [1])
+
+ def testInvalidNAtEval(self):
+ inputs = [[0.1, 0.2], [0.3, 0.4]]
+ with self.test_session(use_gpu=False):
+ n = array_ops.placeholder(dtypes.int32)
+ values = nn_ops.nth_element(inputs, n)
+ with self.assertRaisesOpError("Need n >= 0, got -7"):
+ values.eval(feed_dict={n: -7})
+
+ def testNTooLarge(self):
+ inputs = [[0.1, 0.2], [0.3, 0.4]]
+ with self.assertRaisesRegexp(ValueError,
+ "must have last dimension > n = 2"):
+ nn_ops.nth_element(inputs, 2)
+
+ def testNTooLargeAtEval(self):
+ inputs = [[0.1, 0.2], [0.3, 0.4]]
+ with self.test_session(use_gpu=False):
+ n = array_ops.placeholder(dtypes.int32)
+ values = nn_ops.nth_element(inputs, n)
+ with self.assertRaisesOpError(r"Input must have at least n\+1 columns"):
+ values.eval(feed_dict={n: 2})
+
+ def testGradients(self):
+ with self.test_session(use_gpu=False) as sess:
+ inputs = array_ops.placeholder(dtypes.int32, shape=[3, 5])
+ values = nn_ops.nth_element(inputs, 3)
+ grad = sess.run(
+ gradients_impl.gradients(
+ values, inputs, grad_ys=[[-1., 2., 5.]]),
+ feed_dict={inputs: [[2, -1, 1000, 3, 1000],
+ [1, 5, 2, 4, 3],
+ [2, 2, 2, 2, 2],
+ ]})
+ self.assertAllClose(grad[0], [[0, 0, -0.5, 0, -0.5],
+ [0, 0, 0, 2, 0],
+ [1, 1, 1, 1, 1],
+ ])
+
+
+
+if __name__ == "__main__":
+ test.main()