From 2f16f3afdcde16cf0de2f051c57b32cd61a12ec0 Mon Sep 17 00:00:00 2001 From: Bjarke Hammersholt Roune Date: Fri, 8 Dec 2017 13:37:33 -0800 Subject: Add bfloat16 support to the CPU backend. * A few ops, in particular Convert, directly support bfloat16. * Added an HLO pass HloElementTypeConverter which converts graphs away from bfloat16 without changing the numerics, using Convert ops. This can be improved in many ways, but the feature here is that one can run XLA graphs that use bfloat16 on the CPU backend and get the correct result. PiperOrigin-RevId: 178419829 --- tensorflow/compiler/xla/tests/test_utils.cc | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) (limited to 'tensorflow/compiler/xla/tests/test_utils.cc') diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc index 93bce97a3e..780b292d1a 100644 --- a/tensorflow/compiler/xla/tests/test_utils.cc +++ b/tensorflow/compiler/xla/tests/test_utils.cc @@ -35,6 +35,19 @@ void PopulateWithRandomFloatingPointData(Literal* literal) { })); } +// The standard library does not have a case for bfloat16, unsurprisingly, so we +// handle that one specially. +template <> +void PopulateWithRandomFloatingPointData(Literal* literal) { + CHECK_EQ(literal->shape().element_type(), BF16); + std::minstd_rand0 engine; + std::uniform_real_distribution generator(0.0f, 1.0f); + TF_CHECK_OK(literal->Populate( + [&](tensorflow::gtl::ArraySlice /*indices*/) { + return static_cast(generator(engine)); + })); +} + template void PopulateWithRandomIntegralData(Literal* literal) { CHECK_EQ(literal->shape().element_type(), @@ -171,6 +184,9 @@ StatusOr> MakeFakeLiteral(const Shape& shape) { } std::unique_ptr literal = Literal::CreateFromShape(shape); switch (shape.element_type()) { + case BF16: + PopulateWithRandomFloatingPointData(literal.get()); + break; case F32: PopulateWithRandomFloatingPointData(literal.get()); break; -- cgit v1.2.3