From 5a43e01ef0f8cb86d836a4d1c08a246630e26f8c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 5 Oct 2018 09:29:00 -0700 Subject: Update XlaSort to match the underlying HLO. PiperOrigin-RevId: 215917470 --- tensorflow/compiler/tests/sort_ops_test.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) (limited to 'tensorflow/compiler/tests') diff --git a/tensorflow/compiler/tests/sort_ops_test.py b/tensorflow/compiler/tests/sort_ops_test.py index dbf4beb693..57f0ab7a9e 100644 --- a/tensorflow/compiler/tests/sort_ops_test.py +++ b/tensorflow/compiler/tests/sort_ops_test.py @@ -48,13 +48,29 @@ class XlaSortOpTest(xla_test.XLATestCase): self.assertAllClose(v, result, rtol=1e-3) def testSort(self): - supported_types = set([dtypes.bfloat16.as_numpy_dtype, np.float32]) + supported_types = set( + [dtypes.bfloat16.as_numpy_dtype, np.float32, np.int32, np.uint32]) for dtype in supported_types.intersection(self.numeric_types): x = np.arange(101, dtype=dtype) np.random.shuffle(x) self._assertOpOutputMatchesExpected( xla.sort, [x], expected=[np.arange(101, dtype=dtype)]) + def testKeyValueSort(self): + supported_types = set( + [dtypes.bfloat16.as_numpy_dtype, np.float32, np.int32, np.uint32]) + for key_type in supported_types.intersection(self.numeric_types): + for value_type in supported_types.intersection(self.numeric_types): + x = np.arange(101, dtype=key_type) + np.random.shuffle(x) + y = (-x).astype(value_type) + self._assertOpOutputMatchesExpected( + xla.key_value_sort, [x, y], + expected=[ + np.arange(101, dtype=key_type), + -np.arange(101, dtype=value_type) + ]) + def testTopK(self): supported_types = set( [dtypes.bfloat16.as_numpy_dtype, np.float32, np.int32, np.uint32]) -- cgit v1.2.3