diff options
-rw-r--r-- | tensorflow/compiler/tests/sort_ops_test.py | 29 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/topk_op.cc | 99 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tests/convert_test.cc | 21 |
3 files changed, 121 insertions, 28 deletions
diff --git a/tensorflow/compiler/tests/sort_ops_test.py b/tensorflow/compiler/tests/sort_ops_test.py index 370085c1e2..8ae579abda 100644 --- a/tensorflow/compiler/tests/sort_ops_test.py +++ b/tensorflow/compiler/tests/sort_ops_test.py @@ -81,7 +81,7 @@ class XlaSortOpTest(xla_test.XLATestCase): def testTopKZeros(self): """Tests that positive and negative zeros sort correctly.""" - # Requires Sort HLO, which is not implemented on CPU or GPU. + # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU. if self.device in ["XLA_CPU", "XLA_GPU"]: return @@ -99,7 +99,32 @@ class XlaSortOpTest(xla_test.XLATestCase): {p: np.array([0., -0., 0., 3., -0., -4., 0., -0.], dtype=bfloat16)}) self.assertAllEqual( np.array([3., 0., 0., 0.], dtype=bfloat16), results[0]) - self.assertEqual(set([0, 2, 3, 6]), set(results[1])) + self.assertEqual(list([3, 0, 1, 2]), list(results[1])) + + def testTopKInfinities(self): + """Tests that positive and negative infinity sort correctly.""" + # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU. + if self.device in ["XLA_CPU", "XLA_GPU"]: + return + + # Only bfloat16 is implemented. + bfloat16 = dtypes.bfloat16.as_numpy_dtype + if bfloat16 not in self.numeric_types: + return + + with self.test_session() as sess: + p = array_ops.placeholder(dtypes.bfloat16) + with self.test_scope(): + topk = nn_ops.top_k(p, k=6) + results = sess.run(topk, { + p: np.array( + [1, 2, float("inf"), -float("inf"), -1, -2], dtype=bfloat16) + }) + self.assertAllEqual( + np.array( + [float("inf"), 2.0, 1.0, -1.0, -2.0, -float("inf")], + dtype=bfloat16), results[0]) + self.assertEqual(list([2, 1, 0, 4, 5, 3]), list(results[1])) if __name__ == "__main__": diff --git a/tensorflow/compiler/tf2xla/kernels/topk_op.cc b/tensorflow/compiler/tf2xla/kernels/topk_op.cc index 703e13e089..cbe3c8aaff 100644 --- a/tensorflow/compiler/tf2xla/kernels/topk_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/topk_op.cc @@ -61,42 +61,89 @@ class TopKOp : public XlaOpKernel { if (input_shape.dim_size(0) < k) { k = input_shape.dim_size(0); } - const xla::XlaOp input = context->Input(0); - xla::XlaOp iota; - OP_REQUIRES_OK(context, XlaHelpers::Iota(b, DT_INT32, n, &iota)); + const xla::XlaOp input_bf16 = context->Input(0); + xla::XlaOp iota_s32; + OP_REQUIRES_OK(context, XlaHelpers::Iota(b, DT_INT32, n, &iota_s32)); // TODO(b/73891930): add a key-value sort to HLO, rather than using // bit-packing tricks here. - // TODO(b/73891930): this implementation will convert Infs to NaNs. A - // key-value sort would avoid this; for now, it is no worse than, say, the - // CPU backend in fast-math mode. + + xla::XlaOp zero = b->ConstantR0<int32>(0); + + // max can either be 0x7FFFFFFF or 0x8000000. Neither choice is totally + // ideal. The implications of the choice are: + // + // 0x7FFFFFFF + // 1. +0.0 > -0.0 + // 2. The elements of the inputs and outputs are bitwise identical. + // 3. The sort is unstable since a later +0.0 will appear before an earlier + // -0.0. + // + // 0x8000000 + // 1. +0.0 == -0.0 + // 2. All -0.0 in the input are replaced with +0.0 in the output. + // 3. The sort is stable. + xla::XlaOp max = b->ConstantR0<int32>(0x80000000); + xla::XlaOp index_mask = b->ConstantR0<int32>(0x0000FFFF); + xla::XlaOp value_mask = b->ConstantR0<int32>(0xFFFF0000); + + // Convert to from bf16 to f32. The lower 16-bits are zero due to the + // definition of bf16. + xla::XlaOp input_f32 = b->ConvertElementType(input_bf16, xla::F32); + + // Negate the input to reverse sort it. The lower 16-bits are zero, because + // negating a float is just inverting the high-bit. + xla::XlaOp negative_input_f32 = b->Neg(input_f32); + + // Convert to a sign magnitude integer. The lower 16-bits are zero, since + // bitcast convert doesn't change any bits. + xla::XlaOp negative_input_sm32 = + b->BitcastConvertType(negative_input_f32, xla::S32); + + // Convert from sign magnitude integer to two's complement integer. The + // lower 16-bits are zero on both sides of the select. On the false side, + // the value is unchanged, and on the true side, the lower 16-bits of max + // are all zero, so the lower 16-bits of the result of the subtraction will + // also be zero. + xla::XlaOp negative_input_s32 = + b->Select(b->Lt(negative_input_sm32, zero), + b->Sub(max, negative_input_sm32), negative_input_sm32); + + // In order for the Or with iota_s32 to to work properly, the lower 16-bits + // of negative_input_32 must be zero. // Pack elements as: // * upper 16 bits are the value // * lower 16 bits are the index. - xla::XlaOp packed = b->BitcastConvertType( - b->Or(b->BitcastConvertType(b->ConvertElementType(input, xla::F32), - xla::S32), - iota), - xla::F32); + xla::XlaOp packed_s32 = b->Or(negative_input_s32, iota_s32); // TODO(phawkins): use a more efficient algorithm that does not require a // full sort. - xla::XlaOp sorted = b->Slice(b->Rev(b->Sort(packed), {0}), - /*start_indices=*/{0}, - /*limit_indices=*/{k}, - /*strides=*/{1}); - - // Unpack the value/index - xla::XlaOp x = b->BitcastConvertType(sorted, xla::S32); - xla::XlaOp indices = b->And(x, b->ConstantR0<int32>(0x0000FFFF)); - xla::XlaOp values = b->ConvertElementType( - b->BitcastConvertType(b->And(x, b->ConstantR0<int32>(0xFFFF0000)), - xla::F32), - xla::BF16); - - context->SetOutput(0, values); - context->SetOutput(1, indices); + xla::XlaOp sorted_s32 = b->Slice(b->Sort(packed_s32), + /*start_indices=*/{0}, + /*limit_indices=*/{k}, + /*strides=*/{1}); + + // Unpack the value/index. + xla::XlaOp indices_s32 = b->And(sorted_s32, index_mask); + xla::XlaOp negative_values_s32 = b->And(sorted_s32, value_mask); + + // Convert from two's complement integer to sign magnitude integer. + xla::XlaOp negative_values_sm32 = + b->Select(b->Lt(negative_values_s32, zero), + b->Sub(max, negative_values_s32), negative_values_s32); + + xla::XlaOp negative_values_f32 = + b->BitcastConvertType(negative_values_sm32, xla::F32); + + // Negate the values to get back the original inputs. + xla::XlaOp values_f32 = b->Neg(negative_values_f32); + + // Convert from f32 to bf16. + xla::XlaOp values_bf16 = b->ConvertElementType(values_f32, xla::BF16); + + context->SetOutput(0, values_bf16); + context->SetOutput(1, indices_s32); } private: diff --git a/tensorflow/compiler/xla/tests/convert_test.cc b/tensorflow/compiler/xla/tests/convert_test.cc index 722d882471..3a885b4389 100644 --- a/tensorflow/compiler/xla/tests/convert_test.cc +++ b/tensorflow/compiler/xla/tests/convert_test.cc @@ -461,5 +461,26 @@ XLA_TEST_F(ConvertTest, ConvertS64U64) { ComputeAndCompareR1<uint64>(&builder, unsigned_x, {}); } +XLA_TEST_F(ConvertTest, ConvertBF16F32) { + XlaBuilder builder(TestName()); + + std::vector<bfloat16> all_bfloats(1 << 16); + for (int i = 0; i < all_bfloats.size(); ++i) { + all_bfloats[i].value = i; + } + + std::vector<uint32> expected(all_bfloats.size()); + for (int i = 0; i < expected.size(); ++i) { + expected[i] = (1U << 16) * i; + } + + // Exhaustively test all bf16 to f32 conversions. + xla::XlaOp all_bfloats_bf16 = builder.ConstantR1<bfloat16>(all_bfloats); + xla::XlaOp all_bfloats_f32 = + builder.ConvertElementType(all_bfloats_bf16, F32); + xla::XlaOp all_bfloats_u32 = builder.BitcastConvertType(all_bfloats_f32, U32); + ComputeAndCompareR1<uint32>(&builder, expected, {}); +} + } // namespace } // namespace xla |