aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tests/sort_ops_test.py
diff options
context:
space:
mode:
authorGravatar Peter Hawkins <phawkins@google.com>2018-06-18 09:16:09 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-18 09:23:18 -0700
commit8ecf506fb8464dd273ce59f512f5e20d37dd5cfd (patch)
tree86463ca85a8382743d040000e7135fa135c16f32 /tensorflow/compiler/tests/sort_ops_test.py
parent3db3e50bb0c02d6f0c7284d50bc31e97ebfc96e5 (diff)
[TF:XLA] Add a XlaSort operator that directly wraps the Sort HLO.
Merge XLA-specific operator registrations into a single file rather than having many tiny files. In passing, register a fill function for bfloat16 numpy type; needed for the np.arange() call in the sort unit test. PiperOrigin-RevId: 201005718
Diffstat (limited to 'tensorflow/compiler/tests/sort_ops_test.py')
-rw-r--r--tensorflow/compiler/tests/sort_ops_test.py57
1 files changed, 57 insertions, 0 deletions
diff --git a/tensorflow/compiler/tests/sort_ops_test.py b/tensorflow/compiler/tests/sort_ops_test.py
new file mode 100644
index 0000000000..5ff40edaa5
--- /dev/null
+++ b/tensorflow/compiler/tests/sort_ops_test.py
@@ -0,0 +1,57 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for XlaSort."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.compiler.tests import xla_test
+from tensorflow.compiler.tf2xla.python import xla
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import test
+
+
+class XlaSortOpTest(xla_test.XLATestCase):
+
+ def _assertOpOutputMatchesExpected(self, op, args, expected):
+ with self.test_session() as session:
+ with self.test_scope():
+ placeholders = [
+ array_ops.placeholder(dtypes.as_dtype(arg.dtype), arg.shape)
+ for arg in args
+ ]
+ feeds = {placeholders[i]: args[i] for i in range(0, len(args))}
+ output = op(*placeholders)
+ result = session.run(output, feeds)
+ self.assertAllClose(result, expected, rtol=1e-3)
+
+ def testSort(self):
+ # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU.
+ if self.device in ["XLA_CPU", "XLA_GPU"]:
+ return
+ supported_types = set([dtypes.bfloat16.as_numpy_dtype, np.float32])
+ for dtype in supported_types.intersection(self.numeric_types):
+ x = np.arange(101, dtype=dtype)
+ np.random.shuffle(x)
+ self._assertOpOutputMatchesExpected(
+ xla.sort, [x], expected=np.arange(101, dtype=dtype))
+
+
+if __name__ == "__main__":
+ test.main()