aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Peter Hawkins <phawkins@google.com>2017-11-11 07:46:20 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-11 07:50:10 -0800
commit1a70297c95643ab047f4ef069851523ea5d4d5b3 (patch)
treeea0e80c24ec3ea498804840730b26e859936b290
parent07ac134a4b2fa6f40f7fa8f7266b854a529ddcab (diff)
Automated g4 rollback of changelist 175252067
PiperOrigin-RevId: 175401676
-rw-r--r--tensorflow/compiler/tf2xla/type_util.cc3
-rw-r--r--tensorflow/compiler/xla/BUILD1
-rw-r--r--tensorflow/compiler/xla/literal_util.cc99
-rw-r--r--tensorflow/compiler/xla/literal_util.h23
-rw-r--r--tensorflow/compiler/xla/literal_util_test.cc62
-rw-r--r--tensorflow/compiler/xla/primitive_util.cc8
-rw-r--r--tensorflow/compiler/xla/primitive_util.h7
-rw-r--r--tensorflow/compiler/xla/service/backend.cc4
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo_runner.cc3
-rw-r--r--tensorflow/compiler/xla/shape_util.cc1
-rw-r--r--tensorflow/compiler/xla/tests/literal_test_util.cc13
-rw-r--r--tensorflow/compiler/xla/tests/local_client_test_base.cc3
-rw-r--r--tensorflow/compiler/xla/types.h3
-rw-r--r--tensorflow/compiler/xla/xla_data.proto13
-rw-r--r--tensorflow/core/framework/bfloat16.cc30
-rw-r--r--tensorflow/core/framework/bfloat16_test.cc92
-rw-r--r--tensorflow/core/framework/numeric_types.h251
19 files changed, 44 insertions, 580 deletions
diff --git a/tensorflow/compiler/tf2xla/type_util.cc b/tensorflow/compiler/tf2xla/type_util.cc
index c969212a1b..1efbe0ffb1 100644
--- a/tensorflow/compiler/tf2xla/type_util.cc
+++ b/tensorflow/compiler/tf2xla/type_util.cc
@@ -49,9 +49,6 @@ Status DataTypeToPrimitiveType(DataType data_type, xla::PrimitiveType* type) {
case tensorflow::DT_UINT64:
*type = xla::U64;
return Status::OK();
- case tensorflow::DT_BFLOAT16:
- *type = xla::BF16;
- return Status::OK();
case tensorflow::DT_HALF:
*type = xla::F16;
return Status::OK();
diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD
index 515b572b0e..fa4d348ebd 100644
--- a/tensorflow/compiler/xla/BUILD
+++ b/tensorflow/compiler/xla/BUILD
@@ -77,7 +77,6 @@ cc_library(
hdrs = ["types.h"],
visibility = [":friends"],
deps = [
- "//tensorflow/core:framework_lite",
"//tensorflow/core:lib",
"//third_party/eigen3",
],
diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc
index 93d3cd425f..0cb2223ae5 100644
--- a/tensorflow/compiler/xla/literal_util.cc
+++ b/tensorflow/compiler/xla/literal_util.cc
@@ -33,20 +33,6 @@ limitations under the License.
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
-namespace {
-using tensorflow::int64;
-
-constexpr bool kLittleEndian = __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__;
-
-// Converts between little and big endian, assuming elements in the array are 16
-// bits long.
-void ConvertEndianShort(char* bytes, int64 size) {
- CHECK_EQ(size / 2, 0);
- for (int64 i = 0; i < size; i += 2) {
- std::swap(bytes[i], bytes[i + 1]);
- }
-}
-} // namespace
namespace xla {
@@ -183,8 +169,6 @@ Status Literal::Copy(const Literal& src_literal,
return CopyRange<int64>(src_literal, src_base, dest_base, copy_size);
case F16:
return CopyRange<half>(src_literal, src_base, dest_base, copy_size);
- case BF16:
- return CopyRange<bfloat16>(src_literal, src_base, dest_base, copy_size);
case F32:
return CopyRange<float>(src_literal, src_base, dest_base, copy_size);
case F64:
@@ -216,8 +200,6 @@ Status Literal::Copy(const Literal& src_literal,
return *Literal::CreateR0<int64>(0);
case F16:
return *Literal::CreateR0<half>(static_cast<half>(0.0f));
- case BF16:
- return *Literal::CreateR0<bfloat16>(static_cast<bfloat16>(0.0f));
case F32:
return *Literal::CreateR0<float>(0);
case F64:
@@ -303,9 +285,6 @@ Status Literal::Copy(const Literal& src_literal,
case F16:
return *Literal::CreateR0<half>(
static_cast<half>(-std::numeric_limits<float>::infinity()));
- case BF16:
- return *Literal::CreateR0<bfloat16>(
- static_cast<bfloat16>(-std::numeric_limits<float>::infinity()));
case TUPLE:
LOG(FATAL) << "tuple element type has no minimum value";
case OPAQUE:
@@ -342,9 +321,6 @@ Status Literal::Copy(const Literal& src_literal,
case F16:
return *Literal::CreateR0<half>(
static_cast<half>(std::numeric_limits<float>::infinity()));
- case BF16:
- return *Literal::CreateR0<bfloat16>(
- static_cast<bfloat16>(std::numeric_limits<float>::infinity()));
case TUPLE:
LOG(FATAL) << "tuple element type has no maximum value";
case OPAQUE:
@@ -452,7 +428,6 @@ std::unique_ptr<Literal> Literal::Transpose(
// The shape with affine layout resulting from that operation will be
// F32[8,11]{0,1}, since it leaves the original most minor (the 8 sized), the
// most minor.
- //
// Essentially, given MinMaj(Di) the position of the Di dimension within the
// minor to major vector, and given T(Di) the index that the original Di
// dimension has within the transposed array, a layout is affine if
@@ -561,9 +536,6 @@ string Literal::GetAsString(
}
case F16:
return tensorflow::strings::StrCat(Get<half>(multi_index));
- case BF16:
- return tensorflow::strings::StrCat(
- static_cast<float>(Get<bfloat16>(multi_index)));
default:
return tensorflow::strings::StrCat(
"[", PrimitiveType_Name(shape().element_type()), "]");
@@ -771,8 +743,6 @@ void* Literal::MutableInternalData() {
return reinterpret_cast<void*>(c64s_.data());
case F16:
return reinterpret_cast<void*>(f16s_.data());
- case BF16:
- return reinterpret_cast<void*>(bf16s_.data());
default:
LOG(FATAL) << "primitive type not supported in literals: "
<< PrimitiveType_Name(shape().element_type());
@@ -815,9 +785,6 @@ void Literal::Reserve(int64 num_elements) {
case F16:
Resize<half>(num_elements, static_cast<half>(0.0f));
break;
- case BF16:
- Resize<bfloat16>(num_elements, static_cast<bfloat16>(0.0f));
- break;
default:
LOG(FATAL) << "primitive type not supported in literals: "
<< PrimitiveType_Name(shape().element_type());
@@ -857,9 +824,6 @@ tensorflow::Status Literal::ValidateLiteral() const {
case F16:
actual = f16s().size() / sizeof(half);
break;
- case BF16:
- actual = bf16s().size();
- break;
default:
return tensorflow::errors::Unimplemented(
"unhandled element type for literal validation: " +
@@ -956,7 +920,6 @@ StatusOr<std::unique_ptr<Literal>> ConvertIfDestTypeMatches(
CONVERT_IF_TYPES_MATCH(F16)
CONVERT_IF_TYPES_MATCH(F32)
CONVERT_IF_TYPES_MATCH(F64)
- CONVERT_IF_TYPES_MATCH(BF16)
#undef CONVERT_IF_TYPES_MATCH
case C64:
return ConvertToC64<primitive_src_type>(src_literal);
@@ -986,9 +949,8 @@ StatusOr<std::unique_ptr<Literal>> Literal::Convert(
CONVERT_IF_DEST_TYPE_MATCHES(F16)
CONVERT_IF_DEST_TYPE_MATCHES(F32)
CONVERT_IF_DEST_TYPE_MATCHES(F64)
- CONVERT_IF_DEST_TYPE_MATCHES(BF16)
#undef CONVERT_IF_DEST_TYPE_MATCHES
- // Other types are not yet supported.
+ // Other types are not yet supported.
default:
return InvalidArgument("Unimplemented: Convert from type %s to type %s",
PrimitiveType_Name(shape().element_type()).c_str(),
@@ -1057,8 +1019,6 @@ bool Literal::operator==(const Literal& other) const {
return EqualElements<double>(*this, other, 0, &multi_index);
case F16:
return EqualElements<half>(*this, other, 0, &multi_index);
- case BF16:
- return EqualElements<bfloat16>(*this, other, 0, &multi_index);
case C64:
return EqualElements<complex64>(*this, other, 0, &multi_index);
default:
@@ -1168,19 +1128,14 @@ tensorflow::gtl::MutableArraySlice<complex64> Literal::GetMutableArraySlice() {
template <>
tensorflow::gtl::MutableArraySlice<half> Literal::GetMutableArraySlice<half>() {
+ // TODO - there is an endianess problem here. fix it, or wait for uint16
+ // support in protobuf
auto values = mutable_f16s();
return tensorflow::gtl::MutableArraySlice<half>(values->data(),
values->size());
}
template <>
-tensorflow::gtl::MutableArraySlice<bfloat16>
-Literal::GetMutableArraySlice<bfloat16>() {
- auto values = mutable_bf16s();
- return {values->data(), values->size()};
-}
-
-template <>
tensorflow::gtl::ArraySlice<bool> Literal::GetArraySlice<bool>() const {
CHECK_EQ(shape().element_type(), PRED);
return tensorflow::gtl::ArraySlice<bool>(
@@ -1251,12 +1206,6 @@ tensorflow::gtl::ArraySlice<half> Literal::GetArraySlice<half>() const {
}
template <>
-tensorflow::gtl::ArraySlice<bfloat16> Literal::GetArraySlice<bfloat16>() const {
- CHECK_EQ(shape().element_type(), BF16);
- return {bf16s().data(), bf16s().size()};
-}
-
-template <>
tensorflow::gtl::ArraySlice<complex64> Literal::GetArraySlice<complex64>()
const {
CHECK_EQ(shape().element_type(), C64);
@@ -1304,9 +1253,6 @@ bool Literal::IsAll(int8 value) const {
return AllElementsEqualValue<double>(*this, value);
case F16:
return AllElementsEqualValue<half>(*this, static_cast<half>(value));
- case BF16:
- return AllElementsEqualValue<bfloat16>(*this,
- static_cast<bfloat16>(value));
case PRED:
if (value == 0) {
return AllElementsEqualValue<bool>(*this, false);
@@ -1328,9 +1274,6 @@ bool Literal::IsAllFloat(float value) const {
return AllElementsEqualValue<double>(*this, value);
case F16:
return AllElementsEqualValue<half>(*this, static_cast<half>(value));
- case BF16:
- return AllElementsEqualValue<bfloat16>(*this,
- static_cast<bfloat16>(value));
default:
return false;
}
@@ -1367,8 +1310,6 @@ bool Literal::IsZero(tensorflow::gtl::ArraySlice<int64> indices) const {
return Get<complex64>(indices) == complex64(0.0f, 0.0f);
case F16:
return Get<half>(indices) == static_cast<half>(0.0f);
- case BF16:
- return Get<bfloat16>(indices) == static_cast<bfloat16>(0.0f);
case PRED:
return Get<bool>(indices) == false;
default:
@@ -1437,12 +1378,6 @@ void Literal::Resize<half>(int64 num_elements, half value) {
}
template <>
-void Literal::Resize<bfloat16>(int64 num_elements, bfloat16 value) {
- CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements);
- mutable_bf16s()->resize(num_elements, value);
-}
-
-template <>
void Literal::Resize<complex64>(int64 num_elements, complex64 value) {
CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements);
mutable_c64s()->resize(num_elements, value);
@@ -1490,19 +1425,6 @@ LiteralProto Literal::ToProto() const {
*proto.mutable_f16s() =
string(reinterpret_cast<const char*>(f16s_.data()),
f16s_.size() * sizeof(half));
- if (!kLittleEndian) {
- ConvertEndianShort(const_cast<char*>(proto.mutable_f16s()->data()),
- proto.f16s().size());
- }
- break;
- case BF16:
- *proto.mutable_bf16s() =
- string(reinterpret_cast<const char*>(bf16s_.data()),
- bf16s_.size() * sizeof(bfloat16));
- if (!kLittleEndian) {
- ConvertEndianShort(const_cast<char*>(proto.mutable_bf16s()->data()),
- proto.bf16s().size());
- }
break;
case F32:
CopyToRepeatedField(proto.mutable_f32s(), f32s());
@@ -1571,21 +1493,6 @@ void Literal::CopyFromProto(const LiteralProto& literal_proto) {
CHECK_EQ(0, s.size() % sizeof(half));
f16s_ = std::vector<half>(s.size() / sizeof(half));
memcpy(f16s_.data(), s.data(), s.size());
-
- if (!kLittleEndian) {
- ConvertEndianShort(reinterpret_cast<char*>(f16s_.data()), s.size());
- }
- break;
- }
- case BF16: {
- const string& s(literal_proto.bf16s());
- CHECK_EQ(0, s.size() % sizeof(bfloat16));
- bf16s_ = std::vector<bfloat16>(s.size() / sizeof(bfloat16));
- memcpy(bf16s_.data(), s.data(), s.size());
-
- if (!kLittleEndian) {
- ConvertEndianShort(reinterpret_cast<char*>(bf16s_.data()), s.size());
- }
break;
}
case F32:
diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h
index f37e529caf..667f926c46 100644
--- a/tensorflow/compiler/xla/literal_util.h
+++ b/tensorflow/compiler/xla/literal_util.h
@@ -163,11 +163,6 @@ class Literal {
const std::vector<complex64>& c64s() const { return c64s_; }
std::vector<complex64>* mutable_c64s() { return &c64s_; }
- int bf16s_size() const { return bf16s().size(); }
- bfloat16 bf16s(int i) const { return bf16s_[i]; }
- const std::vector<bfloat16>& bf16s() const { return bf16s_; }
- std::vector<bfloat16>* mutable_bf16s() { return &bf16s_; }
-
int tuple_literals_size() const { return tuple_literals().size(); }
const Literal& tuple_literals(int i) const { return tuple_literals_[i]; }
Literal* add_tuple_literals() {
@@ -627,7 +622,6 @@ class Literal {
std::vector<uint16> u16s_;
std::vector<uint32> u32s_;
std::vector<uint64> u64s_;
- std::vector<bfloat16> bf16s_;
std::vector<half> f16s_;
std::vector<float> f32s_;
std::vector<double> f64s_;
@@ -681,9 +675,6 @@ template <>
tensorflow::gtl::ArraySlice<half> Literal::GetArraySlice<half>() const;
template <>
-tensorflow::gtl::ArraySlice<bfloat16> Literal::GetArraySlice<bfloat16>() const;
-
-template <>
tensorflow::gtl::ArraySlice<complex64> Literal::GetArraySlice<complex64>()
const;
@@ -724,9 +715,6 @@ template <>
tensorflow::gtl::MutableArraySlice<half> Literal::GetMutableArraySlice();
template <>
-tensorflow::gtl::MutableArraySlice<bfloat16> Literal::GetMutableArraySlice();
-
-template <>
tensorflow::gtl::MutableArraySlice<complex64> Literal::GetMutableArraySlice();
template <>
@@ -760,9 +748,6 @@ template <>
void Literal::Resize<half>(int64 num_elements, half value);
template <>
-void Literal::Resize<bfloat16>(int64 num_elements, bfloat16 value);
-
-template <>
void Literal::Resize<complex64>(int64 num_elements, complex64 value);
template <typename NativeT>
@@ -1005,14 +990,6 @@ inline half Literal::Get<half>(
return GetArraySlice<half>()[linear_index];
}
-template <>
-inline bfloat16 Literal::Get<bfloat16>(
- tensorflow::gtl::ArraySlice<int64> multi_index) const {
- CHECK(shape().element_type() == BF16);
- int64 linear_index = LinearIndex(multi_index);
- return GetArraySlice<bfloat16>()[linear_index];
-}
-
template <typename NativeT>
void Literal::Set(tensorflow::gtl::ArraySlice<int64> multi_index,
NativeT value) {
diff --git a/tensorflow/compiler/xla/literal_util_test.cc b/tensorflow/compiler/xla/literal_util_test.cc
index 1e08101759..6d596da4ad 100644
--- a/tensorflow/compiler/xla/literal_util_test.cc
+++ b/tensorflow/compiler/xla/literal_util_test.cc
@@ -110,18 +110,6 @@ TEST_F(LiteralUtilTest, LiteralScalarToString) {
auto c64_lit = Literal::CreateR0<complex64>({3.14f, 2.78f});
ASSERT_EQ("(3.14, 2.78)", c64_lit->ToString());
-
- auto bf16_lit = Literal::CreateR0<bfloat16>(static_cast<bfloat16>(0.5f));
- ASSERT_EQ("0.5", bf16_lit->ToString());
-
- // 3.14 will be rounded to 3.125 in bfloat16 format (Round to nearest even).
- auto bf16_lit_truncated =
- Literal::CreateR0<bfloat16>(static_cast<bfloat16>(3.14f));
- ASSERT_EQ("3.140625", bf16_lit_truncated->ToString());
-
- auto bf16_lit_truncated2 =
- Literal::CreateR0<bfloat16>(static_cast<bfloat16>(9.001f));
- ASSERT_EQ("9", bf16_lit_truncated2->ToString());
}
TEST_F(LiteralUtilTest, LiteralVectorToString) {
@@ -409,18 +397,6 @@ TEST_F(LiteralUtilTest, IsAll) {
EXPECT_FALSE(Literal::CreateR2<half>({{h8}, {h9}})->IsAll(8));
EXPECT_FALSE(Literal::CreateR2<half>({{h9}, {h8}})->IsAll(8));
- bfloat16 b8(8.0f);
- bfloat16 b9(9.0f);
-
- EXPECT_TRUE(Literal::CreateR2<bfloat16>({{b8}, {b8}})->IsAll(8));
- EXPECT_FALSE(Literal::CreateR2<bfloat16>({{b8}, {b9}})->IsAll(8));
- EXPECT_FALSE(Literal::CreateR2<bfloat16>({{b9}, {b8}})->IsAll(8));
-
- // 9.001 will be truncated to 9.0
- bfloat16 b91(9.001f);
- bfloat16 b90(9.00f);
- EXPECT_TRUE(Literal::CreateR2<bfloat16>({{b91}, {b90}})->IsAll(9.0));
-
complex64 c8_9 = {8, 9};
EXPECT_FALSE(Literal::CreateR2<complex64>({{c8_9}, {c8_9}})->IsAll(8));
@@ -715,30 +691,6 @@ TEST_F(LiteralUtilTest, PopulateR2C64) {
EXPECT_EQ(output, *expected);
}
-TEST_F(LiteralUtilTest, PopulateWithValueR0BF16) {
- Literal output;
- bfloat16 h(0.25f);
- output.PopulateWithValue<bfloat16>(h, {});
- auto expected = Literal::CreateR0<bfloat16>(h);
- EXPECT_EQ(output, *expected);
-}
-
-TEST_F(LiteralUtilTest, PopulateWithValueR1BF16) {
- Literal output;
- bfloat16 h(0.5f);
- output.PopulateWithValue<bfloat16>(h, {3});
- auto expected = Literal::CreateR1<bfloat16>({h, h, h});
- EXPECT_EQ(output, *expected);
-}
-
-TEST_F(LiteralUtilTest, PopulateWithValueR2BF16) {
- Literal output;
- bfloat16 h(2.0f);
- output.PopulateWithValue<bfloat16>(h, {2, 2});
- auto expected = Literal::CreateR2<bfloat16>({{h, h}, {h, h}});
- EXPECT_EQ(output, *expected);
-}
-
TEST_F(LiteralUtilTest, PopulateWithValueR0F32) {
Literal output;
output.PopulateWithValue<float>(2.5f, {});
@@ -1023,14 +975,6 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) {
{{half(26.0), half(0.0), half(28.0), half(0.0)},
{half(0.0), half(31.0), half(0.0), half(33.0)}},
}}, layout_r4_dim0major_);
- auto bf16 = Literal::CreateR4WithLayout<bfloat16>({{
- {{bfloat16(10.0), bfloat16(0.0), bfloat16(12.0), bfloat16(0.0)},
- {bfloat16(0.0), bfloat16(15.0), bfloat16(0.0), bfloat16(17.0)}},
- {{bfloat16(0.0), bfloat16(19.0), bfloat16(0.0), bfloat16(21.0)},
- {bfloat16(22.0), bfloat16(0.0), bfloat16(24.0), bfloat16(0.0)}},
- {{bfloat16(26.0), bfloat16(0.0), bfloat16(28.0), bfloat16(0.0)},
- {bfloat16(0.0), bfloat16(31.0), bfloat16(0.0), bfloat16(33.0)}},
- }}, layout_r4_dim0major_);
auto f32 = Literal::CreateR4WithLayout<float>({{
{{10.0f, 0.0f, 12.0f, 0.0f}, {0.0f, 15.0f, 0.0f, 17.0f}},
{{0.0f, 19.0f, 0.0f, 21.0f}, {22.0f, 0.0f, 24.0f, 0.0f}},
@@ -1064,12 +1008,6 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) {
conv = s8->Convert(PRED).ConsumeValueOrDie();
EXPECT_EQ(*conv, *pred);
- conv = bf16->Convert(S32).ConsumeValueOrDie();
- EXPECT_EQ(*conv, *s32);
-
- conv = bf16->Convert(F32).ConsumeValueOrDie();
- EXPECT_EQ(*conv, *f32);
-
conv = pred->Convert(S32).ConsumeValueOrDie();
EXPECT_EQ(*conv, *int32_pred);
diff --git a/tensorflow/compiler/xla/primitive_util.cc b/tensorflow/compiler/xla/primitive_util.cc
index 2bce56b7bd..2113b5e06f 100644
--- a/tensorflow/compiler/xla/primitive_util.cc
+++ b/tensorflow/compiler/xla/primitive_util.cc
@@ -79,11 +79,6 @@ PrimitiveType NativeToPrimitiveType<double>() {
}
template <>
-PrimitiveType NativeToPrimitiveType<bfloat16>() {
- return BF16;
-}
-
-template <>
PrimitiveType NativeToPrimitiveType<half>() {
return F16;
}
@@ -94,7 +89,7 @@ PrimitiveType NativeToPrimitiveType<complex64>() {
}
bool IsFloatingPointType(PrimitiveType type) {
- return type == F16 || type == F32 || type == F64 || type == BF16;
+ return type == F16 || type == F32 || type == F64;
}
bool IsComplexType(PrimitiveType type) { return type == C64; }
@@ -123,7 +118,6 @@ int BitWidth(PrimitiveType type) {
case S16:
case U16:
case F16:
- case BF16:
return 16;
case U32:
diff --git a/tensorflow/compiler/xla/primitive_util.h b/tensorflow/compiler/xla/primitive_util.h
index 19c6a13888..a49c8b86fc 100644
--- a/tensorflow/compiler/xla/primitive_util.h
+++ b/tensorflow/compiler/xla/primitive_util.h
@@ -77,8 +77,6 @@ template <>
PrimitiveType NativeToPrimitiveType<double>();
template <>
PrimitiveType NativeToPrimitiveType<half>();
-template <>
-PrimitiveType NativeToPrimitiveType<bfloat16>();
// Complex
template <>
@@ -169,11 +167,6 @@ struct PrimitiveTypeToNative<F16> {
using type = half;
};
-template <>
-struct PrimitiveTypeToNative<BF16> {
- using type = bfloat16;
-};
-
// Complex
template <>
struct PrimitiveTypeToNative<C64> {
diff --git a/tensorflow/compiler/xla/service/backend.cc b/tensorflow/compiler/xla/service/backend.cc
index 05f2d06278..9abe30e3f3 100644
--- a/tensorflow/compiler/xla/service/backend.cc
+++ b/tensorflow/compiler/xla/service/backend.cc
@@ -13,14 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#define EIGEN_USE_THREADS
-
#include "tensorflow/compiler/xla/service/backend.h"
#include <algorithm>
#include <string>
#include <utility>
+#define EIGEN_USE_THREADS
+
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/service/platform_util.h"
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc
index f385829cdf..f8e260dd90 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc
@@ -12,13 +12,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#define EIGEN_USE_THREADS
+
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
#include <memory>
#include <string>
#include <tuple>
+#define EIGEN_USE_THREADS
+
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/client/local_client.h"
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc
index a722d1b3d9..88b77ccdd0 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc
@@ -1450,10 +1450,6 @@ HloEvaluator::HloEvaluator() {
typed_visitors_[F32] = MakeUnique<TypedVisitor<float>>(this);
typed_visitors_[F64] = MakeUnique<TypedVisitor<double>>(this);
typed_visitors_[C64] = MakeUnique<TypedVisitor<complex64>>(this);
-
- typed_visitors_[BF16] = MakeUnique<FunctionVisitor>([](HloInstruction*) {
- return Unimplemented("HloEvaluator: unhandled primitive type: BF16.");
- });
typed_visitors_[TUPLE] = MakeUnique<FunctionVisitor>([](HloInstruction*) {
return Unimplemented("HloEvaluator: unhandled primitive type: TUPLE.");
});
diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc
index 158fb9a546..f463e57d99 100644
--- a/tensorflow/compiler/xla/service/hlo_runner.cc
+++ b/tensorflow/compiler/xla/service/hlo_runner.cc
@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#define EIGEN_USE_THREADS
#include "tensorflow/compiler/xla/service/hlo_runner.h"
@@ -20,6 +19,8 @@ limitations under the License.
#include <string>
#include <utility>
+#define EIGEN_USE_THREADS
+
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/ptr_util.h"
diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc
index 4d0bafa908..b5eb81dfc6 100644
--- a/tensorflow/compiler/xla/shape_util.cc
+++ b/tensorflow/compiler/xla/shape_util.cc
@@ -263,7 +263,6 @@ StatusOr<Shape> MakeShapeWithLayoutInternal(
case S32:
case S64:
case F16:
- case BF16:
case F32:
case F64:
return true;
diff --git a/tensorflow/compiler/xla/tests/literal_test_util.cc b/tensorflow/compiler/xla/tests/literal_test_util.cc
index 75c9a0d3fb..95a52ecd2f 100644
--- a/tensorflow/compiler/xla/tests/literal_test_util.cc
+++ b/tensorflow/compiler/xla/tests/literal_test_util.cc
@@ -116,18 +116,16 @@ template <typename FloatT, typename UnsignedT>
::testing::AssertionResult CompareFloatsBitwiseEqual(FloatT lhs, FloatT rhs) {
auto ulhs = tensorflow::bit_cast<UnsignedT>(lhs);
auto urhs = tensorflow::bit_cast<UnsignedT>(rhs);
- auto lhs_double = static_cast<double>(lhs);
- auto rhs_double = static_cast<double>(rhs);
if (ulhs != urhs) {
return ::testing::AssertionFailure() << tensorflow::strings::Printf(
"floating values are not bitwise-equal; and equality testing "
"was requested: %s=%g=%a vs %s=%g=%a",
tensorflow::strings::StrCat(tensorflow::strings::Hex(ulhs))
.c_str(),
- lhs_double, lhs_double,
+ lhs, lhs,
tensorflow::strings::StrCat(tensorflow::strings::Hex(urhs))
.c_str(),
- rhs_double, rhs_double);
+ rhs, rhs);
}
return ::testing::AssertionSuccess();
}
@@ -151,10 +149,6 @@ template <typename NativeT>
// Specializations for floating types that do bitwise comparisons when equality
// comparison is requested.
template <>
-::testing::AssertionResult CompareEqual<bfloat16>(bfloat16 lhs, bfloat16 rhs) {
- return CompareFloatsBitwiseEqual<bfloat16, uint16>(lhs, rhs);
-}
-template <>
::testing::AssertionResult CompareEqual<float>(float lhs, float rhs) {
return CompareFloatsBitwiseEqual<float, uint32>(lhs, rhs);
}
@@ -244,9 +238,6 @@ bool ExpectLiteralsEqual(const Literal& expected, const Literal& actual,
case U64:
match = ExpectLiteralsEqual<uint64>(expected, actual, &multi_index, 0);
break;
- case BF16:
- match = ExpectLiteralsEqual<bfloat16>(expected, actual, &multi_index, 0);
- break;
case F32:
match = ExpectLiteralsEqual<float>(expected, actual, &multi_index, 0);
break;
diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.cc b/tensorflow/compiler/xla/tests/local_client_test_base.cc
index d98875dbc2..c11e1df0a7 100644
--- a/tensorflow/compiler/xla/tests/local_client_test_base.cc
+++ b/tensorflow/compiler/xla/tests/local_client_test_base.cc
@@ -12,12 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#define EIGEN_USE_THREADS
#include "tensorflow/compiler/xla/tests/local_client_test_base.h"
#include <vector>
+#define EIGEN_USE_THREADS
+
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/map_util.h"
diff --git a/tensorflow/compiler/xla/types.h b/tensorflow/compiler/xla/types.h
index 9fa4297523..3b19ca321c 100644
--- a/tensorflow/compiler/xla/types.h
+++ b/tensorflow/compiler/xla/types.h
@@ -19,7 +19,6 @@ limitations under the License.
#include <complex>
#include "third_party/eigen3/Eigen/Core"
-#include "tensorflow/core/framework/numeric_types.h"
#include "tensorflow/core/platform/types.h"
#include <Eigen/Core>
@@ -33,8 +32,6 @@ using ::tensorflow::int16;
using ::tensorflow::int32;
using ::tensorflow::int64;
-using ::tensorflow::bfloat16;
-
using ::tensorflow::uint8;
using ::tensorflow::uint16;
using ::tensorflow::uint32;
diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto
index eac8f2ff07..7146604708 100644
--- a/tensorflow/compiler/xla/xla_data.proto
+++ b/tensorflow/compiler/xla/xla_data.proto
@@ -46,12 +46,6 @@ enum PrimitiveType {
// converted to f16 from f32 at arbirary points in the computation.
F16 = 10;
F32 = 11;
-
- // Truncated 16 bit floating-point format. This is similar to IEEE's 16 bit
- // floating-point format, but uses 1 bit for the sign, 8 bits for the exponent
- // and 7 bits for the mantissa.
- BF16 = 16;
-
F64 = 12;
// Complex values of fixed width.
@@ -69,8 +63,6 @@ enum PrimitiveType {
// An opaque type used for passing context specific data to a custom
// operation.
OPAQUE = 14;
-
- // Next = 17
}
// Describes the value held inside padding elements.
@@ -318,10 +310,7 @@ message LiteralProto {
repeated double f64s = 9;
repeated float c64s = 12; // Stored as interleaved real, imag floats.
repeated LiteralProto tuple_literals = 10;
- // The F16s and BF16s are encoded in little endian byte order
- bytes f16s = 11;
- bytes bf16s = 13;
- // Next = 14
+ bytes f16s = 11; // Note: the F16s are encoded in little endian byte order
}
message WindowDimension {
diff --git a/tensorflow/core/framework/bfloat16.cc b/tensorflow/core/framework/bfloat16.cc
index 1a6f355c77..a5ac0e1a8d 100644
--- a/tensorflow/core/framework/bfloat16.cc
+++ b/tensorflow/core/framework/bfloat16.cc
@@ -18,24 +18,32 @@ limitations under the License.
namespace tensorflow {
void FloatToBFloat16(const float* src, bfloat16* dst, int64 size) {
- for (int64 i = 0; i < size; ++i) {
- dst[i] = bfloat16(src[i]);
- }
+ const uint16_t* p = reinterpret_cast<const uint16_t*>(src);
+ uint16_t* q = reinterpret_cast<uint16_t*>(dst);
+#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
+ for (; size != 0; p += 2, q++, size--) {
+ *q = p[0];
+ }
+#else
+ for (; size != 0; p += 2, q++, size--) {
+ *q = p[1];
+ }
+#endif
}
void BFloat16ToFloat(const bfloat16* src, float* dst, int64 size) {
const uint16_t* p = reinterpret_cast<const uint16_t*>(src);
uint16_t* q = reinterpret_cast<uint16_t*>(dst);
#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
- for (; size != 0; p++, q += 2, size--) {
- q[0] = *p;
- q[1] = 0;
+ for (; size != 0; p++, q += 2, size--) {
+ q[0] = *p;
+ q[1] = 0;
}
-#else
- for (; size != 0; p++, q += 2, size--) {
- q[0] = 0;
- q[1] = *p;
- }
+#else
+ for (; size != 0; p++, q += 2, size--) {
+ q[0] = 0;
+ q[1] = *p;
+ }
#endif
}
diff --git a/tensorflow/core/framework/bfloat16_test.cc b/tensorflow/core/framework/bfloat16_test.cc
index a25b764ea2..af4e6a4411 100644
--- a/tensorflow/core/framework/bfloat16_test.cc
+++ b/tensorflow/core/framework/bfloat16_test.cc
@@ -15,7 +15,6 @@ limitations under the License.
#include "tensorflow/core/framework/bfloat16.h"
-#include "tensorflow/core/lib/core/casts.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
@@ -28,97 +27,6 @@ TEST(Bfloat16Test, Simple) {
EXPECT_EQ(0x4140, a.value);
}
-float BinaryToFloat(uint32_t sign, uint32_t exponent, uint32_t high_mantissa,
- uint32_t low_mantissa) {
- return bit_cast<float>((sign << 31) + (exponent << 23) +
- (high_mantissa << 16) + low_mantissa);
-}
-
-struct Bfloat16TestParam {
- float input;
- float expected;
-};
-
-class Bfloat16Test : public ::testing::Test,
- public ::testing::WithParamInterface<Bfloat16TestParam> {};
-
-TEST_P(Bfloat16Test, RoundOrTruncate) {
- bfloat16 a(GetParam().input);
- if (std::isnan(GetParam().input)) {
- EXPECT_TRUE(std::isnan(float(a)));
- return;
- }
- EXPECT_EQ(GetParam().expected, float(a));
-}
-
-INSTANTIATE_TEST_CASE_P(
- Bfloat16Test_Instantiation, Bfloat16Test,
- ::testing::Values(
- // More than half.
- Bfloat16TestParam{
- BinaryToFloat(0, 0b10000000, 0b1001000, 0b1111010111000011),
- BinaryToFloat(0, 0b10000000, 0b1001001, 0b0000000000000000)},
-
- Bfloat16TestParam{
- BinaryToFloat(1, 0b10000000, 0b1001000, 0b1111010111000011),
- BinaryToFloat(1, 0b10000000, 0b1001001, 0b0000000000000000)},
-
- // Exact half.
- Bfloat16TestParam{
- BinaryToFloat(0, 0b10000000, 0b1001000, 0b1000000000000000),
- BinaryToFloat(0, 0b10000000, 0b1001000, 0b0000000000000000)},
-
- // NaN stays at NaN.
- Bfloat16TestParam{
- BinaryToFloat(0, 0b11111111, 0b0000000, 0b0000000000000001),
- BinaryToFloat(0, 0b11111111, 0b1000000, 0b0000000000000000)},
-
- // NaN stays at NaN -- no exponents overflow.
- Bfloat16TestParam{
- BinaryToFloat(0, 0b11111111, 0b1111111, 0b1111111111111111),
- BinaryToFloat(0, 0b11111111, 0b1000000, 0b0000000000000000)},
-
- // More than half, round to an odd number.
- Bfloat16TestParam{
- BinaryToFloat(1, 0b10000000, 0b1001000, 0b1100000000000000),
- BinaryToFloat(1, 0b10000000, 0b1001001, 0b0000000000000000)},
-
- // Less than half, truncate.
- Bfloat16TestParam{
- BinaryToFloat(0, 0b10000000, 0b1001000, 0b0000000000000000),
- BinaryToFloat(0, 0b10000000, 0b1001000, 0b0000000000000000)},
-
- // Less than half, truncate.
- Bfloat16TestParam{
- BinaryToFloat(0, 0b10000000, 0b1001000, 0b0100000000000000),
- BinaryToFloat(0, 0b10000000, 0b1001000, 0b0000000000000000)},
-
- // Exact at half, but result is already even.
- Bfloat16TestParam{
- BinaryToFloat(0, 0b10000000, 0b1001000, 0b1000000000000000),
- BinaryToFloat(0, 0b10000000, 0b1001000, 0b0000000000000000)},
-
- // Denormal values.
- Bfloat16TestParam{
- BinaryToFloat(0, 0b00000000, 0b1001000, 0b1000000000000000),
- BinaryToFloat(0, 0b00000000, 0b1001000, 0b0000000000000000)},
- Bfloat16TestParam{
- BinaryToFloat(0, 0b00000000, 0b1111111, 0b1100000000000000),
- BinaryToFloat(0, 0b00000001, 0b0000000, 0b0000000000000000)}));
-TEST(Bfloat16Test, RoundWithFractionOverflow) {
- // Still works with fraction overflow -- round to 4./
- //
- // Input 3.9960938:
- // Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
- // 0 1 0 0 0 0 0 0 1 1 1 1 1 1 1 1100000000000000
- //
- // Should round to 4.0:
- // Sign | Exp (8 bit) | Frac (first 7 bit)
- // 0 1 0 0 0 0 0 1 0 0 0 0 0 0 0
- bfloat16 a(3.9960938f);
- EXPECT_EQ(4.0, float(a));
-}
-
TEST(Bfloat16Test, Conversion) {
float a[100];
for (int i = 0; i < 100; ++i) {
diff --git a/tensorflow/core/framework/numeric_types.h b/tensorflow/core/framework/numeric_types.h
index d005de2af1..a630bee38d 100644
--- a/tensorflow/core/framework/numeric_types.h
+++ b/tensorflow/core/framework/numeric_types.h
@@ -44,262 +44,29 @@ typedef Eigen::QUInt16 quint16;
// see framework/bfloat16.h for description.
struct bfloat16 {
EIGEN_DEVICE_FUNC bfloat16() {}
-
- explicit EIGEN_DEVICE_FUNC bfloat16(float v) {
- uint32_t input;
- memcpy(&input, &v, sizeof(uint32_t));
-
- if ((~input & 0x7f800000) == 0 && (input & 0x007fffff) != 0) {
- // If the value is a NaN, squash it to a qNaN with msb of fraction set,
- // this makes sure after truncation we don't end up with an inf.
- //
- // qNaN magic: All exponent bits set + most significant bit of fraction
- // set.
- value = 0x7fc0;
- } else {
- // Fast rounding algorithm that rounds a half value to nearest even. This
- // reduces expected error when we convert a large number of floats. Here
- // is how it works:
- //
- // Definitions:
- // To convert a float 32 to bfloat16, a float 32 can be viewed as 32 bits
- // with the following tags:
- //
- // Sign | Exp (8 bits) | Frac (23 bits)
- // S EEEEEEEE FFFFFFLRTTTTTTTTTTTTTTT
- //
- // S: Sign bit.
- // E: Exponent bits.
- // F: First 6 bits of fraction.
- // L: Least significant bit of resulting bfloat16 if we truncate away the
- // rest of the float32. This is also the 7th bit of fraction
- // R: Rounding bit, 8th bit of fraction.
- // T: Sticky bits, rest of fraction, 15 bits.
- //
- // To round half to nearest even, there are 3 cases where we want to round
- // down (simply truncate the result of the bits away, which consists of
- // rounding bit and sticky bits) and two cases where we want to round up
- // (truncate then add one to the result).
- //
- // The fast converting algorithm simply adds lsb (L) to 0x7fff (15 bits of
- // 1s) as the rounding bias, adds the rounding bias to the input, then
- // truncates the last 16 bits away.
- //
- // To understand how it works, we can analyze this algorithm case by case:
- //
- // 1. L = 0, R = 0:
- // Expect: round down, this is less than half value.
- //
- // Algorithm:
- // - Rounding bias: 0x7fff + 0 = 0x7fff
- // - Adding rounding bias to input may create any carry, depending on
- // whether there is any value set to 1 in T bits.
- // - R may be set to 1 if there is a carry.
- // - L remains 0.
- // - Note that this case also handles Inf and -Inf, where all fraction
- // bits, including L, R and Ts are all 0. The output remains Inf after
- // this algorithm.
- //
- // 2. L = 1, R = 0:
- // Expect: round down, this is less than half value.
- //
- // Algorithm:
- // - Rounding bias: 0x7fff + 1 = 0x8000
- // - Adding rounding bias to input doesn't change sticky bits but
- // adds 1 to rounding bit.
- // - L remains 1.
- //
- // 3. L = 0, R = 1, all of T are 0:
- // Expect: round down, this is exactly at half, the result is already
- // even (L=0).
- //
- // Algorithm:
- // - Rounding bias: 0x7fff + 0 = 0x7fff
- // - Adding rounding bias to input sets all sticky bits to 1, but
- // doesn't create a carry.
- // - R remains 1.
- // - L remains 0.
- //
- // 4. L = 1, R = 1:
- // Expect: round up, this is exactly at half, the result needs to be
- // round to the next even number.
- //
- // Algorithm:
- // - Rounding bias: 0x7fff + 1 = 0x8000
- // - Adding rounding bias to input doesn't change sticky bits, but
- // creates a carry from rounding bit.
- // - The carry sets L to 0, creates another carry bit and propagate
- // forward to F bits.
- // - If all the F bits are 1, a carry then propagates to the exponent
- // bits, which then creates the minimum value with the next exponent
- // value. Note that we won't have the case where exponents are all 1,
- // since that's either a NaN (handled in the other if condition) or inf
- // (handled in case 1).
- //
- // 5. L = 0, R = 1, any of T is 1:
- // Expect: round up, this is greater than half.
- //
- // Algorithm:
- // - Rounding bias: 0x7fff + 0 = 0x7fff
- // - Adding rounding bias to input creates a carry from sticky bits,
- // sets rounding bit to 0, then create another carry.
- // - The second carry sets L to 1.
- //
- // Examples:
- //
- // Exact half value that is already even:
- // Input:
- // Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
- // S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
- // 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 1000000000000000
- //
- // This falls into case 3. We truncate the rest of 16 bits and no
- // carry is created into F and L:
- //
- // Output:
- // Sign | Exp (8 bit) | Frac (first 7 bit)
- // S E E E E E E E E F F F F F F L
- // 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0
- //
- // Exact half value, round to next even number:
- // Input:
- // Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
- // S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
- // 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1000000000000000
- //
- // This falls into case 4. We create a carry from R and T,
- // which then propagates into L and F:
- //
- // Output:
- // Sign | Exp (8 bit) | Frac (first 7 bit)
- // S E E E E E E E E F F F F F F L
- // 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0
- //
- //
- // Max denormal value round to min normal value:
- // Input:
- // Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
- // S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
- // 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1111111111111111
- //
- // This falls into case 4. We create a carry from R and T,
- // propagate into L and F, which then propagates into exponent
- // bits:
- //
- // Output:
- // Sign | Exp (8 bit) | Frac (first 7 bit)
- // S E E E E E E E E F F F F F F L
- // 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0
- //
- // Max normal value round to Inf:
- // Input:
- // Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
- // S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
- // 0 1 1 1 1 1 1 1 0 1 1 1 1 1 1 1 1111111111111111
- //
- // This falls into case 4. We create a carry from R and T,
- // propagate into L and F, which then propagates into exponent
- // bits:
- //
- // Sign | Exp (8 bit) | Frac (first 7 bit)
- // S E E E E E E E E F F F F F F L
- // 0 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0
- //
- //
- // Least significant bit of resulting bfloat.
- uint32_t lsb = (input >> 16) & 1;
- uint32_t rounding_bias = 0x7fff + lsb;
- input += rounding_bias;
- value = static_cast<uint16_t>(input >> 16);
- }
- }
-
- template <class T>
- explicit EIGEN_DEVICE_FUNC bfloat16(const T& val)
- : bfloat16(static_cast<float>(val)) {}
-
- EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(float) const {
- float result;
-
- uint16_t* q = reinterpret_cast<uint16_t*>(&result);
-
+ EIGEN_DEVICE_FUNC explicit bfloat16(const float v) {
+ const uint16_t* p = reinterpret_cast<const uint16_t*>(&v);
#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
- q[0] = value;
- q[1] = 0;
+ value = p[0];
#else
- q[0] = 0;
- q[1] = value;
+ value = p[1];
#endif
- return result;
- }
-
- EIGEN_DEVICE_FUNC explicit operator bool() const {
- return static_cast<bool>(float(*this));
- }
-
- EIGEN_DEVICE_FUNC explicit operator Eigen::half() const {
- return static_cast<Eigen::half>(float(*this));
- }
-
- EIGEN_DEVICE_FUNC explicit operator short() const {
- return static_cast<short>(float(*this));
- }
-
- EIGEN_DEVICE_FUNC explicit operator int() const {
- return static_cast<int>(float(*this));
- }
-
- EIGEN_DEVICE_FUNC explicit operator char() const {
- return static_cast<char>(float(*this));
- }
-
- EIGEN_DEVICE_FUNC explicit operator signed char() const {
- return static_cast<signed char>(float(*this));
- }
-
- EIGEN_DEVICE_FUNC explicit operator unsigned char() const {
- return static_cast<unsigned char>(float(*this));
- }
-
- EIGEN_DEVICE_FUNC explicit operator unsigned int() const {
- return static_cast<unsigned int>(float(*this));
- }
-
- EIGEN_DEVICE_FUNC explicit operator unsigned long() const {
- return static_cast<unsigned long>(float(*this));
- }
-
- EIGEN_DEVICE_FUNC explicit operator unsigned long long() const {
- return static_cast<unsigned long long>(float(*this));
- }
-
- EIGEN_DEVICE_FUNC explicit operator long long() const {
- return static_cast<long long>(float(*this));
- }
-
- EIGEN_DEVICE_FUNC explicit operator double() const {
- return static_cast<double>(float(*this));
}
uint16_t value;
};
-inline bool operator==(const bfloat16 a, const bfloat16 b) {
- return a.value == b.value;
-}
-
-inline bool operator!=(const bfloat16 a, const bfloat16 b) {
- return a.value != b.value;
-}
-
} // end namespace tensorflow
namespace Eigen {
template <>
struct NumTraits<tensorflow::bfloat16> : GenericNumTraits<uint16_t> {};
-using ::tensorflow::operator==;
-using ::tensorflow::operator!=;
+EIGEN_STRONG_INLINE bool operator==(const tensorflow::bfloat16 a,
+ const tensorflow::bfloat16 b) {
+ return a.value == b.value;
+}
+
} // namespace Eigen
#ifdef COMPILER_MSVC