aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/compiler/tests/sort_ops_test.py29
-rw-r--r--tensorflow/compiler/tf2xla/kernels/topk_op.cc99
-rw-r--r--tensorflow/compiler/xla/tests/convert_test.cc21
3 files changed, 121 insertions, 28 deletions
diff --git a/tensorflow/compiler/tests/sort_ops_test.py b/tensorflow/compiler/tests/sort_ops_test.py
index 370085c1e2..8ae579abda 100644
--- a/tensorflow/compiler/tests/sort_ops_test.py
+++ b/tensorflow/compiler/tests/sort_ops_test.py
@@ -81,7 +81,7 @@ class XlaSortOpTest(xla_test.XLATestCase):
def testTopKZeros(self):
"""Tests that positive and negative zeros sort correctly."""
- # Requires Sort HLO, which is not implemented on CPU or GPU.
+ # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU.
if self.device in ["XLA_CPU", "XLA_GPU"]:
return
@@ -99,7 +99,32 @@ class XlaSortOpTest(xla_test.XLATestCase):
{p: np.array([0., -0., 0., 3., -0., -4., 0., -0.], dtype=bfloat16)})
self.assertAllEqual(
np.array([3., 0., 0., 0.], dtype=bfloat16), results[0])
- self.assertEqual(set([0, 2, 3, 6]), set(results[1]))
+ self.assertEqual(list([3, 0, 1, 2]), list(results[1]))
+
+ def testTopKInfinities(self):
+ """Tests that positive and negative infinity sort correctly."""
+ # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU.
+ if self.device in ["XLA_CPU", "XLA_GPU"]:
+ return
+
+ # Only bfloat16 is implemented.
+ bfloat16 = dtypes.bfloat16.as_numpy_dtype
+ if bfloat16 not in self.numeric_types:
+ return
+
+ with self.test_session() as sess:
+ p = array_ops.placeholder(dtypes.bfloat16)
+ with self.test_scope():
+ topk = nn_ops.top_k(p, k=6)
+ results = sess.run(topk, {
+ p: np.array(
+ [1, 2, float("inf"), -float("inf"), -1, -2], dtype=bfloat16)
+ })
+ self.assertAllEqual(
+ np.array(
+ [float("inf"), 2.0, 1.0, -1.0, -2.0, -float("inf")],
+ dtype=bfloat16), results[0])
+ self.assertEqual(list([2, 1, 0, 4, 5, 3]), list(results[1]))
if __name__ == "__main__":
diff --git a/tensorflow/compiler/tf2xla/kernels/topk_op.cc b/tensorflow/compiler/tf2xla/kernels/topk_op.cc
index 703e13e089..cbe3c8aaff 100644
--- a/tensorflow/compiler/tf2xla/kernels/topk_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/topk_op.cc
@@ -61,42 +61,89 @@ class TopKOp : public XlaOpKernel {
if (input_shape.dim_size(0) < k) {
k = input_shape.dim_size(0);
}
- const xla::XlaOp input = context->Input(0);
- xla::XlaOp iota;
- OP_REQUIRES_OK(context, XlaHelpers::Iota(b, DT_INT32, n, &iota));
+ const xla::XlaOp input_bf16 = context->Input(0);
+ xla::XlaOp iota_s32;
+ OP_REQUIRES_OK(context, XlaHelpers::Iota(b, DT_INT32, n, &iota_s32));
// TODO(b/73891930): add a key-value sort to HLO, rather than using
// bit-packing tricks here.
- // TODO(b/73891930): this implementation will convert Infs to NaNs. A
- // key-value sort would avoid this; for now, it is no worse than, say, the
- // CPU backend in fast-math mode.
+
+ xla::XlaOp zero = b->ConstantR0<int32>(0);
+
+ // max can either be 0x7FFFFFFF or 0x8000000. Neither choice is totally
+ // ideal. The implications of the choice are:
+ //
+ // 0x7FFFFFFF
+ // 1. +0.0 > -0.0
+ // 2. The elements of the inputs and outputs are bitwise identical.
+ // 3. The sort is unstable since a later +0.0 will appear before an earlier
+ // -0.0.
+ //
+ // 0x8000000
+ // 1. +0.0 == -0.0
+ // 2. All -0.0 in the input are replaced with +0.0 in the output.
+ // 3. The sort is stable.
+ xla::XlaOp max = b->ConstantR0<int32>(0x80000000);
+ xla::XlaOp index_mask = b->ConstantR0<int32>(0x0000FFFF);
+ xla::XlaOp value_mask = b->ConstantR0<int32>(0xFFFF0000);
+
+ // Convert to from bf16 to f32. The lower 16-bits are zero due to the
+ // definition of bf16.
+ xla::XlaOp input_f32 = b->ConvertElementType(input_bf16, xla::F32);
+
+ // Negate the input to reverse sort it. The lower 16-bits are zero, because
+ // negating a float is just inverting the high-bit.
+ xla::XlaOp negative_input_f32 = b->Neg(input_f32);
+
+ // Convert to a sign magnitude integer. The lower 16-bits are zero, since
+ // bitcast convert doesn't change any bits.
+ xla::XlaOp negative_input_sm32 =
+ b->BitcastConvertType(negative_input_f32, xla::S32);
+
+ // Convert from sign magnitude integer to two's complement integer. The
+ // lower 16-bits are zero on both sides of the select. On the false side,
+ // the value is unchanged, and on the true side, the lower 16-bits of max
+ // are all zero, so the lower 16-bits of the result of the subtraction will
+ // also be zero.
+ xla::XlaOp negative_input_s32 =
+ b->Select(b->Lt(negative_input_sm32, zero),
+ b->Sub(max, negative_input_sm32), negative_input_sm32);
+
+ // In order for the Or with iota_s32 to to work properly, the lower 16-bits
+ // of negative_input_32 must be zero.
// Pack elements as:
// * upper 16 bits are the value
// * lower 16 bits are the index.
- xla::XlaOp packed = b->BitcastConvertType(
- b->Or(b->BitcastConvertType(b->ConvertElementType(input, xla::F32),
- xla::S32),
- iota),
- xla::F32);
+ xla::XlaOp packed_s32 = b->Or(negative_input_s32, iota_s32);
// TODO(phawkins): use a more efficient algorithm that does not require a
// full sort.
- xla::XlaOp sorted = b->Slice(b->Rev(b->Sort(packed), {0}),
- /*start_indices=*/{0},
- /*limit_indices=*/{k},
- /*strides=*/{1});
-
- // Unpack the value/index
- xla::XlaOp x = b->BitcastConvertType(sorted, xla::S32);
- xla::XlaOp indices = b->And(x, b->ConstantR0<int32>(0x0000FFFF));
- xla::XlaOp values = b->ConvertElementType(
- b->BitcastConvertType(b->And(x, b->ConstantR0<int32>(0xFFFF0000)),
- xla::F32),
- xla::BF16);
-
- context->SetOutput(0, values);
- context->SetOutput(1, indices);
+ xla::XlaOp sorted_s32 = b->Slice(b->Sort(packed_s32),
+ /*start_indices=*/{0},
+ /*limit_indices=*/{k},
+ /*strides=*/{1});
+
+ // Unpack the value/index.
+ xla::XlaOp indices_s32 = b->And(sorted_s32, index_mask);
+ xla::XlaOp negative_values_s32 = b->And(sorted_s32, value_mask);
+
+ // Convert from two's complement integer to sign magnitude integer.
+ xla::XlaOp negative_values_sm32 =
+ b->Select(b->Lt(negative_values_s32, zero),
+ b->Sub(max, negative_values_s32), negative_values_s32);
+
+ xla::XlaOp negative_values_f32 =
+ b->BitcastConvertType(negative_values_sm32, xla::F32);
+
+ // Negate the values to get back the original inputs.
+ xla::XlaOp values_f32 = b->Neg(negative_values_f32);
+
+ // Convert from f32 to bf16.
+ xla::XlaOp values_bf16 = b->ConvertElementType(values_f32, xla::BF16);
+
+ context->SetOutput(0, values_bf16);
+ context->SetOutput(1, indices_s32);
}
private:
diff --git a/tensorflow/compiler/xla/tests/convert_test.cc b/tensorflow/compiler/xla/tests/convert_test.cc
index 722d882471..3a885b4389 100644
--- a/tensorflow/compiler/xla/tests/convert_test.cc
+++ b/tensorflow/compiler/xla/tests/convert_test.cc
@@ -461,5 +461,26 @@ XLA_TEST_F(ConvertTest, ConvertS64U64) {
ComputeAndCompareR1<uint64>(&builder, unsigned_x, {});
}
+XLA_TEST_F(ConvertTest, ConvertBF16F32) {
+ XlaBuilder builder(TestName());
+
+ std::vector<bfloat16> all_bfloats(1 << 16);
+ for (int i = 0; i < all_bfloats.size(); ++i) {
+ all_bfloats[i].value = i;
+ }
+
+ std::vector<uint32> expected(all_bfloats.size());
+ for (int i = 0; i < expected.size(); ++i) {
+ expected[i] = (1U << 16) * i;
+ }
+
+ // Exhaustively test all bf16 to f32 conversions.
+ xla::XlaOp all_bfloats_bf16 = builder.ConstantR1<bfloat16>(all_bfloats);
+ xla::XlaOp all_bfloats_f32 =
+ builder.ConvertElementType(all_bfloats_bf16, F32);
+ xla::XlaOp all_bfloats_u32 = builder.BitcastConvertType(all_bfloats_f32, U32);
+ ComputeAndCompareR1<uint32>(&builder, expected, {});
+}
+
} // namespace
} // namespace xla