diff options
author | 2017-12-08 13:37:33 -0800 | |
---|---|---|
committer | 2017-12-08 13:41:13 -0800 | |
commit | 2f16f3afdcde16cf0de2f051c57b32cd61a12ec0 (patch) | |
tree | 016e5f89025746fed9d6643d9bfde209cc7ce4ee /tensorflow/compiler/xla/tests | |
parent | dc04e89bc6f0421bf77ac69f21c1f2f57618f53c (diff) |
Add bfloat16 support to the CPU backend.
* A few ops, in particular Convert, directly support bfloat16.
* Added an HLO pass HloElementTypeConverter which converts graphs away from bfloat16
without changing the numerics, using Convert ops.
This can be improved in many ways, but the feature here is that one can run XLA graphs that use bfloat16 on the CPU backend and get the correct result.
PiperOrigin-RevId: 178419829
Diffstat (limited to 'tensorflow/compiler/xla/tests')
-rw-r--r-- | tensorflow/compiler/xla/tests/BUILD | 2 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tests/reduce_window_test.cc | 2 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tests/test_utils.cc | 16 |
3 files changed, 17 insertions, 3 deletions
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 6f03f1a4e0..6af01ae80d 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -802,8 +802,6 @@ xla_test( name = "bfloat16_test", srcs = ["bfloat16_test.cc"], blacklisted_backends = [ - "cpu", - "cpu_parallel", "gpu", ], shard_count = 40, diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc index 330575a02e..b32df74312 100644 --- a/tensorflow/compiler/xla/tests/reduce_window_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc @@ -53,7 +53,7 @@ class ReduceWindowTestBase : public ClientLibraryTestBase { public: ErrorSpec DefaultErrorSpec() const { if (use_bfloat16()) { - return ErrorSpec(1e-1, 3e-2); + return ErrorSpec(1e-1, 5e-2); } else { return ErrorSpec(1e-3, 1e-3); } diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc index 93bce97a3e..780b292d1a 100644 --- a/tensorflow/compiler/xla/tests/test_utils.cc +++ b/tensorflow/compiler/xla/tests/test_utils.cc @@ -35,6 +35,19 @@ void PopulateWithRandomFloatingPointData(Literal* literal) { })); } +// The standard library does not have a case for bfloat16, unsurprisingly, so we +// handle that one specially. +template <> +void PopulateWithRandomFloatingPointData<bfloat16>(Literal* literal) { + CHECK_EQ(literal->shape().element_type(), BF16); + std::minstd_rand0 engine; + std::uniform_real_distribution<float> generator(0.0f, 1.0f); + TF_CHECK_OK(literal->Populate<bfloat16>( + [&](tensorflow::gtl::ArraySlice<int64> /*indices*/) { + return static_cast<bfloat16>(generator(engine)); + })); +} + template <typename IntT> void PopulateWithRandomIntegralData(Literal* literal) { CHECK_EQ(literal->shape().element_type(), @@ -171,6 +184,9 @@ StatusOr<std::unique_ptr<Literal>> MakeFakeLiteral(const Shape& shape) { } std::unique_ptr<Literal> literal = Literal::CreateFromShape(shape); switch (shape.element_type()) { + case BF16: + PopulateWithRandomFloatingPointData<bfloat16>(literal.get()); + break; case F32: PopulateWithRandomFloatingPointData<float>(literal.get()); break; |