aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/literal_util.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/literal_util.cc')
-rw-r--r--tensorflow/compiler/xla/literal_util.cc121
1 files changed, 10 insertions, 111 deletions
diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc
index 93d3cd425f..fda791401d 100644
--- a/tensorflow/compiler/xla/literal_util.cc
+++ b/tensorflow/compiler/xla/literal_util.cc
@@ -33,20 +33,6 @@ limitations under the License.
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
-namespace {
-using tensorflow::int64;
-
-constexpr bool kLittleEndian = __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__;
-
-// Converts between little and big endian, assuming elements in the array are 16
-// bits long.
-void ConvertEndianShort(char* bytes, int64 size) {
- CHECK_EQ(size / 2, 0);
- for (int64 i = 0; i < size; i += 2) {
- std::swap(bytes[i], bytes[i + 1]);
- }
-}
-} // namespace
namespace xla {
@@ -183,8 +169,6 @@ Status Literal::Copy(const Literal& src_literal,
return CopyRange<int64>(src_literal, src_base, dest_base, copy_size);
case F16:
return CopyRange<half>(src_literal, src_base, dest_base, copy_size);
- case BF16:
- return CopyRange<bfloat16>(src_literal, src_base, dest_base, copy_size);
case F32:
return CopyRange<float>(src_literal, src_base, dest_base, copy_size);
case F64:
@@ -216,8 +200,6 @@ Status Literal::Copy(const Literal& src_literal,
return *Literal::CreateR0<int64>(0);
case F16:
return *Literal::CreateR0<half>(static_cast<half>(0.0f));
- case BF16:
- return *Literal::CreateR0<bfloat16>(static_cast<bfloat16>(0.0f));
case F32:
return *Literal::CreateR0<float>(0);
case F64:
@@ -303,9 +285,6 @@ Status Literal::Copy(const Literal& src_literal,
case F16:
return *Literal::CreateR0<half>(
static_cast<half>(-std::numeric_limits<float>::infinity()));
- case BF16:
- return *Literal::CreateR0<bfloat16>(
- static_cast<bfloat16>(-std::numeric_limits<float>::infinity()));
case TUPLE:
LOG(FATAL) << "tuple element type has no minimum value";
case OPAQUE:
@@ -342,9 +321,6 @@ Status Literal::Copy(const Literal& src_literal,
case F16:
return *Literal::CreateR0<half>(
static_cast<half>(std::numeric_limits<float>::infinity()));
- case BF16:
- return *Literal::CreateR0<bfloat16>(
- static_cast<bfloat16>(std::numeric_limits<float>::infinity()));
case TUPLE:
LOG(FATAL) << "tuple element type has no maximum value";
case OPAQUE:
@@ -452,7 +428,6 @@ std::unique_ptr<Literal> Literal::Transpose(
// The shape with affine layout resulting from that operation will be
// F32[8,11]{0,1}, since it leaves the original most minor (the 8 sized), the
// most minor.
- //
// Essentially, given MinMaj(Di) the position of the Di dimension within the
// minor to major vector, and given T(Di) the index that the original Di
// dimension has within the transposed array, a layout is affine if
@@ -561,9 +536,6 @@ string Literal::GetAsString(
}
case F16:
return tensorflow::strings::StrCat(Get<half>(multi_index));
- case BF16:
- return tensorflow::strings::StrCat(
- static_cast<float>(Get<bfloat16>(multi_index)));
default:
return tensorflow::strings::StrCat(
"[", PrimitiveType_Name(shape().element_type()), "]");
@@ -597,17 +569,9 @@ int64 Literal::LinearIndex(
return IndexUtil::MultidimensionalIndexToLinearIndex(shape(), multi_index);
}
-string Literal::ToString(bool print_layout) const {
+string Literal::ToString() const {
std::vector<string> pieces;
- auto shape_to_string = [print_layout](const Shape& shape) {
- if (print_layout) {
- return ShapeUtil::HumanStringWithLayout(shape);
- } else {
- return ShapeUtil::HumanString(shape);
- }
- };
-
auto element_to_string =
[this](tensorflow::gtl::ArraySlice<int64> indices) -> string {
PrimitiveType element_type = shape().element_type();
@@ -621,7 +585,7 @@ string Literal::ToString(bool print_layout) const {
// TODO(b/32894291): refactor this code to reduce code duplication.
if (ShapeUtil::IsTuple(shape())) {
- pieces.push_back(shape_to_string(shape()));
+ pieces.push_back(ShapeUtil::HumanString(shape()));
pieces.push_back(" (\n");
pieces.push_back(tensorflow::str_util::Join(
tuple_literals(), ",\n", [](string* out, const Literal& element) {
@@ -637,7 +601,7 @@ string Literal::ToString(bool print_layout) const {
}
pieces.push_back("}");
} else if (ShapeUtil::Rank(shape()) == 2) {
- pieces.push_back(shape_to_string(shape()));
+ pieces.push_back(ShapeUtil::HumanString(shape()));
pieces.push_back(" {\n");
for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) {
pieces.push_back(" { ");
@@ -649,7 +613,7 @@ string Literal::ToString(bool print_layout) const {
}
pieces.push_back("}");
} else if (ShapeUtil::Rank(shape()) == 3) {
- pieces.push_back(shape_to_string(shape()));
+ pieces.push_back(ShapeUtil::HumanString(shape()));
pieces.push_back(" {\n");
for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) {
pieces.push_back(i0 > 0 ? ",\n{" : "{");
@@ -664,7 +628,7 @@ string Literal::ToString(bool print_layout) const {
}
pieces.push_back("\n}");
} else if (ShapeUtil::Rank(shape()) == 4) {
- pieces.push_back(shape_to_string(shape()));
+ pieces.push_back(ShapeUtil::HumanString(shape()));
pieces.push_back(" {\n");
for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) {
pieces.push_back(tensorflow::strings::Printf(" { /*i0=%lld*/\n", i0));
@@ -685,7 +649,7 @@ string Literal::ToString(bool print_layout) const {
}
pieces.push_back("}");
} else if (ShapeUtil::Rank(shape()) == 5) {
- pieces.push_back(shape_to_string(shape()));
+ pieces.push_back(ShapeUtil::HumanString(shape()));
pieces.push_back(" {\n");
for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) {
pieces.push_back(tensorflow::strings::Printf(" { /*i0=%lld*/\n", i0));
@@ -712,7 +676,7 @@ string Literal::ToString(bool print_layout) const {
}
pieces.push_back("}");
} else {
- pieces.push_back(shape_to_string(shape()));
+ pieces.push_back(ShapeUtil::HumanString(shape()));
pieces.push_back(" {...}");
}
@@ -771,8 +735,6 @@ void* Literal::MutableInternalData() {
return reinterpret_cast<void*>(c64s_.data());
case F16:
return reinterpret_cast<void*>(f16s_.data());
- case BF16:
- return reinterpret_cast<void*>(bf16s_.data());
default:
LOG(FATAL) << "primitive type not supported in literals: "
<< PrimitiveType_Name(shape().element_type());
@@ -815,9 +777,6 @@ void Literal::Reserve(int64 num_elements) {
case F16:
Resize<half>(num_elements, static_cast<half>(0.0f));
break;
- case BF16:
- Resize<bfloat16>(num_elements, static_cast<bfloat16>(0.0f));
- break;
default:
LOG(FATAL) << "primitive type not supported in literals: "
<< PrimitiveType_Name(shape().element_type());
@@ -857,9 +816,6 @@ tensorflow::Status Literal::ValidateLiteral() const {
case F16:
actual = f16s().size() / sizeof(half);
break;
- case BF16:
- actual = bf16s().size();
- break;
default:
return tensorflow::errors::Unimplemented(
"unhandled element type for literal validation: " +
@@ -956,7 +912,6 @@ StatusOr<std::unique_ptr<Literal>> ConvertIfDestTypeMatches(
CONVERT_IF_TYPES_MATCH(F16)
CONVERT_IF_TYPES_MATCH(F32)
CONVERT_IF_TYPES_MATCH(F64)
- CONVERT_IF_TYPES_MATCH(BF16)
#undef CONVERT_IF_TYPES_MATCH
case C64:
return ConvertToC64<primitive_src_type>(src_literal);
@@ -986,9 +941,8 @@ StatusOr<std::unique_ptr<Literal>> Literal::Convert(
CONVERT_IF_DEST_TYPE_MATCHES(F16)
CONVERT_IF_DEST_TYPE_MATCHES(F32)
CONVERT_IF_DEST_TYPE_MATCHES(F64)
- CONVERT_IF_DEST_TYPE_MATCHES(BF16)
#undef CONVERT_IF_DEST_TYPE_MATCHES
- // Other types are not yet supported.
+ // Other types are not yet supported.
default:
return InvalidArgument("Unimplemented: Convert from type %s to type %s",
PrimitiveType_Name(shape().element_type()).c_str(),
@@ -1057,8 +1011,6 @@ bool Literal::operator==(const Literal& other) const {
return EqualElements<double>(*this, other, 0, &multi_index);
case F16:
return EqualElements<half>(*this, other, 0, &multi_index);
- case BF16:
- return EqualElements<bfloat16>(*this, other, 0, &multi_index);
case C64:
return EqualElements<complex64>(*this, other, 0, &multi_index);
default:
@@ -1168,19 +1120,14 @@ tensorflow::gtl::MutableArraySlice<complex64> Literal::GetMutableArraySlice() {
template <>
tensorflow::gtl::MutableArraySlice<half> Literal::GetMutableArraySlice<half>() {
+ // TODO - there is an endianess problem here. fix it, or wait for uint16
+ // support in protobuf
auto values = mutable_f16s();
return tensorflow::gtl::MutableArraySlice<half>(values->data(),
values->size());
}
template <>
-tensorflow::gtl::MutableArraySlice<bfloat16>
-Literal::GetMutableArraySlice<bfloat16>() {
- auto values = mutable_bf16s();
- return {values->data(), values->size()};
-}
-
-template <>
tensorflow::gtl::ArraySlice<bool> Literal::GetArraySlice<bool>() const {
CHECK_EQ(shape().element_type(), PRED);
return tensorflow::gtl::ArraySlice<bool>(
@@ -1251,12 +1198,6 @@ tensorflow::gtl::ArraySlice<half> Literal::GetArraySlice<half>() const {
}
template <>
-tensorflow::gtl::ArraySlice<bfloat16> Literal::GetArraySlice<bfloat16>() const {
- CHECK_EQ(shape().element_type(), BF16);
- return {bf16s().data(), bf16s().size()};
-}
-
-template <>
tensorflow::gtl::ArraySlice<complex64> Literal::GetArraySlice<complex64>()
const {
CHECK_EQ(shape().element_type(), C64);
@@ -1304,9 +1245,6 @@ bool Literal::IsAll(int8 value) const {
return AllElementsEqualValue<double>(*this, value);
case F16:
return AllElementsEqualValue<half>(*this, static_cast<half>(value));
- case BF16:
- return AllElementsEqualValue<bfloat16>(*this,
- static_cast<bfloat16>(value));
case PRED:
if (value == 0) {
return AllElementsEqualValue<bool>(*this, false);
@@ -1328,9 +1266,6 @@ bool Literal::IsAllFloat(float value) const {
return AllElementsEqualValue<double>(*this, value);
case F16:
return AllElementsEqualValue<half>(*this, static_cast<half>(value));
- case BF16:
- return AllElementsEqualValue<bfloat16>(*this,
- static_cast<bfloat16>(value));
default:
return false;
}
@@ -1367,8 +1302,6 @@ bool Literal::IsZero(tensorflow::gtl::ArraySlice<int64> indices) const {
return Get<complex64>(indices) == complex64(0.0f, 0.0f);
case F16:
return Get<half>(indices) == static_cast<half>(0.0f);
- case BF16:
- return Get<bfloat16>(indices) == static_cast<bfloat16>(0.0f);
case PRED:
return Get<bool>(indices) == false;
default:
@@ -1437,12 +1370,6 @@ void Literal::Resize<half>(int64 num_elements, half value) {
}
template <>
-void Literal::Resize<bfloat16>(int64 num_elements, bfloat16 value) {
- CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements);
- mutable_bf16s()->resize(num_elements, value);
-}
-
-template <>
void Literal::Resize<complex64>(int64 num_elements, complex64 value) {
CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements);
mutable_c64s()->resize(num_elements, value);
@@ -1490,19 +1417,6 @@ LiteralProto Literal::ToProto() const {
*proto.mutable_f16s() =
string(reinterpret_cast<const char*>(f16s_.data()),
f16s_.size() * sizeof(half));
- if (!kLittleEndian) {
- ConvertEndianShort(const_cast<char*>(proto.mutable_f16s()->data()),
- proto.f16s().size());
- }
- break;
- case BF16:
- *proto.mutable_bf16s() =
- string(reinterpret_cast<const char*>(bf16s_.data()),
- bf16s_.size() * sizeof(bfloat16));
- if (!kLittleEndian) {
- ConvertEndianShort(const_cast<char*>(proto.mutable_bf16s()->data()),
- proto.bf16s().size());
- }
break;
case F32:
CopyToRepeatedField(proto.mutable_f32s(), f32s());
@@ -1571,21 +1485,6 @@ void Literal::CopyFromProto(const LiteralProto& literal_proto) {
CHECK_EQ(0, s.size() % sizeof(half));
f16s_ = std::vector<half>(s.size() / sizeof(half));
memcpy(f16s_.data(), s.data(), s.size());
-
- if (!kLittleEndian) {
- ConvertEndianShort(reinterpret_cast<char*>(f16s_.data()), s.size());
- }
- break;
- }
- case BF16: {
- const string& s(literal_proto.bf16s());
- CHECK_EQ(0, s.size() % sizeof(bfloat16));
- bf16s_ = std::vector<bfloat16>(s.size() / sizeof(bfloat16));
- memcpy(bf16s_.data(), s.data(), s.size());
-
- if (!kLittleEndian) {
- ConvertEndianShort(reinterpret_cast<char*>(bf16s_.data()), s.size());
- }
break;
}
case F32: