aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tests/sort_ops_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/tests/sort_ops_test.py')
-rw-r--r--tensorflow/compiler/tests/sort_ops_test.py57
1 files changed, 49 insertions, 8 deletions
diff --git a/tensorflow/compiler/tests/sort_ops_test.py b/tensorflow/compiler/tests/sort_ops_test.py
index 8ae579abda..7ff01be3cb 100644
--- a/tensorflow/compiler/tests/sort_ops_test.py
+++ b/tensorflow/compiler/tests/sort_ops_test.py
@@ -64,20 +64,61 @@ class XlaSortOpTest(xla_test.XLATestCase):
if self.device in ["XLA_CPU", "XLA_GPU"]:
return
- # Only bfloat16 is implemented.
- bfloat16 = dtypes.bfloat16.as_numpy_dtype
- if bfloat16 in self.numeric_types:
- for x in [np.arange(20)]:
+ supported_types = set(
+ [dtypes.bfloat16.as_numpy_dtype, np.float32, np.int32, np.uint32])
+ for dtype in supported_types.intersection(self.numeric_types):
+ # Use small input size for bfloat16. Otherwise, we'll get duplicate values
+ # after conversion to bfloat16, so the possible resulting index array is
+ # no longer unique.
+ if dtype == dtypes.bfloat16.as_numpy_dtype:
+ array_size = 20
+ k_options = [0, 1, 2, 10, 20]
+ else:
+ array_size = 200 * 1000
+ k_options = [0, 1, 2, 10, 20, 100, 1000, 200 * 1000]
+ for x in [np.arange(array_size)]:
np.random.shuffle(x)
- for k in [0, 1, 2, 10, 20]:
+ for k in k_options:
indices = x.argsort()[::-1][:k]
def topk(v, k=k):
return nn_ops.top_k(v, k=k, sorted=True)
self._assertOpOutputMatchesExpected(
- topk, [x.astype(bfloat16)],
- expected=[x[indices].astype(bfloat16), indices])
+ topk, [x.astype(dtype)],
+ expected=[x[indices].astype(dtype), indices])
+
+ def testTopK2D(self):
+ # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU.
+ if self.device in ["XLA_CPU", "XLA_GPU"]:
+ return
+
+ supported_types = set(
+ [dtypes.bfloat16.as_numpy_dtype, np.float32, np.int32, np.uint32])
+ for dtype in supported_types.intersection(self.numeric_types):
+ # Use small input size for bfloat16. Otherwise, we'll get duplicate values
+ # after conversion to bfloat16, so the possible resulting index array is
+ # no longer unique.
+ if dtype == dtypes.bfloat16.as_numpy_dtype:
+ array_size = 10
+ k_options = [0, 1, 2, 10]
+ else:
+ array_size = 200 * 1000
+ k_options = [0, 1, 2, 10, 20, 100, 1000, 200 * 1000]
+ batch = 16
+ for x in [np.arange(batch * array_size)]:
+ np.random.shuffle(x)
+ x = np.reshape(x, [batch, array_size])
+ for k in k_options:
+ indices = x.argsort(axis=1)[::, -1:-k - 1:-1]
+ expected = np.sort(x, axis=1)[::, -1:-k - 1:-1]
+
+ def topk(v, k=k):
+ return nn_ops.top_k(v, k=k, sorted=True)
+
+ self._assertOpOutputMatchesExpected(
+ topk, [x.astype(dtype)],
+ expected=[expected.astype(dtype), indices])
def testTopKZeros(self):
"""Tests that positive and negative zeros sort correctly."""
@@ -99,7 +140,7 @@ class XlaSortOpTest(xla_test.XLATestCase):
{p: np.array([0., -0., 0., 3., -0., -4., 0., -0.], dtype=bfloat16)})
self.assertAllEqual(
np.array([3., 0., 0., 0.], dtype=bfloat16), results[0])
- self.assertEqual(list([3, 0, 1, 2]), list(results[1]))
+ self.assertEqual(list([3, 0, 2, 6]), list(results[1]))
def testTopKInfinities(self):
"""Tests that positive and negative infinity sort correctly."""