# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Tests for sorting operators.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import numpy as np from tensorflow.compiler.tests import xla_test from tensorflow.compiler.tf2xla.python import xla from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import nn_ops from tensorflow.python.platform import test class XlaSortOpTest(xla_test.XLATestCase): def _assertOpOutputMatchesExpected(self, op, args, expected): with self.cached_session() as session: with self.test_scope(): placeholders = [ array_ops.placeholder(dtypes.as_dtype(arg.dtype), arg.shape) for arg in args ] feeds = {placeholders[i]: args[i] for i in range(0, len(args))} output = op(*placeholders) if isinstance(output, ops.Tensor): output = [output] results = session.run(output, feeds) for result, v in zip(results, expected): self.assertAllClose(v, result, rtol=1e-3) def testSort(self): supported_types = set( [dtypes.bfloat16.as_numpy_dtype, np.float32, np.int32, np.uint32]) for dtype in supported_types.intersection(self.numeric_types): x = np.arange(101, dtype=dtype) np.random.shuffle(x) self._assertOpOutputMatchesExpected( xla.sort, [x], expected=[np.arange(101, dtype=dtype)]) def testKeyValueSort(self): supported_types = set( [dtypes.bfloat16.as_numpy_dtype, np.float32, np.int32, np.uint32]) for key_type in supported_types.intersection(self.numeric_types): for value_type in supported_types.intersection(self.numeric_types): x = np.arange(101, dtype=key_type) np.random.shuffle(x) y = (-x).astype(value_type) self._assertOpOutputMatchesExpected( xla.key_value_sort, [x, y], expected=[ np.arange(101, dtype=key_type), -np.arange(101, dtype=value_type) ]) def testTopK(self): supported_types = set( [dtypes.bfloat16.as_numpy_dtype, np.float32, np.int32, np.uint32]) for dtype in supported_types.intersection(self.numeric_types): # Use small input size for bfloat16. Otherwise, we'll get duplicate values # after conversion to bfloat16, so the possible resulting index array is # no longer unique. if dtype == dtypes.bfloat16.as_numpy_dtype: array_size = 20 k_options = [0, 1, 2, 10, 20] else: array_size = 200 * 1000 k_options = [0, 1, 2, 10, 20, 100, 1000, 200 * 1000] for x in [np.arange(array_size)]: np.random.shuffle(x) for k in k_options: indices = x.argsort()[::-1][:k] def topk(v, k=k): return nn_ops.top_k(v, k=k, sorted=True) self._assertOpOutputMatchesExpected( topk, [x.astype(dtype)], expected=[x[indices].astype(dtype), indices]) def testTopK2D(self): supported_types = set( [dtypes.bfloat16.as_numpy_dtype, np.float32, np.int32, np.uint32]) for dtype in supported_types.intersection(self.numeric_types): # Use small input size for bfloat16. Otherwise, we'll get duplicate values # after conversion to bfloat16, so the possible resulting index array is # no longer unique. if dtype == dtypes.bfloat16.as_numpy_dtype: array_size = 10 k_options = [0, 1, 2, 10] else: array_size = 200 * 1000 k_options = [0, 1, 2, 10, 20, 100, 1000, 200 * 1000] batch = 16 for x in [np.arange(batch * array_size)]: np.random.shuffle(x) x = np.reshape(x, [batch, array_size]) for k in k_options: indices = x.argsort(axis=1)[::, -1:-k - 1:-1] expected = np.sort(x, axis=1)[::, -1:-k - 1:-1] def topk(v, k=k): return nn_ops.top_k(v, k=k, sorted=True) self._assertOpOutputMatchesExpected( topk, [x.astype(dtype)], expected=[expected.astype(dtype), indices]) def testTopKZeros(self): """Tests that positive and negative zeros sort correctly.""" # Only bfloat16 is implemented. bfloat16 = dtypes.bfloat16.as_numpy_dtype if bfloat16 not in self.numeric_types: return with self.cached_session() as sess: p = array_ops.placeholder(dtypes.bfloat16) with self.test_scope(): topk = nn_ops.top_k(p, k=4) results = sess.run( topk, {p: np.array([0., -0., 0., 3., -0., -4., 0., -0.], dtype=bfloat16)}) self.assertAllEqual( np.array([3., 0., 0., 0.], dtype=bfloat16), results[0]) self.assertEqual(list([3, 0, 2, 6]), list(results[1])) def testTopKInfinities(self): """Tests that positive and negative infinity sort correctly.""" # Only bfloat16 is implemented. bfloat16 = dtypes.bfloat16.as_numpy_dtype if bfloat16 not in self.numeric_types: return with self.cached_session() as sess: p = array_ops.placeholder(dtypes.bfloat16) with self.test_scope(): topk = nn_ops.top_k(p, k=6) results = sess.run(topk, { p: np.array( [1, 2, float("inf"), -float("inf"), -1, -2], dtype=bfloat16) }) self.assertAllEqual( np.array( [float("inf"), 2.0, 1.0, -1.0, -2.0, -float("inf")], dtype=bfloat16), results[0]) self.assertEqual(list([2, 1, 0, 4, 5, 3]), list(results[1])) if __name__ == "__main__": test.main()