aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tests/sort_ops_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/tests/sort_ops_test.py')
-rw-r--r--tensorflow/compiler/tests/sort_ops_test.py18
1 files changed, 17 insertions, 1 deletions
diff --git a/tensorflow/compiler/tests/sort_ops_test.py b/tensorflow/compiler/tests/sort_ops_test.py
index dbf4beb693..57f0ab7a9e 100644
--- a/tensorflow/compiler/tests/sort_ops_test.py
+++ b/tensorflow/compiler/tests/sort_ops_test.py
@@ -48,13 +48,29 @@ class XlaSortOpTest(xla_test.XLATestCase):
self.assertAllClose(v, result, rtol=1e-3)
def testSort(self):
- supported_types = set([dtypes.bfloat16.as_numpy_dtype, np.float32])
+ supported_types = set(
+ [dtypes.bfloat16.as_numpy_dtype, np.float32, np.int32, np.uint32])
for dtype in supported_types.intersection(self.numeric_types):
x = np.arange(101, dtype=dtype)
np.random.shuffle(x)
self._assertOpOutputMatchesExpected(
xla.sort, [x], expected=[np.arange(101, dtype=dtype)])
+ def testKeyValueSort(self):
+ supported_types = set(
+ [dtypes.bfloat16.as_numpy_dtype, np.float32, np.int32, np.uint32])
+ for key_type in supported_types.intersection(self.numeric_types):
+ for value_type in supported_types.intersection(self.numeric_types):
+ x = np.arange(101, dtype=key_type)
+ np.random.shuffle(x)
+ y = (-x).astype(value_type)
+ self._assertOpOutputMatchesExpected(
+ xla.key_value_sort, [x, y],
+ expected=[
+ np.arange(101, dtype=key_type),
+ -np.arange(101, dtype=value_type)
+ ])
+
def testTopK(self):
supported_types = set(
[dtypes.bfloat16.as_numpy_dtype, np.float32, np.int32, np.uint32])