From 39cac0519176d1244b0e29d6c28691189ea755ec Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 30 Nov 2017 13:50:13 -0800 Subject: [TF:XLA] Allow bfloat16 types in more places. PiperOrigin-RevId: 177502497 --- tensorflow/compiler/tf2xla/kernels/matmul_op.cc | 4 +-- tensorflow/compiler/tf2xla/lib/util.cc | 3 ++ tensorflow/compiler/tf2xla/xla_helpers.cc | 7 +++- tensorflow/compiler/xla/literal_util.cc | 6 ++-- tensorflow/core/framework/bfloat16_test.cc | 12 +++++++ tensorflow/core/framework/numeric_types.h | 43 ++++++++++++++++++++++--- 6 files changed, 65 insertions(+), 10 deletions(-) diff --git a/tensorflow/compiler/tf2xla/kernels/matmul_op.cc b/tensorflow/compiler/tf2xla/kernels/matmul_op.cc index fcef497e58..a62d233526 100644 --- a/tensorflow/compiler/tf2xla/kernels/matmul_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matmul_op.cc @@ -23,8 +23,8 @@ limitations under the License. namespace tensorflow { namespace { -constexpr std::array kMatmulTypes = { - {DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64}}; +constexpr std::array kMatmulTypes = { + {DT_HALF, DT_BFLOAT16, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64}}; class MatMulOp : public XlaOpKernel { public: diff --git a/tensorflow/compiler/tf2xla/lib/util.cc b/tensorflow/compiler/tf2xla/lib/util.cc index 7ffe0aa6df..943248aedb 100644 --- a/tensorflow/compiler/tf2xla/lib/util.cc +++ b/tensorflow/compiler/tf2xla/lib/util.cc @@ -40,6 +40,9 @@ xla::ComputationDataHandle FloatLiteral(xla::ComputationBuilder* builder, case xla::F16: return builder->ConstantR0(static_cast(value)); break; + case xla::BF16: + return builder->ConstantR0(static_cast(value)); + break; case xla::F32: return builder->ConstantR0(static_cast(value)); break; diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc index 9c3e15d2fa..ec9e535b70 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.cc +++ b/tensorflow/compiler/tf2xla/xla_helpers.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// This file defines helper routines for Tla JIT compilation. +// This file defines helper routines for XLA compilation. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/lib/util.h" @@ -121,6 +121,8 @@ xla::ComputationDataHandle XlaHelpers::One(xla::ComputationBuilder* b, xla::ComputationDataHandle XlaHelpers::Epsilon(xla::ComputationBuilder* b, DataType data_type) { switch (data_type) { + case DT_BFLOAT16: + return b->ConstantR0(bfloat16::epsilon()); case DT_FLOAT: return b->ConstantR0(std::numeric_limits::epsilon()); case DT_DOUBLE: @@ -169,6 +171,9 @@ xla::ComputationDataHandle XlaHelpers::IntegerLiteral( case xla::S16: case xla::U16: LOG(FATAL) << "u16/s16 literals not yet implemented"; + case xla::BF16: + literal = *xla::Literal::CreateR0(static_cast(value)); + break; case xla::F16: literal = *xla::Literal::CreateR0(static_cast(value)); diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc index 93d3cd425f..250df5f4d5 100644 --- a/tensorflow/compiler/xla/literal_util.cc +++ b/tensorflow/compiler/xla/literal_util.cc @@ -252,6 +252,10 @@ Status Literal::Copy(const Literal& src_literal, return *Literal::CreateR0(1); case S64: return *Literal::CreateR0(1); + case F16: + return *Literal::CreateR0(static_cast(1.0f)); + case BF16: + return *Literal::CreateR0(static_cast(1.0f)); case F32: return *Literal::CreateR0(1); case F64: @@ -263,8 +267,6 @@ Status Literal::Copy(const Literal& src_literal, case S16: case U16: LOG(FATAL) << "u16/s16 literals not yet implemented"; - case F16: - return *Literal::CreateR0(static_cast(1.0f)); case TUPLE: LOG(FATAL) << "tuple element type cannot take on value of 1"; case OPAQUE: diff --git a/tensorflow/core/framework/bfloat16_test.cc b/tensorflow/core/framework/bfloat16_test.cc index 6e45338751..17e6209f8e 100644 --- a/tensorflow/core/framework/bfloat16_test.cc +++ b/tensorflow/core/framework/bfloat16_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/framework/bfloat16.h" +#include "tensorflow/core/framework/numeric_types.h" #include "tensorflow/core/lib/core/casts.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test_benchmark.h" @@ -104,6 +105,17 @@ TEST(Bfloat16Test, Conversion) { } } +TEST(Bfloat16Test, Epsilon) { + EXPECT_LT(1.0f, static_cast(bfloat16::epsilon() + bfloat16(1.0f))); + EXPECT_EQ(1.0f, static_cast((bfloat16::epsilon() / bfloat16(2.0f)) + + bfloat16(1.0f))); +} + +TEST(Bfloat16Test, Negate) { + EXPECT_EQ(-3.0f, static_cast(-bfloat16(3.0f))); + EXPECT_EQ(4.5f, static_cast(-bfloat16(-4.5f))); +} + static void BM_FloatToBFloat16(int iters) { testing::StopTiming(); static const int N = 32 << 20; diff --git a/tensorflow/core/framework/numeric_types.h b/tensorflow/core/framework/numeric_types.h index 2b080e13fd..29cac26244 100644 --- a/tensorflow/core/framework/numeric_types.h +++ b/tensorflow/core/framework/numeric_types.h @@ -121,15 +121,48 @@ struct bfloat16 { return static_cast(float(*this)); } + static bfloat16 epsilon() { + bfloat16 x; + x.value = 0x3c00; // 0x1.0p-7 + return x; + } + uint16_t value; }; -inline bool operator==(const bfloat16 a, const bfloat16 b) { - return a.value == b.value; +inline bfloat16 operator+(bfloat16 a, bfloat16 b) { + return bfloat16(static_cast(a) + static_cast(b)); } - -inline bool operator!=(const bfloat16 a, const bfloat16 b) { - return a.value != b.value; +inline bfloat16 operator-(bfloat16 a, bfloat16 b) { + return bfloat16(static_cast(a) - static_cast(b)); +} +inline bfloat16 operator*(bfloat16 a, bfloat16 b) { + return bfloat16(static_cast(a) * static_cast(b)); +} +inline bfloat16 operator/(bfloat16 a, bfloat16 b) { + return bfloat16(static_cast(a) / static_cast(b)); +} +inline bfloat16 operator-(bfloat16 a) { + a.value ^= 0x8000; + return a; +} +inline bool operator<(bfloat16 a, bfloat16 b) { + return static_cast(a) < static_cast(b); +} +inline bool operator<=(bfloat16 a, bfloat16 b) { + return static_cast(a) <= static_cast(b); +} +inline bool operator==(bfloat16 a, bfloat16 b) { + return static_cast(a) == static_cast(b); +} +inline bool operator!=(bfloat16 a, bfloat16 b) { + return static_cast(a) != static_cast(b); +} +inline bool operator>(bfloat16 a, bfloat16 b) { + return static_cast(a) > static_cast(b); +} +inline bool operator>=(bfloat16 a, bfloat16 b) { + return static_cast(a) >= static_cast(b); } } // end namespace tensorflow -- cgit v1.2.3