diff options
Diffstat (limited to 'tensorflow/compiler/tests/sort_ops_test.py')
-rw-r--r-- | tensorflow/compiler/tests/sort_ops_test.py | 57 |
1 files changed, 49 insertions, 8 deletions
diff --git a/tensorflow/compiler/tests/sort_ops_test.py b/tensorflow/compiler/tests/sort_ops_test.py index 8ae579abda..7ff01be3cb 100644 --- a/tensorflow/compiler/tests/sort_ops_test.py +++ b/tensorflow/compiler/tests/sort_ops_test.py @@ -64,20 +64,61 @@ class XlaSortOpTest(xla_test.XLATestCase): if self.device in ["XLA_CPU", "XLA_GPU"]: return - # Only bfloat16 is implemented. - bfloat16 = dtypes.bfloat16.as_numpy_dtype - if bfloat16 in self.numeric_types: - for x in [np.arange(20)]: + supported_types = set( + [dtypes.bfloat16.as_numpy_dtype, np.float32, np.int32, np.uint32]) + for dtype in supported_types.intersection(self.numeric_types): + # Use small input size for bfloat16. Otherwise, we'll get duplicate values + # after conversion to bfloat16, so the possible resulting index array is + # no longer unique. + if dtype == dtypes.bfloat16.as_numpy_dtype: + array_size = 20 + k_options = [0, 1, 2, 10, 20] + else: + array_size = 200 * 1000 + k_options = [0, 1, 2, 10, 20, 100, 1000, 200 * 1000] + for x in [np.arange(array_size)]: np.random.shuffle(x) - for k in [0, 1, 2, 10, 20]: + for k in k_options: indices = x.argsort()[::-1][:k] def topk(v, k=k): return nn_ops.top_k(v, k=k, sorted=True) self._assertOpOutputMatchesExpected( - topk, [x.astype(bfloat16)], - expected=[x[indices].astype(bfloat16), indices]) + topk, [x.astype(dtype)], + expected=[x[indices].astype(dtype), indices]) + + def testTopK2D(self): + # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU. + if self.device in ["XLA_CPU", "XLA_GPU"]: + return + + supported_types = set( + [dtypes.bfloat16.as_numpy_dtype, np.float32, np.int32, np.uint32]) + for dtype in supported_types.intersection(self.numeric_types): + # Use small input size for bfloat16. Otherwise, we'll get duplicate values + # after conversion to bfloat16, so the possible resulting index array is + # no longer unique. + if dtype == dtypes.bfloat16.as_numpy_dtype: + array_size = 10 + k_options = [0, 1, 2, 10] + else: + array_size = 200 * 1000 + k_options = [0, 1, 2, 10, 20, 100, 1000, 200 * 1000] + batch = 16 + for x in [np.arange(batch * array_size)]: + np.random.shuffle(x) + x = np.reshape(x, [batch, array_size]) + for k in k_options: + indices = x.argsort(axis=1)[::, -1:-k - 1:-1] + expected = np.sort(x, axis=1)[::, -1:-k - 1:-1] + + def topk(v, k=k): + return nn_ops.top_k(v, k=k, sorted=True) + + self._assertOpOutputMatchesExpected( + topk, [x.astype(dtype)], + expected=[expected.astype(dtype), indices]) def testTopKZeros(self): """Tests that positive and negative zeros sort correctly.""" @@ -99,7 +140,7 @@ class XlaSortOpTest(xla_test.XLATestCase): {p: np.array([0., -0., 0., 3., -0., -4., 0., -0.], dtype=bfloat16)}) self.assertAllEqual( np.array([3., 0., 0., 0.], dtype=bfloat16), results[0]) - self.assertEqual(list([3, 0, 1, 2]), list(results[1])) + self.assertEqual(list([3, 0, 2, 6]), list(results[1])) def testTopKInfinities(self): """Tests that positive and negative infinity sort correctly.""" |