aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tests/sort_ops_test.py
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-20 14:56:00 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-20 14:58:55 -0700
commit2cd247d20422a41c33e0f4be265eba2df537ed3b (patch)
tree1f63e3bc635d1c87c5a2b3a8b7cc62d627dde756 /tensorflow/compiler/tests/sort_ops_test.py
parent164099ee4688432d614c754b1e01d56715811062 (diff)
Handle positive and negative infinity in TopKV2.
TopKV2 hides iota in the low bits of the input after converting from bf16 to f32. This usually works, but for positive and negative infinity or'ing in iota produces NANs. To handle positive and negative infinity, treat bf16 as integers in sign-magnitude format. Convert to two's complement. Sort in two's complement and convert back. Add an exhaustive unit test for bfloat16 to float conversion. PiperOrigin-RevId: 201421784
Diffstat (limited to 'tensorflow/compiler/tests/sort_ops_test.py')
-rw-r--r--tensorflow/compiler/tests/sort_ops_test.py29
1 files changed, 27 insertions, 2 deletions
diff --git a/tensorflow/compiler/tests/sort_ops_test.py b/tensorflow/compiler/tests/sort_ops_test.py
index 370085c1e2..8ae579abda 100644
--- a/tensorflow/compiler/tests/sort_ops_test.py
+++ b/tensorflow/compiler/tests/sort_ops_test.py
@@ -81,7 +81,7 @@ class XlaSortOpTest(xla_test.XLATestCase):
def testTopKZeros(self):
"""Tests that positive and negative zeros sort correctly."""
- # Requires Sort HLO, which is not implemented on CPU or GPU.
+ # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU.
if self.device in ["XLA_CPU", "XLA_GPU"]:
return
@@ -99,7 +99,32 @@ class XlaSortOpTest(xla_test.XLATestCase):
{p: np.array([0., -0., 0., 3., -0., -4., 0., -0.], dtype=bfloat16)})
self.assertAllEqual(
np.array([3., 0., 0., 0.], dtype=bfloat16), results[0])
- self.assertEqual(set([0, 2, 3, 6]), set(results[1]))
+ self.assertEqual(list([3, 0, 1, 2]), list(results[1]))
+
+ def testTopKInfinities(self):
+ """Tests that positive and negative infinity sort correctly."""
+ # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU.
+ if self.device in ["XLA_CPU", "XLA_GPU"]:
+ return
+
+ # Only bfloat16 is implemented.
+ bfloat16 = dtypes.bfloat16.as_numpy_dtype
+ if bfloat16 not in self.numeric_types:
+ return
+
+ with self.test_session() as sess:
+ p = array_ops.placeholder(dtypes.bfloat16)
+ with self.test_scope():
+ topk = nn_ops.top_k(p, k=6)
+ results = sess.run(topk, {
+ p: np.array(
+ [1, 2, float("inf"), -float("inf"), -1, -2], dtype=bfloat16)
+ })
+ self.assertAllEqual(
+ np.array(
+ [float("inf"), 2.0, 1.0, -1.0, -2.0, -float("inf")],
+ dtype=bfloat16), results[0])
+ self.assertEqual(list([2, 1, 0, 4, 5, 3]), list(results[1]))
if __name__ == "__main__":