diff options
author | 2018-06-20 14:56:00 -0700 | |
---|---|---|
committer | 2018-06-20 14:58:55 -0700 | |
commit | 2cd247d20422a41c33e0f4be265eba2df537ed3b (patch) | |
tree | 1f63e3bc635d1c87c5a2b3a8b7cc62d627dde756 /tensorflow/compiler/tests/sort_ops_test.py | |
parent | 164099ee4688432d614c754b1e01d56715811062 (diff) |
Handle positive and negative infinity in TopKV2.
TopKV2 hides iota in the low bits of the input after converting from bf16 to f32. This usually works, but for positive and negative infinity or'ing in iota produces NANs.
To handle positive and negative infinity, treat bf16 as integers in
sign-magnitude format. Convert to two's complement. Sort in two's complement and
convert back.
Add an exhaustive unit test for bfloat16 to float conversion.
PiperOrigin-RevId: 201421784
Diffstat (limited to 'tensorflow/compiler/tests/sort_ops_test.py')
-rw-r--r-- | tensorflow/compiler/tests/sort_ops_test.py | 29 |
1 files changed, 27 insertions, 2 deletions
diff --git a/tensorflow/compiler/tests/sort_ops_test.py b/tensorflow/compiler/tests/sort_ops_test.py index 370085c1e2..8ae579abda 100644 --- a/tensorflow/compiler/tests/sort_ops_test.py +++ b/tensorflow/compiler/tests/sort_ops_test.py @@ -81,7 +81,7 @@ class XlaSortOpTest(xla_test.XLATestCase): def testTopKZeros(self): """Tests that positive and negative zeros sort correctly.""" - # Requires Sort HLO, which is not implemented on CPU or GPU. + # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU. if self.device in ["XLA_CPU", "XLA_GPU"]: return @@ -99,7 +99,32 @@ class XlaSortOpTest(xla_test.XLATestCase): {p: np.array([0., -0., 0., 3., -0., -4., 0., -0.], dtype=bfloat16)}) self.assertAllEqual( np.array([3., 0., 0., 0.], dtype=bfloat16), results[0]) - self.assertEqual(set([0, 2, 3, 6]), set(results[1])) + self.assertEqual(list([3, 0, 1, 2]), list(results[1])) + + def testTopKInfinities(self): + """Tests that positive and negative infinity sort correctly.""" + # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU. + if self.device in ["XLA_CPU", "XLA_GPU"]: + return + + # Only bfloat16 is implemented. + bfloat16 = dtypes.bfloat16.as_numpy_dtype + if bfloat16 not in self.numeric_types: + return + + with self.test_session() as sess: + p = array_ops.placeholder(dtypes.bfloat16) + with self.test_scope(): + topk = nn_ops.top_k(p, k=6) + results = sess.run(topk, { + p: np.array( + [1, 2, float("inf"), -float("inf"), -1, -2], dtype=bfloat16) + }) + self.assertAllEqual( + np.array( + [float("inf"), 2.0, 1.0, -1.0, -2.0, -float("inf")], + dtype=bfloat16), results[0]) + self.assertEqual(list([2, 1, 0, 4, 5, 3]), list(results[1])) if __name__ == "__main__": |