aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/tests
diff options
context:
space:
mode:
authorGravatar Bjarke Hammersholt Roune <broune@google.com>2017-12-08 13:37:33 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-08 13:41:13 -0800
commit2f16f3afdcde16cf0de2f051c57b32cd61a12ec0 (patch)
tree016e5f89025746fed9d6643d9bfde209cc7ce4ee /tensorflow/compiler/xla/tests
parentdc04e89bc6f0421bf77ac69f21c1f2f57618f53c (diff)
Add bfloat16 support to the CPU backend.
* A few ops, in particular Convert, directly support bfloat16. * Added an HLO pass HloElementTypeConverter which converts graphs away from bfloat16 without changing the numerics, using Convert ops. This can be improved in many ways, but the feature here is that one can run XLA graphs that use bfloat16 on the CPU backend and get the correct result. PiperOrigin-RevId: 178419829
Diffstat (limited to 'tensorflow/compiler/xla/tests')
-rw-r--r--tensorflow/compiler/xla/tests/BUILD2
-rw-r--r--tensorflow/compiler/xla/tests/reduce_window_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/test_utils.cc16
3 files changed, 17 insertions, 3 deletions
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index 6f03f1a4e0..6af01ae80d 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -802,8 +802,6 @@ xla_test(
name = "bfloat16_test",
srcs = ["bfloat16_test.cc"],
blacklisted_backends = [
- "cpu",
- "cpu_parallel",
"gpu",
],
shard_count = 40,
diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc
index 330575a02e..b32df74312 100644
--- a/tensorflow/compiler/xla/tests/reduce_window_test.cc
+++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc
@@ -53,7 +53,7 @@ class ReduceWindowTestBase : public ClientLibraryTestBase {
public:
ErrorSpec DefaultErrorSpec() const {
if (use_bfloat16()) {
- return ErrorSpec(1e-1, 3e-2);
+ return ErrorSpec(1e-1, 5e-2);
} else {
return ErrorSpec(1e-3, 1e-3);
}
diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc
index 93bce97a3e..780b292d1a 100644
--- a/tensorflow/compiler/xla/tests/test_utils.cc
+++ b/tensorflow/compiler/xla/tests/test_utils.cc
@@ -35,6 +35,19 @@ void PopulateWithRandomFloatingPointData(Literal* literal) {
}));
}
+// The standard library does not have a case for bfloat16, unsurprisingly, so we
+// handle that one specially.
+template <>
+void PopulateWithRandomFloatingPointData<bfloat16>(Literal* literal) {
+ CHECK_EQ(literal->shape().element_type(), BF16);
+ std::minstd_rand0 engine;
+ std::uniform_real_distribution<float> generator(0.0f, 1.0f);
+ TF_CHECK_OK(literal->Populate<bfloat16>(
+ [&](tensorflow::gtl::ArraySlice<int64> /*indices*/) {
+ return static_cast<bfloat16>(generator(engine));
+ }));
+}
+
template <typename IntT>
void PopulateWithRandomIntegralData(Literal* literal) {
CHECK_EQ(literal->shape().element_type(),
@@ -171,6 +184,9 @@ StatusOr<std::unique_ptr<Literal>> MakeFakeLiteral(const Shape& shape) {
}
std::unique_ptr<Literal> literal = Literal::CreateFromShape(shape);
switch (shape.element_type()) {
+ case BF16:
+ PopulateWithRandomFloatingPointData<bfloat16>(literal.get());
+ break;
case F32:
PopulateWithRandomFloatingPointData<float>(literal.get());
break;