aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Peter Hawkins <phawkins@google.com>2017-11-30 13:50:13 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-30 13:55:02 -0800
commit39cac0519176d1244b0e29d6c28691189ea755ec (patch)
tree1e246a98ad6479d87ade6c4b87c47518d5f052f1
parent15b06e060af59a1e30f4a9079679718aaa68dbc7 (diff)
[TF:XLA] Allow bfloat16 types in more places.
PiperOrigin-RevId: 177502497
-rw-r--r--tensorflow/compiler/tf2xla/kernels/matmul_op.cc4
-rw-r--r--tensorflow/compiler/tf2xla/lib/util.cc3
-rw-r--r--tensorflow/compiler/tf2xla/xla_helpers.cc7
-rw-r--r--tensorflow/compiler/xla/literal_util.cc6
-rw-r--r--tensorflow/core/framework/bfloat16_test.cc12
-rw-r--r--tensorflow/core/framework/numeric_types.h43
6 files changed, 65 insertions, 10 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/matmul_op.cc b/tensorflow/compiler/tf2xla/kernels/matmul_op.cc
index fcef497e58..a62d233526 100644
--- a/tensorflow/compiler/tf2xla/kernels/matmul_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/matmul_op.cc
@@ -23,8 +23,8 @@ limitations under the License.
namespace tensorflow {
namespace {
-constexpr std::array<DataType, 4> kMatmulTypes = {
- {DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64}};
+constexpr std::array<DataType, 5> kMatmulTypes = {
+ {DT_HALF, DT_BFLOAT16, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64}};
class MatMulOp : public XlaOpKernel {
public:
diff --git a/tensorflow/compiler/tf2xla/lib/util.cc b/tensorflow/compiler/tf2xla/lib/util.cc
index 7ffe0aa6df..943248aedb 100644
--- a/tensorflow/compiler/tf2xla/lib/util.cc
+++ b/tensorflow/compiler/tf2xla/lib/util.cc
@@ -40,6 +40,9 @@ xla::ComputationDataHandle FloatLiteral(xla::ComputationBuilder* builder,
case xla::F16:
return builder->ConstantR0<xla::half>(static_cast<xla::half>(value));
break;
+ case xla::BF16:
+ return builder->ConstantR0<bfloat16>(static_cast<bfloat16>(value));
+ break;
case xla::F32:
return builder->ConstantR0<float>(static_cast<float>(value));
break;
diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc
index 9c3e15d2fa..ec9e535b70 100644
--- a/tensorflow/compiler/tf2xla/xla_helpers.cc
+++ b/tensorflow/compiler/tf2xla/xla_helpers.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-// This file defines helper routines for Tla JIT compilation.
+// This file defines helper routines for XLA compilation.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/lib/util.h"
@@ -121,6 +121,8 @@ xla::ComputationDataHandle XlaHelpers::One(xla::ComputationBuilder* b,
xla::ComputationDataHandle XlaHelpers::Epsilon(xla::ComputationBuilder* b,
DataType data_type) {
switch (data_type) {
+ case DT_BFLOAT16:
+ return b->ConstantR0<bfloat16>(bfloat16::epsilon());
case DT_FLOAT:
return b->ConstantR0<float>(std::numeric_limits<float>::epsilon());
case DT_DOUBLE:
@@ -169,6 +171,9 @@ xla::ComputationDataHandle XlaHelpers::IntegerLiteral(
case xla::S16:
case xla::U16:
LOG(FATAL) << "u16/s16 literals not yet implemented";
+ case xla::BF16:
+ literal = *xla::Literal::CreateR0<bfloat16>(static_cast<bfloat16>(value));
+ break;
case xla::F16:
literal =
*xla::Literal::CreateR0<xla::half>(static_cast<xla::half>(value));
diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc
index 93d3cd425f..250df5f4d5 100644
--- a/tensorflow/compiler/xla/literal_util.cc
+++ b/tensorflow/compiler/xla/literal_util.cc
@@ -252,6 +252,10 @@ Status Literal::Copy(const Literal& src_literal,
return *Literal::CreateR0<int32>(1);
case S64:
return *Literal::CreateR0<int64>(1);
+ case F16:
+ return *Literal::CreateR0<half>(static_cast<half>(1.0f));
+ case BF16:
+ return *Literal::CreateR0<bfloat16>(static_cast<bfloat16>(1.0f));
case F32:
return *Literal::CreateR0<float>(1);
case F64:
@@ -263,8 +267,6 @@ Status Literal::Copy(const Literal& src_literal,
case S16:
case U16:
LOG(FATAL) << "u16/s16 literals not yet implemented";
- case F16:
- return *Literal::CreateR0<half>(static_cast<half>(1.0f));
case TUPLE:
LOG(FATAL) << "tuple element type cannot take on value of 1";
case OPAQUE:
diff --git a/tensorflow/core/framework/bfloat16_test.cc b/tensorflow/core/framework/bfloat16_test.cc
index 6e45338751..17e6209f8e 100644
--- a/tensorflow/core/framework/bfloat16_test.cc
+++ b/tensorflow/core/framework/bfloat16_test.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/framework/bfloat16.h"
+#include "tensorflow/core/framework/numeric_types.h"
#include "tensorflow/core/lib/core/casts.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
@@ -104,6 +105,17 @@ TEST(Bfloat16Test, Conversion) {
}
}
+TEST(Bfloat16Test, Epsilon) {
+ EXPECT_LT(1.0f, static_cast<float>(bfloat16::epsilon() + bfloat16(1.0f)));
+ EXPECT_EQ(1.0f, static_cast<float>((bfloat16::epsilon() / bfloat16(2.0f)) +
+ bfloat16(1.0f)));
+}
+
+TEST(Bfloat16Test, Negate) {
+ EXPECT_EQ(-3.0f, static_cast<float>(-bfloat16(3.0f)));
+ EXPECT_EQ(4.5f, static_cast<float>(-bfloat16(-4.5f)));
+}
+
static void BM_FloatToBFloat16(int iters) {
testing::StopTiming();
static const int N = 32 << 20;
diff --git a/tensorflow/core/framework/numeric_types.h b/tensorflow/core/framework/numeric_types.h
index 2b080e13fd..29cac26244 100644
--- a/tensorflow/core/framework/numeric_types.h
+++ b/tensorflow/core/framework/numeric_types.h
@@ -121,15 +121,48 @@ struct bfloat16 {
return static_cast<double>(float(*this));
}
+ static bfloat16 epsilon() {
+ bfloat16 x;
+ x.value = 0x3c00; // 0x1.0p-7
+ return x;
+ }
+
uint16_t value;
};
-inline bool operator==(const bfloat16 a, const bfloat16 b) {
- return a.value == b.value;
+inline bfloat16 operator+(bfloat16 a, bfloat16 b) {
+ return bfloat16(static_cast<float>(a) + static_cast<float>(b));
}
-
-inline bool operator!=(const bfloat16 a, const bfloat16 b) {
- return a.value != b.value;
+inline bfloat16 operator-(bfloat16 a, bfloat16 b) {
+ return bfloat16(static_cast<float>(a) - static_cast<float>(b));
+}
+inline bfloat16 operator*(bfloat16 a, bfloat16 b) {
+ return bfloat16(static_cast<float>(a) * static_cast<float>(b));
+}
+inline bfloat16 operator/(bfloat16 a, bfloat16 b) {
+ return bfloat16(static_cast<float>(a) / static_cast<float>(b));
+}
+inline bfloat16 operator-(bfloat16 a) {
+ a.value ^= 0x8000;
+ return a;
+}
+inline bool operator<(bfloat16 a, bfloat16 b) {
+ return static_cast<float>(a) < static_cast<float>(b);
+}
+inline bool operator<=(bfloat16 a, bfloat16 b) {
+ return static_cast<float>(a) <= static_cast<float>(b);
+}
+inline bool operator==(bfloat16 a, bfloat16 b) {
+ return static_cast<float>(a) == static_cast<float>(b);
+}
+inline bool operator!=(bfloat16 a, bfloat16 b) {
+ return static_cast<float>(a) != static_cast<float>(b);
+}
+inline bool operator>(bfloat16 a, bfloat16 b) {
+ return static_cast<float>(a) > static_cast<float>(b);
+}
+inline bool operator>=(bfloat16 a, bfloat16 b) {
+ return static_cast<float>(a) >= static_cast<float>(b);
}
} // end namespace tensorflow