diff options
Diffstat (limited to 'tensorflow/compiler/xla/literal_util.cc')
-rw-r--r-- | tensorflow/compiler/xla/literal_util.cc | 121 |
1 files changed, 10 insertions, 111 deletions
diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc index 93d3cd425f..fda791401d 100644 --- a/tensorflow/compiler/xla/literal_util.cc +++ b/tensorflow/compiler/xla/literal_util.cc @@ -33,20 +33,6 @@ limitations under the License. #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" -namespace { -using tensorflow::int64; - -constexpr bool kLittleEndian = __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__; - -// Converts between little and big endian, assuming elements in the array are 16 -// bits long. -void ConvertEndianShort(char* bytes, int64 size) { - CHECK_EQ(size / 2, 0); - for (int64 i = 0; i < size; i += 2) { - std::swap(bytes[i], bytes[i + 1]); - } -} -} // namespace namespace xla { @@ -183,8 +169,6 @@ Status Literal::Copy(const Literal& src_literal, return CopyRange<int64>(src_literal, src_base, dest_base, copy_size); case F16: return CopyRange<half>(src_literal, src_base, dest_base, copy_size); - case BF16: - return CopyRange<bfloat16>(src_literal, src_base, dest_base, copy_size); case F32: return CopyRange<float>(src_literal, src_base, dest_base, copy_size); case F64: @@ -216,8 +200,6 @@ Status Literal::Copy(const Literal& src_literal, return *Literal::CreateR0<int64>(0); case F16: return *Literal::CreateR0<half>(static_cast<half>(0.0f)); - case BF16: - return *Literal::CreateR0<bfloat16>(static_cast<bfloat16>(0.0f)); case F32: return *Literal::CreateR0<float>(0); case F64: @@ -303,9 +285,6 @@ Status Literal::Copy(const Literal& src_literal, case F16: return *Literal::CreateR0<half>( static_cast<half>(-std::numeric_limits<float>::infinity())); - case BF16: - return *Literal::CreateR0<bfloat16>( - static_cast<bfloat16>(-std::numeric_limits<float>::infinity())); case TUPLE: LOG(FATAL) << "tuple element type has no minimum value"; case OPAQUE: @@ -342,9 +321,6 @@ Status Literal::Copy(const Literal& src_literal, case F16: return *Literal::CreateR0<half>( static_cast<half>(std::numeric_limits<float>::infinity())); - case BF16: - return *Literal::CreateR0<bfloat16>( - static_cast<bfloat16>(std::numeric_limits<float>::infinity())); case TUPLE: LOG(FATAL) << "tuple element type has no maximum value"; case OPAQUE: @@ -452,7 +428,6 @@ std::unique_ptr<Literal> Literal::Transpose( // The shape with affine layout resulting from that operation will be // F32[8,11]{0,1}, since it leaves the original most minor (the 8 sized), the // most minor. - // // Essentially, given MinMaj(Di) the position of the Di dimension within the // minor to major vector, and given T(Di) the index that the original Di // dimension has within the transposed array, a layout is affine if @@ -561,9 +536,6 @@ string Literal::GetAsString( } case F16: return tensorflow::strings::StrCat(Get<half>(multi_index)); - case BF16: - return tensorflow::strings::StrCat( - static_cast<float>(Get<bfloat16>(multi_index))); default: return tensorflow::strings::StrCat( "[", PrimitiveType_Name(shape().element_type()), "]"); @@ -597,17 +569,9 @@ int64 Literal::LinearIndex( return IndexUtil::MultidimensionalIndexToLinearIndex(shape(), multi_index); } -string Literal::ToString(bool print_layout) const { +string Literal::ToString() const { std::vector<string> pieces; - auto shape_to_string = [print_layout](const Shape& shape) { - if (print_layout) { - return ShapeUtil::HumanStringWithLayout(shape); - } else { - return ShapeUtil::HumanString(shape); - } - }; - auto element_to_string = [this](tensorflow::gtl::ArraySlice<int64> indices) -> string { PrimitiveType element_type = shape().element_type(); @@ -621,7 +585,7 @@ string Literal::ToString(bool print_layout) const { // TODO(b/32894291): refactor this code to reduce code duplication. if (ShapeUtil::IsTuple(shape())) { - pieces.push_back(shape_to_string(shape())); + pieces.push_back(ShapeUtil::HumanString(shape())); pieces.push_back(" (\n"); pieces.push_back(tensorflow::str_util::Join( tuple_literals(), ",\n", [](string* out, const Literal& element) { @@ -637,7 +601,7 @@ string Literal::ToString(bool print_layout) const { } pieces.push_back("}"); } else if (ShapeUtil::Rank(shape()) == 2) { - pieces.push_back(shape_to_string(shape())); + pieces.push_back(ShapeUtil::HumanString(shape())); pieces.push_back(" {\n"); for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) { pieces.push_back(" { "); @@ -649,7 +613,7 @@ string Literal::ToString(bool print_layout) const { } pieces.push_back("}"); } else if (ShapeUtil::Rank(shape()) == 3) { - pieces.push_back(shape_to_string(shape())); + pieces.push_back(ShapeUtil::HumanString(shape())); pieces.push_back(" {\n"); for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) { pieces.push_back(i0 > 0 ? ",\n{" : "{"); @@ -664,7 +628,7 @@ string Literal::ToString(bool print_layout) const { } pieces.push_back("\n}"); } else if (ShapeUtil::Rank(shape()) == 4) { - pieces.push_back(shape_to_string(shape())); + pieces.push_back(ShapeUtil::HumanString(shape())); pieces.push_back(" {\n"); for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) { pieces.push_back(tensorflow::strings::Printf(" { /*i0=%lld*/\n", i0)); @@ -685,7 +649,7 @@ string Literal::ToString(bool print_layout) const { } pieces.push_back("}"); } else if (ShapeUtil::Rank(shape()) == 5) { - pieces.push_back(shape_to_string(shape())); + pieces.push_back(ShapeUtil::HumanString(shape())); pieces.push_back(" {\n"); for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) { pieces.push_back(tensorflow::strings::Printf(" { /*i0=%lld*/\n", i0)); @@ -712,7 +676,7 @@ string Literal::ToString(bool print_layout) const { } pieces.push_back("}"); } else { - pieces.push_back(shape_to_string(shape())); + pieces.push_back(ShapeUtil::HumanString(shape())); pieces.push_back(" {...}"); } @@ -771,8 +735,6 @@ void* Literal::MutableInternalData() { return reinterpret_cast<void*>(c64s_.data()); case F16: return reinterpret_cast<void*>(f16s_.data()); - case BF16: - return reinterpret_cast<void*>(bf16s_.data()); default: LOG(FATAL) << "primitive type not supported in literals: " << PrimitiveType_Name(shape().element_type()); @@ -815,9 +777,6 @@ void Literal::Reserve(int64 num_elements) { case F16: Resize<half>(num_elements, static_cast<half>(0.0f)); break; - case BF16: - Resize<bfloat16>(num_elements, static_cast<bfloat16>(0.0f)); - break; default: LOG(FATAL) << "primitive type not supported in literals: " << PrimitiveType_Name(shape().element_type()); @@ -857,9 +816,6 @@ tensorflow::Status Literal::ValidateLiteral() const { case F16: actual = f16s().size() / sizeof(half); break; - case BF16: - actual = bf16s().size(); - break; default: return tensorflow::errors::Unimplemented( "unhandled element type for literal validation: " + @@ -956,7 +912,6 @@ StatusOr<std::unique_ptr<Literal>> ConvertIfDestTypeMatches( CONVERT_IF_TYPES_MATCH(F16) CONVERT_IF_TYPES_MATCH(F32) CONVERT_IF_TYPES_MATCH(F64) - CONVERT_IF_TYPES_MATCH(BF16) #undef CONVERT_IF_TYPES_MATCH case C64: return ConvertToC64<primitive_src_type>(src_literal); @@ -986,9 +941,8 @@ StatusOr<std::unique_ptr<Literal>> Literal::Convert( CONVERT_IF_DEST_TYPE_MATCHES(F16) CONVERT_IF_DEST_TYPE_MATCHES(F32) CONVERT_IF_DEST_TYPE_MATCHES(F64) - CONVERT_IF_DEST_TYPE_MATCHES(BF16) #undef CONVERT_IF_DEST_TYPE_MATCHES - // Other types are not yet supported. + // Other types are not yet supported. default: return InvalidArgument("Unimplemented: Convert from type %s to type %s", PrimitiveType_Name(shape().element_type()).c_str(), @@ -1057,8 +1011,6 @@ bool Literal::operator==(const Literal& other) const { return EqualElements<double>(*this, other, 0, &multi_index); case F16: return EqualElements<half>(*this, other, 0, &multi_index); - case BF16: - return EqualElements<bfloat16>(*this, other, 0, &multi_index); case C64: return EqualElements<complex64>(*this, other, 0, &multi_index); default: @@ -1168,19 +1120,14 @@ tensorflow::gtl::MutableArraySlice<complex64> Literal::GetMutableArraySlice() { template <> tensorflow::gtl::MutableArraySlice<half> Literal::GetMutableArraySlice<half>() { + // TODO - there is an endianess problem here. fix it, or wait for uint16 + // support in protobuf auto values = mutable_f16s(); return tensorflow::gtl::MutableArraySlice<half>(values->data(), values->size()); } template <> -tensorflow::gtl::MutableArraySlice<bfloat16> -Literal::GetMutableArraySlice<bfloat16>() { - auto values = mutable_bf16s(); - return {values->data(), values->size()}; -} - -template <> tensorflow::gtl::ArraySlice<bool> Literal::GetArraySlice<bool>() const { CHECK_EQ(shape().element_type(), PRED); return tensorflow::gtl::ArraySlice<bool>( @@ -1251,12 +1198,6 @@ tensorflow::gtl::ArraySlice<half> Literal::GetArraySlice<half>() const { } template <> -tensorflow::gtl::ArraySlice<bfloat16> Literal::GetArraySlice<bfloat16>() const { - CHECK_EQ(shape().element_type(), BF16); - return {bf16s().data(), bf16s().size()}; -} - -template <> tensorflow::gtl::ArraySlice<complex64> Literal::GetArraySlice<complex64>() const { CHECK_EQ(shape().element_type(), C64); @@ -1304,9 +1245,6 @@ bool Literal::IsAll(int8 value) const { return AllElementsEqualValue<double>(*this, value); case F16: return AllElementsEqualValue<half>(*this, static_cast<half>(value)); - case BF16: - return AllElementsEqualValue<bfloat16>(*this, - static_cast<bfloat16>(value)); case PRED: if (value == 0) { return AllElementsEqualValue<bool>(*this, false); @@ -1328,9 +1266,6 @@ bool Literal::IsAllFloat(float value) const { return AllElementsEqualValue<double>(*this, value); case F16: return AllElementsEqualValue<half>(*this, static_cast<half>(value)); - case BF16: - return AllElementsEqualValue<bfloat16>(*this, - static_cast<bfloat16>(value)); default: return false; } @@ -1367,8 +1302,6 @@ bool Literal::IsZero(tensorflow::gtl::ArraySlice<int64> indices) const { return Get<complex64>(indices) == complex64(0.0f, 0.0f); case F16: return Get<half>(indices) == static_cast<half>(0.0f); - case BF16: - return Get<bfloat16>(indices) == static_cast<bfloat16>(0.0f); case PRED: return Get<bool>(indices) == false; default: @@ -1437,12 +1370,6 @@ void Literal::Resize<half>(int64 num_elements, half value) { } template <> -void Literal::Resize<bfloat16>(int64 num_elements, bfloat16 value) { - CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements); - mutable_bf16s()->resize(num_elements, value); -} - -template <> void Literal::Resize<complex64>(int64 num_elements, complex64 value) { CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements); mutable_c64s()->resize(num_elements, value); @@ -1490,19 +1417,6 @@ LiteralProto Literal::ToProto() const { *proto.mutable_f16s() = string(reinterpret_cast<const char*>(f16s_.data()), f16s_.size() * sizeof(half)); - if (!kLittleEndian) { - ConvertEndianShort(const_cast<char*>(proto.mutable_f16s()->data()), - proto.f16s().size()); - } - break; - case BF16: - *proto.mutable_bf16s() = - string(reinterpret_cast<const char*>(bf16s_.data()), - bf16s_.size() * sizeof(bfloat16)); - if (!kLittleEndian) { - ConvertEndianShort(const_cast<char*>(proto.mutable_bf16s()->data()), - proto.bf16s().size()); - } break; case F32: CopyToRepeatedField(proto.mutable_f32s(), f32s()); @@ -1571,21 +1485,6 @@ void Literal::CopyFromProto(const LiteralProto& literal_proto) { CHECK_EQ(0, s.size() % sizeof(half)); f16s_ = std::vector<half>(s.size() / sizeof(half)); memcpy(f16s_.data(), s.data(), s.size()); - - if (!kLittleEndian) { - ConvertEndianShort(reinterpret_cast<char*>(f16s_.data()), s.size()); - } - break; - } - case BF16: { - const string& s(literal_proto.bf16s()); - CHECK_EQ(0, s.size() % sizeof(bfloat16)); - bf16s_ = std::vector<bfloat16>(s.size() / sizeof(bfloat16)); - memcpy(bf16s_.data(), s.data(), s.size()); - - if (!kLittleEndian) { - ConvertEndianShort(reinterpret_cast<char*>(bf16s_.data()), s.size()); - } break; } case F32: |