diff options
author | Bjarke Hammersholt Roune <broune@google.com> | 2017-12-08 13:37:33 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-12-08 13:41:13 -0800 |
commit | 2f16f3afdcde16cf0de2f051c57b32cd61a12ec0 (patch) | |
tree | 016e5f89025746fed9d6643d9bfde209cc7ce4ee /tensorflow/compiler/xla/tests/test_utils.cc | |
parent | dc04e89bc6f0421bf77ac69f21c1f2f57618f53c (diff) |
Add bfloat16 support to the CPU backend.
* A few ops, in particular Convert, directly support bfloat16.
* Added an HLO pass HloElementTypeConverter which converts graphs away from bfloat16
without changing the numerics, using Convert ops.
This can be improved in many ways, but the feature here is that one can run XLA graphs that use bfloat16 on the CPU backend and get the correct result.
PiperOrigin-RevId: 178419829
Diffstat (limited to 'tensorflow/compiler/xla/tests/test_utils.cc')
-rw-r--r-- | tensorflow/compiler/xla/tests/test_utils.cc | 16 |
1 files changed, 16 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc index 93bce97a3e..780b292d1a 100644 --- a/tensorflow/compiler/xla/tests/test_utils.cc +++ b/tensorflow/compiler/xla/tests/test_utils.cc @@ -35,6 +35,19 @@ void PopulateWithRandomFloatingPointData(Literal* literal) { })); } +// The standard library does not have a case for bfloat16, unsurprisingly, so we +// handle that one specially. +template <> +void PopulateWithRandomFloatingPointData<bfloat16>(Literal* literal) { + CHECK_EQ(literal->shape().element_type(), BF16); + std::minstd_rand0 engine; + std::uniform_real_distribution<float> generator(0.0f, 1.0f); + TF_CHECK_OK(literal->Populate<bfloat16>( + [&](tensorflow::gtl::ArraySlice<int64> /*indices*/) { + return static_cast<bfloat16>(generator(engine)); + })); +} + template <typename IntT> void PopulateWithRandomIntegralData(Literal* literal) { CHECK_EQ(literal->shape().element_type(), @@ -171,6 +184,9 @@ StatusOr<std::unique_ptr<Literal>> MakeFakeLiteral(const Shape& shape) { } std::unique_ptr<Literal> literal = Literal::CreateFromShape(shape); switch (shape.element_type()) { + case BF16: + PopulateWithRandomFloatingPointData<bfloat16>(literal.get()); + break; case F32: PopulateWithRandomFloatingPointData<float>(literal.get()); break; |