aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/tests/test_utils.cc
diff options
context:
space:
mode:
authorGravatar Bjarke Hammersholt Roune <broune@google.com>2017-12-08 13:37:33 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-08 13:41:13 -0800
commit2f16f3afdcde16cf0de2f051c57b32cd61a12ec0 (patch)
tree016e5f89025746fed9d6643d9bfde209cc7ce4ee /tensorflow/compiler/xla/tests/test_utils.cc
parentdc04e89bc6f0421bf77ac69f21c1f2f57618f53c (diff)
Add bfloat16 support to the CPU backend.
* A few ops, in particular Convert, directly support bfloat16. * Added an HLO pass HloElementTypeConverter which converts graphs away from bfloat16 without changing the numerics, using Convert ops. This can be improved in many ways, but the feature here is that one can run XLA graphs that use bfloat16 on the CPU backend and get the correct result. PiperOrigin-RevId: 178419829
Diffstat (limited to 'tensorflow/compiler/xla/tests/test_utils.cc')
-rw-r--r--tensorflow/compiler/xla/tests/test_utils.cc16
1 files changed, 16 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc
index 93bce97a3e..780b292d1a 100644
--- a/tensorflow/compiler/xla/tests/test_utils.cc
+++ b/tensorflow/compiler/xla/tests/test_utils.cc
@@ -35,6 +35,19 @@ void PopulateWithRandomFloatingPointData(Literal* literal) {
}));
}
+// The standard library does not have a case for bfloat16, unsurprisingly, so we
+// handle that one specially.
+template <>
+void PopulateWithRandomFloatingPointData<bfloat16>(Literal* literal) {
+ CHECK_EQ(literal->shape().element_type(), BF16);
+ std::minstd_rand0 engine;
+ std::uniform_real_distribution<float> generator(0.0f, 1.0f);
+ TF_CHECK_OK(literal->Populate<bfloat16>(
+ [&](tensorflow::gtl::ArraySlice<int64> /*indices*/) {
+ return static_cast<bfloat16>(generator(engine));
+ }));
+}
+
template <typename IntT>
void PopulateWithRandomIntegralData(Literal* literal) {
CHECK_EQ(literal->shape().element_type(),
@@ -171,6 +184,9 @@ StatusOr<std::unique_ptr<Literal>> MakeFakeLiteral(const Shape& shape) {
}
std::unique_ptr<Literal> literal = Literal::CreateFromShape(shape);
switch (shape.element_type()) {
+ case BF16:
+ PopulateWithRandomFloatingPointData<bfloat16>(literal.get());
+ break;
case F32:
PopulateWithRandomFloatingPointData<float>(literal.get());
break;