diff options
Diffstat (limited to 'tensorflow/compiler/xla/literal.cc')
-rw-r--r-- | tensorflow/compiler/xla/literal.cc | 1969 |
1 files changed, 1969 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/literal.cc b/tensorflow/compiler/xla/literal.cc new file mode 100644 index 0000000000..0545deb096 --- /dev/null +++ b/tensorflow/compiler/xla/literal.cc @@ -0,0 +1,1969 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/literal.h" + +#include <algorithm> +#include <cstring> +#include <functional> +#include <limits> +#include <numeric> +#include <vector> + +#include "tensorflow/compiler/xla/index_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/casts.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" + +using tensorflow::strings::Printf; +using tensorflow::strings::StrCat; + +namespace xla { + +namespace { + +constexpr bool kLittleEndian = __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__; + +// Converts between little and big endian. +// +// Precondition: size % 2 == 0 (elements in the array are 16 bits long) +void ConvertEndianShort(string* bytes) { + CHECK_EQ(bytes->size() / 2, 0); + for (int64 i = 0; i < bytes->size(); i += 2) { + std::swap((*bytes)[i], (*bytes)[i + 1]); + } +} + +void ConvertEndianShort(char* bytes, int64 size) { + CHECK_EQ(size / 2, 0); + for (int64 i = 0; i < size; i += 2) { + std::swap(bytes[i], bytes[i + 1]); + } +} + +} // namespace + +LiteralBase::~LiteralBase() {} + +std::ostream& operator<<(std::ostream& out, const Literal& literal) { + out << literal.ToString(); + return out; +} + +Literal::StrideConfig::StrideConfig( + const Shape& source_shape, const Shape& dest_shape, + tensorflow::gtl::ArraySlice<int64> dimensions) + : dimensions(dimensions), + base(dimensions.size(), 0), + step(dimensions.size(), 1) { + if (!dimensions.empty()) { + // Selects the shape with the largest minor dimension as the one upon + // which to run the tight stride loop. + if (dimensions[LayoutUtil::Minor(source_shape.layout(), 0)] >= + dimensions[LayoutUtil::Minor(dest_shape.layout(), 0)]) { + minor_dimension = LayoutUtil::Minor(source_shape.layout(), 0); + dest_stride = IndexUtil::GetDimensionStride(dest_shape, minor_dimension); + } else { + minor_dimension = LayoutUtil::Minor(dest_shape.layout(), 0); + source_stride = + IndexUtil::GetDimensionStride(source_shape, minor_dimension); + } + minor_loop_size = dimensions[minor_dimension]; + step[minor_dimension] = minor_loop_size; + } +} + +Literal::Literal(const Shape& shape) + : Literal(shape, /*allocate_arrays=*/true) {} + +void Literal::SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays) { + if (ShapeUtil::IsTuple(shape)) { + for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { + const Shape& subshape = shape.tuple_shapes(i); + + auto child_piece = Piece(); + child_piece.set_subshape(&subshape); + + SetPiece(subshape, &child_piece, allocate_arrays); + + piece->emplace_back(std::move(child_piece)); + } + } else if (ShapeUtil::IsArray(shape)) { + if (allocate_arrays) { + if (LayoutUtil::IsSparseArray(shape)) { + // For sparse arrays, the buffer must be of the size of the maximum + // number of sparse elements possible. + const int64 max_sparse_elements = + LayoutUtil::MaxSparseElements(shape.layout()); + piece->set_buffer( + new char[max_sparse_elements * + ShapeUtil::ByteSizeOfPrimitiveType(shape.element_type())]); + piece->set_sparse_indices( + new SparseIndexArray(max_sparse_elements, ShapeUtil::Rank(shape))); + } else { + piece->set_buffer(new char[piece->size_bytes()]); + } + } + } else { + // If the shape is neither an array nor tuple, then it must be + // zero-sized. Otherwise, some memory needs to be allocated for it. + CHECK_EQ(piece->size_bytes(), 0); + } +} + +Literal::Literal(const Shape& shape, bool allocate_arrays) + : LiteralBase(), shape_(MakeUnique<Shape>(shape)) { + CHECK(LayoutUtil::HasLayout(*shape_)); + root_piece_ = new Piece(); + root_piece_->set_subshape(shape_.get()); + CHECK(&root_piece_->subshape() == shape_.get()); + + SetPiece(*shape_, root_piece_, allocate_arrays); +} + +Literal::~Literal() { + if (root_piece_ != nullptr) { + DeallocateBuffers(); + delete root_piece_; + } +} + +void Literal::DeallocateBuffers() { + root_piece_->ForEachMutableSubpiece( + [&](const ShapeIndex& index, Piece* piece) { + if (piece->buffer() != nullptr) { + delete[] piece->buffer(); + delete piece->sparse_indices(); + } + }); +} + +Literal::Literal(Literal&& other) : LiteralBase() { *this = std::move(other); } + +Literal& Literal::operator=(Literal&& other) { + DCHECK(&other.root_piece_->subshape() == other.shape_.get()); + using std::swap; + swap(shape_, other.shape_); + swap(root_piece_, other.root_piece_); + DCHECK(&root_piece_->subshape() == shape_.get()); + + return *this; +} + +std::unique_ptr<Literal> LiteralBase::CreateFromShape(const Shape& shape) { + auto literal = MakeUnique<Literal>(shape); + literal->root_piece_->ForEachMutableSubpiece( + [&](const ShapeIndex& index, Piece* piece) { + if (ShapeUtil::IsArray(piece->subshape())) { + memset(piece->untyped_data(), 0, piece->size_bytes()); + } + }); + return literal; +} + +const SparseIndexArray* LiteralBase::sparse_indices( + const ShapeIndex& shape_index) const { + return piece(shape_index).sparse_indices(); +} + +SparseIndexArray* Literal::sparse_indices(const ShapeIndex& shape_index) { + return piece(shape_index).sparse_indices(); +} + +template <typename NativeT> +Status Literal::CopySliceFromInternal( + const LiteralBase& src_literal, tensorflow::gtl::ArraySlice<int64> src_base, + tensorflow::gtl::ArraySlice<int64> dest_base, + tensorflow::gtl::ArraySlice<int64> copy_size) { + TF_RET_CHECK(ShapeUtil::Rank(src_literal.shape()) == src_base.size()); + TF_RET_CHECK(ShapeUtil::Rank(shape()) == dest_base.size()); + + auto linear_index = [](const Shape& shape, + tensorflow::gtl::ArraySlice<int64> multi_index) { + return IndexUtil::MultidimensionalIndexToLinearIndex(shape, multi_index); + }; + + if (ShapeUtil::Rank(src_literal.shape()) == 0 || + ShapeUtil::Rank(shape()) == 0) { + // If any of the two shapes are scalars, we can just call the StridedCopy() + // directly, and we know we will be copying only one value. + TF_RET_CHECK(copy_size.empty()); + StridedCopy(data<NativeT>(), linear_index(shape(), dest_base), 0, + src_literal.data<NativeT>(), + linear_index(src_literal.shape(), src_base), 0, 1); + } else if (!ShapeUtil::IsZeroElementArray(shape()) && + !ShapeUtil::IsZeroElementArray(src_literal.shape())) { + // Perform copy if neither src nor dest has dimensions with zero element, + // otherwise it's a no-op. + TF_RET_CHECK(src_base.size() == dest_base.size()); + TF_RET_CHECK(src_base.size() == copy_size.size()); + + // Scan the source from minor, stepping in copy size blocks, then within + // the index enumaration functor, do a strided copy advancing source index + // by one (walking through the minor dimension), and destination index by + // proper stride size at the matching dimension. + DimensionVector src_indexes(src_base.size(), 0); + DimensionVector dest_indexes(dest_base.size(), 0); + Literal::StrideConfig stride_config(src_literal.shape(), shape(), + copy_size); + + auto copy_proc = [&](tensorflow::gtl::ArraySlice<int64> indexes) { + // Map from multi-dimensional index, to source index. + std::transform(indexes.begin(), indexes.end(), src_base.begin(), + src_indexes.begin(), std::plus<int64>()); + // Map from multi-dimensional index, to destination index. + std::transform(indexes.begin(), indexes.end(), dest_base.begin(), + dest_indexes.begin(), std::plus<int64>()); + + int64 src_index = linear_index(src_literal.shape(), src_indexes); + int64 dest_index = linear_index(shape(), dest_indexes); + + // `this->` is needed to workaround MSVC bug: #16882 + StridedCopy(this->data<NativeT>(), dest_index, stride_config.dest_stride, + src_literal.data<NativeT>(), src_index, + stride_config.source_stride, stride_config.minor_loop_size); + return true; + }; + + ShapeUtil::ForEachIndex(src_literal.shape(), stride_config.base, + stride_config.dimensions, stride_config.step, + copy_proc); + } + return Status::OK(); +} + +Status Literal::CopyElementFrom(const LiteralSlice& src_literal, + tensorflow::gtl::ArraySlice<int64> src_index, + tensorflow::gtl::ArraySlice<int64> dest_index) { + DCHECK_EQ(shape().element_type(), src_literal.shape().element_type()); + const int64 src_linear_index = IndexUtil::MultidimensionalIndexToLinearIndex( + src_literal.shape(), src_index); + const int64 dest_linear_index = + IndexUtil::MultidimensionalIndexToLinearIndex(shape(), dest_index); + const int64 primitive_size = + ShapeUtil::ByteSizeOfPrimitiveType(shape().element_type()); + + char* dest_address = + static_cast<char*>(untyped_data()) + dest_linear_index * primitive_size; + const char* source_address = + static_cast<const char*>(src_literal.untyped_data()) + + src_linear_index * primitive_size; + if (dest_address != source_address) { + memcpy(dest_address, source_address, primitive_size); + } + return Status::OK(); +} + +/* static */ StatusOr<std::unique_ptr<Literal>> Literal::CreateFromProto( + const LiteralProto& proto) { + if (!proto.has_shape()) { + return InvalidArgument("LiteralProto has no shape"); + } + if (!LayoutUtil::HasLayout(proto.shape())) { + return InvalidArgument("LiteralProto has no layout"); + } + + auto literal = MakeUnique<Literal>(proto.shape()); + + TF_RETURN_IF_ERROR(literal->root_piece_->ForEachMutableSubpieceWithStatus( + [&](const ShapeIndex& index, Piece* piece) { + const LiteralProto* proto_element = &proto; + for (int64 i : index) { + CHECK(i < proto_element->tuple_literals_size()); + proto_element = &proto_element->tuple_literals(i); + } + + if (ShapeUtil::IsTuple(piece->subshape())) { + if (proto_element->tuple_literals_size() != + ShapeUtil::TupleElementCount(piece->subshape())) { + return InvalidArgument( + "Expected %lld tuple elements in LiteralProto, has %d", + ShapeUtil::TupleElementCount(piece->subshape()), + proto_element->tuple_literals_size()); + } + return Status::OK(); + } + if (piece->subshape().element_type() == TOKEN) { + return Status::OK(); + } + + CHECK(ShapeUtil::IsArray(piece->subshape())); + TF_RETURN_IF_ERROR(piece->CopyFromProto(*proto_element)); + + return Status::OK(); + })); + + return std::move(literal); +} + +std::vector<Literal> Literal::DecomposeTuple() { + CHECK(ShapeUtil::IsTuple(shape())); + std::vector<Literal> elements; + for (int i = 0; i < ShapeUtil::TupleElementCount(shape()); ++i) { + elements.push_back(Literal(ShapeUtil::GetSubshape(shape(), {i}), + /*allocate_arrays=*/false)); + Literal& element = elements.back(); + element.root_piece_->ForEachMutableSubpiece( + [&](const ShapeIndex& index, Piece* dest_piece) { + ShapeIndex src_index = {i}; + for (int64 j : index) { + src_index.push_back(j); + } + Piece& src_piece = piece(src_index); + + // Move the respective buffer and sparse indices over to the element + // Literal. + dest_piece->set_buffer(src_piece.buffer()); + src_piece.set_buffer(nullptr); + dest_piece->set_sparse_indices(src_piece.sparse_indices()); + src_piece.set_sparse_indices(nullptr); + }); + } + // Set this literal to be nil-shaped. + *this = Literal(); + return elements; +} + +namespace { + +// Copies the elements in 'src' to 'dest'. The shape and layout of the data in +// the array slices are indicated by dest_shape and src_shape respectively. +template <typename NativeT> +void CopyElementsBetween(tensorflow::gtl::MutableArraySlice<NativeT> dest, + tensorflow::gtl::ArraySlice<NativeT> src, + const Shape& dest_shape, const Shape& src_shape) { + CHECK(ShapeUtil::Compatible(dest_shape, src_shape)); + if (ShapeUtil::IsZeroElementArray(dest_shape)) { + return; + } + std::vector<int64> index(ShapeUtil::Rank(dest_shape)); + do { + dest[IndexUtil::MultidimensionalIndexToLinearIndex(dest_shape, index)] = + src[IndexUtil::MultidimensionalIndexToLinearIndex(src_shape, index)]; + } while (IndexUtil::BumpIndices(dest_shape, &index)); +} + +} // namespace + +Status LiteralBase::Piece::CopyFrom(const LiteralBase::Piece& src) { + CHECK(subshape_ != nullptr); + CHECK(src.subshape_ != nullptr); + if (ShapeUtil::Equal(subshape(), src.subshape())) { + // If the layouts are equal it's faster just to memcpy. + memcpy(buffer(), src.buffer(), src.size_bytes()); + } else { + TF_RET_CHECK(ShapeUtil::Compatible(src.subshape(), subshape())); + std::vector<int64> origin(ShapeUtil::Rank(subshape()), 0); + switch (subshape().element_type()) { +#define COPY_ELEMENTS(XLA_T, NATIVE_T) \ + case (XLA_T): \ + CopyElementsBetween<NATIVE_T>(data<NATIVE_T>(), src.data<NATIVE_T>(), \ + subshape(), src.subshape()); \ + break; + COPY_ELEMENTS(U8, uint8); + COPY_ELEMENTS(U16, uint16); + COPY_ELEMENTS(U32, uint32); + COPY_ELEMENTS(U64, uint64); + COPY_ELEMENTS(S8, int8); + COPY_ELEMENTS(S16, int16); + COPY_ELEMENTS(S32, int32); + COPY_ELEMENTS(S64, int64); + COPY_ELEMENTS(F16, half); + COPY_ELEMENTS(BF16, bfloat16); + COPY_ELEMENTS(F32, float); + COPY_ELEMENTS(F64, double); + COPY_ELEMENTS(C64, complex64); + COPY_ELEMENTS(PRED, bool); +#undef COPY_ELEMENTS + default: + return Unimplemented( + "Copying a Literal object with element type %s is not implemented.", + PrimitiveType_Name(subshape().element_type()).c_str()); + } + } + return Status::OK(); +} + +Status Literal::CopyFrom(const LiteralSlice& src_literal, + const ShapeIndex& dest_shape_index, + const ShapeIndex& src_shape_index) { + const Shape& dest_subshape = + ShapeUtil::GetSubshape(shape(), dest_shape_index); + const Shape& src_subshape = + ShapeUtil::GetSubshape(src_literal.shape(), src_shape_index); + if (!ShapeUtil::Compatible(dest_subshape, src_subshape)) { + return InvalidArgument( + "Destination subshape incompatible with source subshape: %s vs %s", + ShapeUtil::HumanString(dest_subshape).c_str(), + ShapeUtil::HumanString(src_subshape).c_str()); + } + return root_piece_->ForEachMutableSubpieceWithStatus( + [&](const ShapeIndex& index, Piece* piece) { + if (!ShapeUtil::IsArray(piece->subshape())) { + return Status::OK(); + } + + // Determine if this index is in the part of this literal that we want + // to copy over from src_literal. + bool in_subtree_to_copy = true; + for (int i = 0; i < dest_shape_index.size(); ++i) { + if (index[i] != dest_shape_index[i]) { + in_subtree_to_copy = false; + break; + } + } + if (!in_subtree_to_copy) { + return Status::OK(); + } + // Construct the index of the corresponding piece in the source literal. + ShapeIndex src_piece_index = src_shape_index; + for (int64 i = dest_shape_index.size(); i < index.size(); ++i) { + src_piece_index.push_back(index[i]); + } + TF_RETURN_IF_ERROR(piece->CopyFrom(src_literal.piece(src_piece_index))); + return Status::OK(); + }); +} + +Status Literal::MoveFrom(Literal&& src_literal, + const ShapeIndex& dest_shape_index) { + const Shape& dest_subshape = + ShapeUtil::GetSubshape(shape(), dest_shape_index); + if (!ShapeUtil::Equal(dest_subshape, src_literal.shape())) { + return InvalidArgument( + "Destination subshape not equal to source shape: %s vs %s", + ShapeUtil::HumanString(dest_subshape).c_str(), + ShapeUtil::HumanString(src_literal.shape()).c_str()); + } + + src_literal.root_piece_->ForEachSubpiece( + [&](const ShapeIndex& src_index, const Piece& src_piece) { + if (!ShapeUtil::IsArray(src_piece.subshape())) { + return; + } + + ShapeIndex dest_index = dest_shape_index; + for (int64 i : src_index) { + dest_index.push_back(i); + } + Piece& dest_piece = piece(dest_index); + delete[] dest_piece.buffer(); + dest_piece.set_buffer(src_piece.buffer()); + delete dest_piece.sparse_indices(); + dest_piece.set_sparse_indices(src_piece.sparse_indices()); + }); + + src_literal.shape_ = MakeUnique<Shape>(ShapeUtil::MakeNil()); + delete src_literal.root_piece_; + src_literal.root_piece_ = new LiteralBase::Piece(); + src_literal.root_piece_->set_subshape(src_literal.shape_.get()); + + return Status::OK(); +} + +Status Literal::CopySliceFrom(const LiteralSlice& src_literal, + tensorflow::gtl::ArraySlice<int64> src_base, + tensorflow::gtl::ArraySlice<int64> dest_base, + tensorflow::gtl::ArraySlice<int64> copy_size) { + TF_RET_CHECK(ShapeUtil::IsArray(shape())) << ShapeUtil::HumanString(shape()); + TF_RET_CHECK(ShapeUtil::IsArray(src_literal.shape())) + << ShapeUtil::HumanString(src_literal.shape()); + TF_RET_CHECK(ShapeUtil::SameElementType(src_literal.shape(), shape())); + + switch (shape().element_type()) { + case U8: + return CopySliceFromInternal<uint8>(src_literal, src_base, dest_base, + copy_size); + case U16: + return CopySliceFromInternal<uint16>(src_literal, src_base, dest_base, + copy_size); + case U32: + return CopySliceFromInternal<uint32>(src_literal, src_base, dest_base, + copy_size); + case U64: + return CopySliceFromInternal<uint64>(src_literal, src_base, dest_base, + copy_size); + case S8: + return CopySliceFromInternal<int8>(src_literal, src_base, dest_base, + copy_size); + case S16: + return CopySliceFromInternal<int16>(src_literal, src_base, dest_base, + copy_size); + case S32: + return CopySliceFromInternal<int32>(src_literal, src_base, dest_base, + copy_size); + case S64: + return CopySliceFromInternal<int64>(src_literal, src_base, dest_base, + copy_size); + case F16: + return CopySliceFromInternal<half>(src_literal, src_base, dest_base, + copy_size); + case BF16: + return CopySliceFromInternal<bfloat16>(src_literal, src_base, dest_base, + copy_size); + case F32: + return CopySliceFromInternal<float>(src_literal, src_base, dest_base, + copy_size); + case F64: + return CopySliceFromInternal<double>(src_literal, src_base, dest_base, + copy_size); + case C64: + return CopySliceFromInternal<complex64>(src_literal, src_base, dest_base, + copy_size); + case PRED: + return CopySliceFromInternal<bool>(src_literal, src_base, dest_base, + copy_size); + default: + break; + } + return Unimplemented( + "Copying a slice from a Literal object with element type %d is not " + "implemented.", + shape().element_type()); +} + +void Literal::PopulateR1(const tensorflow::core::Bitmap& values) { + CHECK(ShapeUtil::IsArray(shape())); + CHECK_EQ(ShapeUtil::Rank(shape()), 1); + CHECK_EQ(element_count(), values.bits()); + CHECK_EQ(shape().element_type(), PRED); + for (int64 i = 0; i < static_cast<int64>(values.bits()); ++i) { + Set({i}, values.get(i)); + } +} + +std::unique_ptr<Literal> LiteralBase::Relayout( + const Layout& new_layout, const ShapeIndex& shape_index) const { + // Create new shape with 'new_layout' set at the given shape index. + Shape new_shape = shape(); + Shape* subshape = ShapeUtil::GetMutableSubshape(&new_shape, shape_index); + TF_CHECK_OK(LayoutUtil::ValidateLayoutForShape(new_layout, *subshape)); + *subshape->mutable_layout() = new_layout; + auto result = MakeUnique<Literal>(new_shape); + TF_CHECK_OK(result->CopyFrom(*this)); + return result; +} + +std::unique_ptr<Literal> LiteralBase::Relayout( + const Shape& shape_with_layout) const { + CHECK(ShapeUtil::Compatible(shape_with_layout, shape())) + << "Given shape_with_layout " << ShapeUtil::HumanString(shape_with_layout) + << " not compatible with literal shape " + << ShapeUtil::HumanString(shape()); + std::unique_ptr<Literal> result = CreateFromShape(shape_with_layout); + ShapeUtil::ForEachSubshape( + result->shape(), + [this, &result](const Shape& subshape, const ShapeIndex& index) { + if (ShapeUtil::IsArray(subshape)) { + TF_CHECK_OK(result->CopyFrom(*this, + /*dest_shape_index=*/index, + /*src_shape_index=*/index)); + } + }); + return result; +} + +StatusOr<std::unique_ptr<Literal>> LiteralBase::Broadcast( + const Shape& result_shape, + tensorflow::gtl::ArraySlice<int64> dimensions) const { + if (!ShapeUtil::IsArray(shape())) { + return InvalidArgument("Broadcast only supports arrays."); + } + + for (int64 i = 0; i < dimensions.size(); i++) { + TF_RET_CHECK(shape().dimensions(i) == + result_shape.dimensions(dimensions[i])); + } + + std::unique_ptr<Literal> result = MakeUnique<Literal>(result_shape); + + // scratch_source_index is temporary storage space for the computed index into + // the input literal. We put it here to avoid allocating an std::vector in + // every iteration of ShapeUtil::ForEachIndex. + std::vector<int64> scratch_source_index(shape().dimensions_size()); + + char* dest_data = static_cast<char*>(result->untyped_data()); + const char* source_data = static_cast<const char*>(untyped_data()); + const int64 primitive_size = + ShapeUtil::ByteSizeOfPrimitiveType(shape().element_type()); + + ShapeUtil::ForEachIndex( + result_shape, [&](tensorflow::gtl::ArraySlice<int64> output_index) { + for (int64 i = 0; i < dimensions.size(); ++i) { + scratch_source_index[i] = output_index[dimensions[i]]; + } + int64 dest_index = IndexUtil::MultidimensionalIndexToLinearIndex( + result_shape, output_index); + int64 source_index = IndexUtil::MultidimensionalIndexToLinearIndex( + shape(), scratch_source_index); + memcpy(dest_data + primitive_size * dest_index, + source_data + primitive_size * source_index, primitive_size); + return true; + }); + + return std::move(result); +} + +StatusOr<std::unique_ptr<Literal>> LiteralBase::Reshape( + tensorflow::gtl::ArraySlice<int64> dimensions) const { + if (!ShapeUtil::IsArray(shape())) { + return InvalidArgument("Reshape does not support tuples."); + } + std::unique_ptr<Literal> output; + if (!LayoutUtil::IsMonotonicWithDim0Major(shape().layout())) { + output = + Relayout(LayoutUtil::GetDefaultLayoutForRank(ShapeUtil::Rank(shape()))); + } else { + output = CloneToUnique(); + } + // Because the layout is monotonic, we can simply reuse the same sequence of + // values without changing their order. + *output->mutable_shape_do_not_use() = + ShapeUtil::MakeShape(shape().element_type(), dimensions); + + int64 elements_before = ShapeUtil::ElementsIn(shape()); + int64 elements_after = ShapeUtil::ElementsIn(output->shape()); + if (elements_before != elements_after) { + return InvalidArgument( + "Shapes before and after Literal::Reshape have different numbers " + "of elements: %s vs %s.", + ShapeUtil::HumanString(shape()).c_str(), + ShapeUtil::HumanString(output->shape()).c_str()); + } + return std::move(output); +} + +std::unique_ptr<Literal> LiteralBase::Transpose( + tensorflow::gtl::ArraySlice<int64> permutation) const { + CHECK(ShapeUtil::IsArray(shape())) << "Tuple is not supported for transpose"; + CHECK(IsPermutation(permutation, ShapeUtil::Rank(shape()))) + << "Given permutation is not a permutation of dimension numbers"; + // To transpose the array, we just permute the dimensions and layout, and + // do a straight memory copy of the raw data set. + // This is considerably faster than iterating over every array element using + // the EachCell<>() and Set<>() APIs. + std::vector<int64> inverse_permutation = InversePermutation(permutation); + Shape permuted_shape = + ShapeUtil::PermuteDimensions(inverse_permutation, shape()); + // Replace the layout with one affine to this shape, such that a + // transpose operation can be performed by leaving the flat values + // representation intact. + // For example, consider the shape F32[11,8]{1,0} under a {1,0} permutation. + // The shape with affine layout resulting from that operation will be + // F32[8,11]{0,1}, since it leaves the original most minor (the 8 sized), the + // most minor. + // + // Essentially, given MinMaj(Di) the position of the Di dimension within the + // minor to major vector, and given T(Di) the index that the original Di + // dimension has within the transposed array, a layout is affine if + // MinMaj(Di) == TMinMaj(T(Di)), with TMinMaj() being the minor to major + // vector of the affine layout. + CHECK(LayoutUtil::IsDenseArray(permuted_shape)); + Layout* layout = permuted_shape.mutable_layout(); + layout->clear_minor_to_major(); + for (auto index : LayoutUtil::MinorToMajor(shape())) { + layout->add_minor_to_major(inverse_permutation[index]); + } + auto new_literal = MakeUnique<Literal>(permuted_shape); + DCHECK_EQ(ShapeUtil::ByteSizeOf(new_literal->shape()), + ShapeUtil::ByteSizeOf(shape())); + std::memcpy(new_literal->untyped_data(), untyped_data(), size_bytes()); + return new_literal; +} + +template <typename NativeT> +std::unique_ptr<Literal> LiteralBase::SliceInternal( + const Shape& result_shape, + tensorflow::gtl::ArraySlice<int64> start_indices) const { + auto result_literal = MakeUnique<Literal>(result_shape); + DimensionVector new_indices(ShapeUtil::Rank(result_shape)); + result_literal->EachCell<NativeT>( + [&](tensorflow::gtl::ArraySlice<int64> indices, NativeT /*value*/) { + for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) { + new_indices[i] = indices[i] + start_indices[i]; + } + NativeT value = Get<NativeT>(new_indices); + result_literal->Set<NativeT>(indices, value); + }); + return result_literal; +} + +std::unique_ptr<Literal> LiteralBase::Slice( + tensorflow::gtl::ArraySlice<int64> start_indices, + tensorflow::gtl::ArraySlice<int64> limit_indices) const { + CHECK(ShapeUtil::IsArray(shape())) << "tuple is not supported for slice"; + + DimensionVector result_dimensions; + for (int64 dnum = 0; dnum < ShapeUtil::Rank(shape()); ++dnum) { + CHECK_GE(start_indices[dnum], 0); + CHECK_LE(limit_indices[dnum], shape().dimensions(dnum)) + << "dnum = " << dnum; + int64 dimension = limit_indices[dnum] - start_indices[dnum]; + CHECK_GE(dimension, 0) << "dnum = " << dnum; + result_dimensions.push_back(dimension); + } + const auto result_shape = + ShapeUtil::MakeShapeWithLayout(shape().element_type(), result_dimensions, + LayoutUtil::MinorToMajor(shape())); + switch (result_shape.element_type()) { + case F32: + return SliceInternal<float>(result_shape, start_indices); + case BF16: + return SliceInternal<bfloat16>(result_shape, start_indices); + case C64: + return SliceInternal<complex64>(result_shape, start_indices); + case S32: + return SliceInternal<int32>(result_shape, start_indices); + case U32: + return SliceInternal<uint32>(result_shape, start_indices); + default: + LOG(FATAL) << "not yet implemented: " + << PrimitiveType_Name(result_shape.element_type()); + } +} + +Literal LiteralBase::Clone() const { + Literal result(shape()); + TF_CHECK_OK(result.CopyFrom(*this)); + return result; +} + +std::unique_ptr<Literal> LiteralBase::CloneToUnique() const { + auto result = MakeUnique<Literal>(shape()); + TF_CHECK_OK(result->CopyFrom(*this)); + return result; +} + +string LiteralBase::GetAsString(tensorflow::gtl::ArraySlice<int64> multi_index, + const ShapeIndex& shape_index) const { + const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index); + CHECK(LayoutUtil::IsDenseArray(subshape)); + switch (subshape.element_type()) { + case PRED: + return Get<bool>(multi_index, shape_index) ? "true" : "false"; + case S8: + return StrCat(Get<int8>(multi_index, shape_index)); + case S16: + return StrCat(Get<int16>(multi_index, shape_index)); + case S32: + return StrCat(Get<int32>(multi_index, shape_index)); + case S64: + return StrCat(Get<int64>(multi_index, shape_index)); + case U8: + return StrCat(Get<uint8>(multi_index, shape_index)); + case U16: + return StrCat(Get<uint16>(multi_index, shape_index)); + case U32: + return StrCat(Get<uint32>(multi_index, shape_index)); + case U64: + return StrCat(Get<uint64>(multi_index, shape_index)); + case F16: + return StrCat(static_cast<float>(Get<half>(multi_index, shape_index))); + case F32: + return StrCat(Get<float>(multi_index, shape_index)); + case BF16: + return StrCat( + static_cast<float>(Get<bfloat16>(multi_index, shape_index))); + case F64: + return StrCat(Get<double>(multi_index, shape_index)); + case C64: { + complex64 c = Get<complex64>(multi_index, shape_index); + return StrCat("(", c.real(), ", ", c.imag(), ")"); + } + default: + LOG(FATAL) << PrimitiveType_Name(subshape.element_type()); + } +} + +string LiteralBase::GetSparseElementAsString( + int64 sparse_element_number, const ShapeIndex& shape_index) const { + const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index); + CHECK(LayoutUtil::IsSparseArray(subshape)); + switch (subshape.element_type()) { + case PRED: + return GetSparseElement<bool>(sparse_element_number, shape_index) + ? "true" + : "false"; + case S8: + return StrCat(GetSparseElement<int8>(sparse_element_number, shape_index)); + case S16: + return StrCat( + GetSparseElement<int16>(sparse_element_number, shape_index)); + case S32: + return StrCat( + GetSparseElement<int32>(sparse_element_number, shape_index)); + case S64: + return StrCat( + GetSparseElement<int64>(sparse_element_number, shape_index)); + case U8: + return StrCat( + GetSparseElement<uint8>(sparse_element_number, shape_index)); + case U16: + return StrCat( + GetSparseElement<uint16>(sparse_element_number, shape_index)); + case U32: + return StrCat( + GetSparseElement<uint32>(sparse_element_number, shape_index)); + case U64: + return StrCat( + GetSparseElement<uint64>(sparse_element_number, shape_index)); + case F16: + return StrCat(static_cast<float>( + GetSparseElement<half>(sparse_element_number, shape_index))); + case F32: + return StrCat( + GetSparseElement<float>(sparse_element_number, shape_index)); + case BF16: + return StrCat(static_cast<float>( + GetSparseElement<bfloat16>(sparse_element_number, shape_index))); + case F64: + return StrCat( + GetSparseElement<double>(sparse_element_number, shape_index)); + case C64: { + complex64 c = + GetSparseElement<complex64>(sparse_element_number, shape_index); + return StrCat("(", c.real(), ", ", c.imag(), ")"); + } + default: + LOG(FATAL) << "Invalid element type for sparse arrays: " + << PrimitiveType_Name(subshape.element_type()); + } +} + +StatusOr<int64> LiteralBase::GetIntegralAsS64( + tensorflow::gtl::ArraySlice<int64> multi_index) const { + CHECK(LayoutUtil::IsDenseArray(shape())); + switch (shape().element_type()) { + case PRED: + return Get<bool>(multi_index); + case U8: + return Get<uint8>(multi_index); + case S32: + return Get<int32>(multi_index); + case S64: + return Get<int64>(multi_index); + case U32: + return Get<uint32>(multi_index); + case U64: + return Get<uint64>(multi_index); + default: + return FailedPrecondition( + "Array element type is not integral: %s", + PrimitiveType_Name(shape().element_type()).c_str()); + } +} + +size_t LiteralBase::Hash() const { + using tensorflow::Hash64; + using tensorflow::Hash64Combine; + + size_t hash_value = ShapeUtil::Hash(shape()); + + ShapeUtil::ForEachSubshape( + shape(), [&](const Shape& subshape, const ShapeIndex& index) { + if (!ShapeUtil::IsArray(subshape)) { + return; + } + + CHECK(LayoutUtil::IsDense(subshape.layout())); + hash_value = Hash64Combine( + hash_value, Hash64(static_cast<const char*>(untyped_data(index)), + size_bytes(index))); + }); + + return hash_value; +} + +Status Literal::SetIntegralAsS64(tensorflow::gtl::ArraySlice<int64> multi_index, + int64 value) { + CHECK(LayoutUtil::IsDenseArray(shape())); + switch (shape().element_type()) { + case PRED: + Set<bool>(multi_index, value); + break; + case U8: + Set<uint8>(multi_index, value); + break; + case S32: + Set<int32>(multi_index, value); + break; + case S64: + Set<int64>(multi_index, value); + break; + case U32: + Set<uint32>(multi_index, value); + break; + case U64: + Set<uint64>(multi_index, value); + break; + default: + return FailedPrecondition( + "Array element type is not integral: %s", + PrimitiveType_Name(shape().element_type()).c_str()); + } + return Status::OK(); +} + +tensorflow::gtl::ArraySlice<int64> LiteralBase::GetSparseIndex( + int64 sparse_element_number, const ShapeIndex& shape_index) const { + const Piece& p = piece(shape_index); + CHECK_GE(sparse_element_number, 0); + CHECK_LT(sparse_element_number, p.sparse_indices()->index_count()); + return p.sparse_indices()->At(sparse_element_number); +} + +void Literal::SortSparseElements(const ShapeIndex& shape_index) { + piece(shape_index).SortSparseElements(); +} + +void LiteralBase::Piece::SortSparseElements() { + switch (subshape().element_type()) { + case PRED: + SortSparseElementsInternal<bool>(); + break; + case S8: + SortSparseElementsInternal<int8>(); + break; + case U8: + SortSparseElementsInternal<uint8>(); + break; + case S16: + SortSparseElementsInternal<int16>(); + break; + case U16: + SortSparseElementsInternal<uint16>(); + break; + case S32: + SortSparseElementsInternal<int32>(); + break; + case U32: + SortSparseElementsInternal<uint32>(); + break; + case S64: + SortSparseElementsInternal<int64>(); + break; + case U64: + SortSparseElementsInternal<uint64>(); + break; + case F32: + SortSparseElementsInternal<float>(); + break; + case F64: + SortSparseElementsInternal<double>(); + break; + case C64: + SortSparseElementsInternal<complex64>(); + break; + case F16: + SortSparseElementsInternal<half>(); + break; + case BF16: + SortSparseElementsInternal<bfloat16>(); + break; + default: + LOG(FATAL) << "Element type not valid for sparse array: " + << PrimitiveType_Name(subshape().element_type()); + } +} + +template <typename NativeT> +void LiteralBase::Piece::SortSparseElementsInternal() { + CHECK(LayoutUtil::IsSparseArray(subshape())); + int64 num_elements = sparse_indices()->index_count(); + auto values = data<NativeT>(); + CHECK_LE(num_elements, values.size()); + sparse_indices()->SortWithValues( + tensorflow::gtl::MutableArraySlice<NativeT>(values.data(), num_elements)); +} + +namespace { + +void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index, + bool print_layout, std::vector<string>* pieces) { + const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index); + CHECK(LayoutUtil::HasLayout(literal.shape())); + CHECK(LayoutUtil::HasLayout(subshape)); + + auto shape_to_string = [print_layout](const Shape& shape) { + if (print_layout) { + return ShapeUtil::HumanStringWithLayout(shape); + } else { + return ShapeUtil::HumanString(shape); + } + }; + + // TODO(b/32894291): refactor this code to reduce code duplication. + if (ShapeUtil::IsTuple(subshape)) { + pieces->push_back(shape_to_string(subshape)); + pieces->push_back(" (\n"); + std::vector<string> tuple_pieces; + for (int i = 0; i < ShapeUtil::TupleElementCount(subshape); ++i) { + ShapeIndex element_index = shape_index; + element_index.push_back(i); + std::vector<string> element_pieces; + ToStringHelper(literal, element_index, print_layout, &element_pieces); + tuple_pieces.push_back(tensorflow::str_util::Join(element_pieces, "")); + } + pieces->push_back(tensorflow::str_util::Join(tuple_pieces, ",\n")); + pieces->push_back("\n)"); + return; + } + + if (ShapeUtil::IsToken(subshape)) { + pieces->push_back("token"); + return; + } + + if (LayoutUtil::IsSparseArray(subshape)) { + pieces->push_back(shape_to_string(subshape)); + pieces->push_back("{"); + int64 rank = ShapeUtil::Rank(subshape); + int64 num_elements = literal.sparse_element_count(); + for (int64 i = 0; i < num_elements; ++i) { + if (i > 0) { + pieces->push_back(", "); + } + if (rank == 1) { + pieces->push_back(StrCat(literal.GetSparseIndex(i)[0])); + pieces->push_back(": "); + } else { + pieces->push_back("["); + pieces->push_back( + tensorflow::str_util::Join(literal.GetSparseIndex(i), ", ")); + pieces->push_back("]: "); + } + pieces->push_back(literal.GetSparseElementAsString(i)); + } + pieces->push_back("}"); + return; + } + + CHECK(LayoutUtil::IsDenseArray(subshape)); + + auto element_to_string = + [&](tensorflow::gtl::ArraySlice<int64> indices) -> string { + PrimitiveType element_type = subshape.element_type(); + if (element_type == PRED) { + // We display predicates in a densely packed form. + return literal.Get<bool>(indices, shape_index) ? "1" : "0"; + } + return ((!indices.empty() && indices.back() > 0) ? ", " : "") + + literal.GetAsString(indices, shape_index); + }; + + if (ShapeUtil::Rank(subshape) == 0) { + pieces->push_back(literal.GetAsString({}, shape_index)); + } else if (ShapeUtil::Rank(subshape) == 1) { + pieces->push_back("{"); + for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) { + pieces->push_back(element_to_string({i0})); + } + pieces->push_back("}"); + } else if (ShapeUtil::Rank(subshape) == 2) { + pieces->push_back(shape_to_string(subshape)); + pieces->push_back(" {\n"); + for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) { + pieces->push_back(" { "); + for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) { + pieces->push_back(element_to_string({i0, i1})); + } + pieces->push_back(" "); + pieces->push_back(i0 == subshape.dimensions(0) - 1 ? "}\n" : "},\n"); + } + pieces->push_back("}"); + } else if (ShapeUtil::Rank(subshape) == 3) { + pieces->push_back(shape_to_string(subshape)); + pieces->push_back(" {\n"); + for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) { + pieces->push_back(i0 > 0 ? ",\n{" : "{"); + for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) { + pieces->push_back(i1 > 0 ? ",\n { " : " { "); + for (int64 i2 = 0; i2 < subshape.dimensions(2); ++i2) { + pieces->push_back(element_to_string({i0, i1, i2})); + } + pieces->push_back(" }"); + } + pieces->push_back(" }"); + } + pieces->push_back("\n}"); + } else if (ShapeUtil::Rank(subshape) == 4) { + pieces->push_back(shape_to_string(subshape)); + pieces->push_back(" {\n"); + for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) { + pieces->push_back(Printf(" { /*i0=%lld*/\n", i0)); + for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) { + pieces->push_back(Printf(" { /*i1=%lld*/\n", i1)); + for (int64 i2 = 0; i2 < subshape.dimensions(2); ++i2) { + pieces->push_back(" {"); + for (int64 i3 = 0; i3 < subshape.dimensions(3); ++i3) { + pieces->push_back(element_to_string({i0, i1, i2, i3})); + } + pieces->push_back(i2 == subshape.dimensions(2) - 1 ? "}\n" : "},\n"); + } + pieces->push_back(i1 == subshape.dimensions(1) - 1 ? " }\n" + : " },\n"); + } + pieces->push_back(i0 == subshape.dimensions(0) - 1 ? " }\n" : " },\n"); + } + pieces->push_back("}"); + } else if (ShapeUtil::Rank(subshape) == 5) { + pieces->push_back(shape_to_string(subshape)); + pieces->push_back(" {\n"); + for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) { + pieces->push_back(Printf(" { /*i0=%lld*/\n", i0)); + for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) { + pieces->push_back(Printf(" { /*i1=%lld*/\n", i1)); + for (int64 i2 = 0; i2 < subshape.dimensions(2); ++i2) { + pieces->push_back(Printf(" { /*i2=%lld*/\n", i2)); + for (int64 i3 = 0; i3 < subshape.dimensions(3); ++i3) { + pieces->push_back(" {"); + for (int64 i4 = 0; i4 < subshape.dimensions(4); ++i4) { + pieces->push_back(element_to_string({i0, i1, i2, i3, i4})); + } + pieces->push_back(i3 == subshape.dimensions(3) - 1 ? "}\n" + : "},\n"); + } + pieces->push_back(i2 == subshape.dimensions(2) - 1 ? " }\n" + : " },\n"); + } + pieces->push_back(i1 == subshape.dimensions(1) - 1 ? " }\n" + : " },\n"); + } + pieces->push_back(i0 == subshape.dimensions(0) - 1 ? " }\n" : " },\n"); + } + pieces->push_back("}"); + } else { + pieces->push_back(shape_to_string(subshape)); + pieces->push_back(" {"); + literal.EachCellAsString( + [&](tensorflow::gtl::ArraySlice<int64> indices, const string& value) { + pieces->push_back(" "); + pieces->push_back(value); + }); + pieces->push_back("}"); + } +} + +} // namespace + +int64 LiteralBase::sparse_element_count() const { + CHECK(LayoutUtil::IsSparseArray(shape())); + return sparse_indices()->index_count(); +} + +string LiteralBase::ToString(bool print_layout) const { + std::vector<string> pieces; + CHECK(LayoutUtil::HasLayout(this->shape())); + ToStringHelper(*this, {}, print_layout, &pieces); + return tensorflow::str_util::Join(pieces, ""); +} + +void LiteralBase::EachCellAsString( + const std::function<void(tensorflow::gtl::ArraySlice<int64> indices, + const string& value)>& per_cell) const { + if (ShapeUtil::IsZeroElementArray(shape())) { + return; + } + std::vector<int64> indices = IndexUtil::LinearIndexToMultidimensionalIndex( + shape(), /*linear_index=*/0); + do { + per_cell(indices, GetAsString(indices)); + } while (IndexUtil::BumpIndices(shape(), &indices)); +} + +namespace { +template <typename NativeSrcT, typename NativeDestT, typename ConverterType> +std::unique_ptr<Literal> ConvertBetweenNativeTypesWithConverter( + const LiteralBase& src_literal, const ConverterType& converter) { + CHECK(ShapeUtil::IsArray(src_literal.shape())); + auto result_literal = MakeUnique<Literal>(ShapeUtil::ChangeElementType( + src_literal.shape(), + primitive_util::NativeToPrimitiveType<NativeDestT>())); + auto src_data = src_literal.data<NativeSrcT>(); + auto dest_data = result_literal->template data<NativeDestT>(); + int64 num_elements = src_literal.element_count(); + + for (int64 i = 0; i < num_elements; ++i) { + dest_data[i] = converter(src_data[i]); + } + return result_literal; +} + +template <typename NativeSrcT, typename NativeDestT> +std::unique_ptr<Literal> ConvertBetweenNativeTypes( + const LiteralBase& src_literal) { + auto converter = [](NativeSrcT src) { return static_cast<NativeDestT>(src); }; + return ConvertBetweenNativeTypesWithConverter<NativeSrcT, NativeDestT>( + src_literal, converter); +} + +template <typename NativeSrcT, typename NativeDestT> +typename std::enable_if<(sizeof(NativeSrcT) == sizeof(NativeDestT)), + std::unique_ptr<Literal>>::type +BitcastBetweenNativeTypes(const LiteralBase& src_literal) { + auto converter = [](NativeSrcT src) { + return tensorflow::bit_cast<NativeDestT>(src); + }; + return ConvertBetweenNativeTypesWithConverter<NativeSrcT, NativeDestT>( + src_literal, converter); +} + +// This template specialization is here to make the compiler happy. bit_cast has +// a static check that the types are the same size. This specialization should +// never be used because the source and destination types are checked for +// identical sizes higher up. +template <typename NativeSrcT, typename NativeDestT> +typename std::enable_if<(sizeof(NativeSrcT) != sizeof(NativeDestT)), + std::unique_ptr<Literal>>::type +BitcastBetweenNativeTypes(const LiteralBase& src_literal) { + LOG(FATAL) << "Invalid bitcast between types of different sizes."; +} + +template <PrimitiveType primitive_src_type> +std::unique_ptr<Literal> ConvertToC64(const LiteralBase& src_literal) { + CHECK(ShapeUtil::IsArray(src_literal.shape())); + auto result_literal = MakeUnique<Literal>( + ShapeUtil::ChangeElementType(src_literal.shape(), C64)); + using NativeSrcT = + typename primitive_util::PrimitiveTypeToNative<primitive_src_type>::type; + tensorflow::gtl::ArraySlice<NativeSrcT> src_data = + src_literal.data<NativeSrcT>(); + tensorflow::gtl::MutableArraySlice<complex64> dest_data = + result_literal->data<complex64>(); + int64 num_elements = src_literal.element_count(); + for (int64 i = 0; i < num_elements; ++i) { + dest_data[i] = complex64(static_cast<float>(src_data[i]), 0); + } + return result_literal; +} + +template <PrimitiveType primitive_src_type, PrimitiveType primitive_dest_type> +std::unique_ptr<Literal> ConvertIfTypesMatch(const LiteralBase& src_literal, + bool bitcast) { + CHECK_EQ(primitive_src_type, src_literal.shape().element_type()); + if (bitcast) { + return BitcastBetweenNativeTypes< + typename primitive_util::PrimitiveTypeToNative< + primitive_src_type>::type, + typename primitive_util::PrimitiveTypeToNative< + primitive_dest_type>::type>(src_literal); + } else { + return ConvertBetweenNativeTypes< + typename primitive_util::PrimitiveTypeToNative< + primitive_src_type>::type, + typename primitive_util::PrimitiveTypeToNative< + primitive_dest_type>::type>(src_literal); + } +} + +template <PrimitiveType primitive_src_type> +StatusOr<std::unique_ptr<Literal>> ConvertIfDestTypeMatches( + const LiteralBase& src_literal, PrimitiveType primitive_dest_type, + bool bitcast) { + switch (primitive_dest_type) { +#define CONVERT_IF_TYPES_MATCH(type) \ + case (type): \ + return ConvertIfTypesMatch<primitive_src_type, (type)>(src_literal, \ + bitcast); + CONVERT_IF_TYPES_MATCH(PRED) + CONVERT_IF_TYPES_MATCH(S8) + CONVERT_IF_TYPES_MATCH(S32) + CONVERT_IF_TYPES_MATCH(S64) + CONVERT_IF_TYPES_MATCH(U8) + CONVERT_IF_TYPES_MATCH(U32) + CONVERT_IF_TYPES_MATCH(U64) + CONVERT_IF_TYPES_MATCH(F16) + CONVERT_IF_TYPES_MATCH(F32) + CONVERT_IF_TYPES_MATCH(F64) + CONVERT_IF_TYPES_MATCH(BF16) +#undef CONVERT_IF_TYPES_MATCH + case C64: + if (!bitcast) { + return ConvertToC64<primitive_src_type>(src_literal); + } + break; + // Other types are not yet supported. + default: + break; + } + return Unimplemented( + "Converting from type %s to type %s is not implemented.", + PrimitiveType_Name(src_literal.shape().element_type()).c_str(), + PrimitiveType_Name(primitive_dest_type).c_str()); +} + +StatusOr<std::unique_ptr<Literal>> ConvertSwitch( + const LiteralBase& literal, PrimitiveType primitive_dest_type, + bool bitcast) { + TF_RET_CHECK(ShapeUtil::IsArray(literal.shape())); + if (literal.shape().element_type() == primitive_dest_type) { + return literal.CloneToUnique(); + } + switch (literal.shape().element_type()) { +#define CONVERT_IF_DEST_TYPE_MATCHES(type) \ + case (type): \ + return ConvertIfDestTypeMatches<(type)>(literal, primitive_dest_type, \ + bitcast); + CONVERT_IF_DEST_TYPE_MATCHES(PRED) + CONVERT_IF_DEST_TYPE_MATCHES(S8) + CONVERT_IF_DEST_TYPE_MATCHES(S32) + CONVERT_IF_DEST_TYPE_MATCHES(S64) + CONVERT_IF_DEST_TYPE_MATCHES(U8) + CONVERT_IF_DEST_TYPE_MATCHES(U32) + CONVERT_IF_DEST_TYPE_MATCHES(U64) + CONVERT_IF_DEST_TYPE_MATCHES(F16) + CONVERT_IF_DEST_TYPE_MATCHES(F32) + CONVERT_IF_DEST_TYPE_MATCHES(F64) + CONVERT_IF_DEST_TYPE_MATCHES(BF16) +#undef CONVERT_IF_DEST_TYPE_MATCHES + // Other types are not yet supported. + default: + return Unimplemented( + "%s from type %s to type %s is not implemented.", + (bitcast ? "Bitcast converting" : "Converting"), + PrimitiveType_Name(literal.shape().element_type()).c_str(), + PrimitiveType_Name(primitive_dest_type).c_str()); + } +} + +} // namespace + +StatusOr<std::unique_ptr<Literal>> LiteralBase::Convert( + PrimitiveType primitive_dest_type) const { + return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/false); +} + +StatusOr<std::unique_ptr<Literal>> LiteralBase::BitcastConvert( + PrimitiveType primitive_dest_type) const { + if (primitive_util::BitWidth(shape().element_type()) != + primitive_util::BitWidth(primitive_dest_type)) { + return InvalidArgument( + "Cannot bitcast convert from %s to %s, bit widths are different: %d != " + "%d", + PrimitiveType_Name(shape().element_type()).c_str(), + PrimitiveType_Name(primitive_dest_type).c_str(), + primitive_util::BitWidth(shape().element_type()), + primitive_util::BitWidth(primitive_dest_type)); + } + return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/true); +} + +StatusOr<std::unique_ptr<Literal>> LiteralBase::ConvertToShape( + const Shape& dest_shape, bool round_f32_to_bf16) const { + if (!ShapeUtil::IsTuple(dest_shape)) { + if (round_f32_to_bf16 && shape().element_type() == F32 && + dest_shape.element_type() == BF16) { + auto converter = [](float src) { + return tensorflow::bfloat16::round_to_bfloat16(src); + }; + return ConvertBetweenNativeTypesWithConverter<float, bfloat16>(*this, + converter); + } + return Convert(dest_shape.element_type()); + } + std::vector<Literal> elements; + for (int i = 0; i < ShapeUtil::TupleElementCount(shape()); ++i) { + auto element = LiteralSlice(*this, {i}); + TF_ASSIGN_OR_RETURN( + auto new_element, + element.ConvertToShape(ShapeUtil::GetSubshape(dest_shape, {i}))); + elements.push_back(std::move(*new_element)); + } + auto converted = MakeUnique<Literal>(); + *converted = Literal::MoveIntoTuple(&elements); + return std::move(converted); +} + +/* static */ Literal Literal::MoveIntoTuple( + tensorflow::gtl::MutableArraySlice<Literal> elements) { + std::vector<Shape> element_shapes; + for (const Literal& element : elements) { + element_shapes.push_back(element.shape()); + } + Literal literal(ShapeUtil::MakeTupleShape(element_shapes), + /*allocate_arrays=*/false); + for (int i = 0; i < elements.size(); ++i) { + TF_CHECK_OK( + literal.MoveFrom(std::move(elements[i]), /*dest_shape_index=*/{i})); + } + return literal; +} + +template <typename NativeT> +bool LiteralBase::Piece::EqualElementsInternal( + const LiteralBase::Piece& other, std::vector<int64>* multi_index) const { + if (multi_index->size() == ShapeUtil::Rank(subshape())) { + return (Get<NativeT>(*multi_index) == other.Get<NativeT>(*multi_index)); + } + for (int64 i = 0; i < subshape().dimensions(multi_index->size()); ++i) { + multi_index->push_back(i); + if (!EqualElementsInternal<NativeT>(other, multi_index)) { + return false; + } + multi_index->pop_back(); + } + return true; +} + +bool LiteralBase::Piece::EqualElements(const LiteralBase::Piece& other) const { + DCHECK(ShapeUtil::Compatible(subshape(), other.subshape())); + + std::vector<int64> multi_index; + switch (subshape().element_type()) { + case PRED: + return EqualElementsInternal<bool>(other, &multi_index); + case U8: + return EqualElementsInternal<uint8>(other, &multi_index); + case S32: + return EqualElementsInternal<int32>(other, &multi_index); + case S64: + return EqualElementsInternal<int64>(other, &multi_index); + case U32: + return EqualElementsInternal<uint32>(other, &multi_index); + case U64: + return EqualElementsInternal<uint64>(other, &multi_index); + case F32: + return EqualElementsInternal<float>(other, &multi_index); + case F64: + return EqualElementsInternal<double>(other, &multi_index); + case F16: + return EqualElementsInternal<half>(other, &multi_index); + case BF16: + return EqualElementsInternal<bfloat16>(other, &multi_index); + case C64: + return EqualElementsInternal<complex64>(other, &multi_index); + default: + LOG(FATAL) << "Unimplemented: LiteralBase::Piece::EqualElements for type " + << PrimitiveType_Name(subshape().element_type()); + } +} + +bool LiteralBase::operator==(const LiteralBase& other) const { + if (!ShapeUtil::Compatible(shape(), other.shape())) { + return false; + } + + return root_piece().ForEachSubpieceWithBool( + [&](const ShapeIndex& index, const Piece& piece) { + if (!ShapeUtil::IsArray(piece.subshape())) { + return true; + } + + const Piece& other_piece = other.piece(index); + if (!piece.EqualElements(other_piece)) { + return false; + } + return true; + }); +} + +namespace { + +template <typename NativeT> +static bool AllElementsEqualValue(tensorflow::gtl::ArraySlice<NativeT> data, + NativeT value) { + for (int64 i = 0; i < data.size(); ++i) { + if (data[i] != value) { + return false; + } + } + return true; +} + +} // namespace + +bool LiteralBase::IsAll(int8 value) const { + return root_piece().ForEachSubpieceWithBool([&](const ShapeIndex& index, + const Piece& piece) { + if (!ShapeUtil::IsArray(piece.subshape())) { + return true; + } + + auto piece_is_all = [&]() { + switch (shape().element_type()) { + case U8: + if (value >= 0) { + return AllElementsEqualValue<uint8>(piece.data<uint8>(), value); + } + return false; + case U32: + if (value >= 0) { + return AllElementsEqualValue<uint32>(piece.data<uint32>(), value); + } + return false; + case U64: + if (value >= 0) { + return AllElementsEqualValue<uint64>(piece.data<uint64>(), value); + } + return false; + case S8: + return AllElementsEqualValue<int8>(piece.data<int8>(), value); + case S32: + return AllElementsEqualValue<int32>(piece.data<int32>(), value); + case S64: + return AllElementsEqualValue<int64>(piece.data<int64>(), value); + case F32: + return AllElementsEqualValue<float>(piece.data<float>(), value); + case F64: + return AllElementsEqualValue<double>(piece.data<double>(), value); + case F16: + return AllElementsEqualValue<half>(piece.data<half>(), + static_cast<half>(value)); + case BF16: + return AllElementsEqualValue<bfloat16>(piece.data<bfloat16>(), + static_cast<bfloat16>(value)); + case PRED: + if (value == 0) { + return AllElementsEqualValue<bool>(piece.data<bool>(), false); + } + if (value == 1) { + return AllElementsEqualValue<bool>(piece.data<bool>(), true); + } + return false; + default: + return false; + } + return false; + }; + + if (!piece_is_all()) { + return false; + } + return true; + }); +} + +bool LiteralBase::IsAllFloat(float value) const { + return root_piece().ForEachSubpieceWithBool( + [&](const ShapeIndex& index, const Piece& piece) { + if (!ShapeUtil::IsArray(piece.subshape())) { + return true; + } + + auto piece_is_all = [&]() { + switch (shape().element_type()) { + case F32: + return AllElementsEqualValue<float>(piece.data<float>(), value); + case F64: + return AllElementsEqualValue<double>(piece.data<double>(), value); + case F16: + return AllElementsEqualValue<half>(piece.data<half>(), + static_cast<half>(value)); + case BF16: + return AllElementsEqualValue<bfloat16>( + piece.data<bfloat16>(), static_cast<bfloat16>(value)); + default: + return false; + } + }; + if (!piece_is_all()) { + return false; + } + return true; + }); +} + +bool LiteralBase::IsAllComplex(complex64 value) const { + switch (shape().element_type()) { + case C64: + return AllElementsEqualValue<complex64>(root_piece().data<complex64>(), + value); + default: + return false; + } +} + +bool LiteralBase::IsAllFirst() const { + return root_piece().ForEachSubpieceWithBool( + [&](const ShapeIndex& index, const Piece& piece) { + if (!ShapeUtil::IsArray(piece.subshape())) { + return true; + } + + // Empty shapes are not all the first element since there is no first + // element. + if (ShapeUtil::IsZeroElementArray(piece.subshape())) { + return false; + } + auto piece_is_all = [&]() { + switch (piece.subshape().element_type()) { + case PRED: { + auto data = piece.data<bool>(); + return AllElementsEqualValue<bool>(data, data[0]); + } + // 8 bit types + case S8: { + auto data = piece.data<int8>(); + return AllElementsEqualValue<int8>(data, data[0]); + } + case U8: { + auto data = piece.data<uint8>(); + return AllElementsEqualValue<uint8>(data, data[0]); + } + // 16 bit types + case BF16: { + auto data = piece.data<bfloat16>(); + return AllElementsEqualValue<bfloat16>(data, data[0]); + } + case F16: { + auto data = piece.data<half>(); + return AllElementsEqualValue<half>(data, data[0]); + } + case S16: { + auto data = piece.data<int16>(); + return AllElementsEqualValue<int16>(data, data[0]); + } + case U16: { + auto data = piece.data<uint16>(); + return AllElementsEqualValue<uint16>(data, data[0]); + } + // 32 bit types + case F32: { + auto data = piece.data<float>(); + return AllElementsEqualValue<float>(data, data[0]); + } + case U32: { + auto data = piece.data<uint32>(); + return AllElementsEqualValue<uint32>(data, data[0]); + } + case S32: { + auto data = piece.data<int32>(); + return AllElementsEqualValue<int32>(data, data[0]); + } + // 64 bit types + case C64: { + auto data = piece.data<complex64>(); + return AllElementsEqualValue<complex64>(data, data[0]); + } + case F64: { + auto data = piece.data<double>(); + return AllElementsEqualValue<double>(data, data[0]); + } + case S64: { + auto data = piece.data<int64>(); + return AllElementsEqualValue<int64>(data, data[0]); + } + case U64: { + auto data = piece.data<uint64>(); + return AllElementsEqualValue<uint64>(data, data[0]); + } + default: + return false; + } + }; + + if (!piece_is_all()) { + return false; + } + return true; + }); +} + +bool LiteralBase::IsZero(tensorflow::gtl::ArraySlice<int64> indices) const { + CHECK(ShapeUtil::IsArray(shape())); + switch (shape().element_type()) { + case U8: + return Get<uint8>(indices) == 0; + case U32: + return Get<uint32>(indices) == 0; + case U64: + return Get<uint64>(indices) == 0; + case S8: + return Get<int8>(indices) == 0; + case S32: + return Get<int32>(indices) == 0; + case S64: + return Get<int64>(indices) == 0; + case F32: + return Get<float>(indices) == 0.0f; + case F64: + return Get<double>(indices) == 0.0; + case C64: + return Get<complex64>(indices) == complex64(0.0f, 0.0f); + case F16: + return Get<half>(indices) == static_cast<half>(0.0f); + case BF16: + return Get<bfloat16>(indices) == static_cast<bfloat16>(0.0f); + case PRED: + return Get<bool>(indices) == false; + default: + LOG(FATAL) << "Input literal must be an array."; + } +} + +namespace { + +template <typename RepeatedFieldT, typename NativeT> +void CopyToRepeatedField(RepeatedFieldT* dest, + const tensorflow::gtl::ArraySlice<NativeT> src) { + *dest = RepeatedFieldT(src.begin(), src.end()); +} + +} // namespace + +void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const { + *proto->mutable_shape() = subshape(); + switch (subshape().element_type()) { + case PRED: + CopyToRepeatedField(proto->mutable_preds(), data<bool>()); + break; + case U8: + proto->set_u8s(static_cast<const unsigned char*>(data<uint8>().data()), + element_count()); + break; + case U32: + CopyToRepeatedField(proto->mutable_u32s(), data<uint32>()); + break; + case U64: + CopyToRepeatedField(proto->mutable_u64s(), data<uint64>()); + break; + case S32: + CopyToRepeatedField(proto->mutable_s32s(), data<int32>()); + break; + case S64: + CopyToRepeatedField(proto->mutable_s64s(), data<int64>()); + break; + case F16: + *proto->mutable_f16s() = string( + reinterpret_cast<const char*>(data<half>().data()), size_bytes()); + if (!kLittleEndian) { + ConvertEndianShort(proto->mutable_f16s()); + } + break; + case BF16: + *proto->mutable_bf16s() = string( + reinterpret_cast<const char*>(data<bfloat16>().data()), size_bytes()); + if (!kLittleEndian) { + ConvertEndianShort(proto->mutable_bf16s()); + } + break; + case F32: + CopyToRepeatedField(proto->mutable_f32s(), data<float>()); + break; + case F64: + CopyToRepeatedField(proto->mutable_f64s(), data<double>()); + break; + case C64: + for (complex64 value : data<complex64>()) { + proto->add_c64s(value.real()); + proto->add_c64s(value.imag()); + } + break; + case TUPLE: + case TOKEN: + // Nothing to do but assign the shape which is done above. + return; + default: + // TODO(b/111551621): Support serializing more PrimitiveTypes. + LOG(FATAL) << "Unhandled primitive type " + << PrimitiveType_Name(subshape().element_type()); + } +} + +const void* LiteralBase::Piece::untyped_data() const { + CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); + return buffer(); +} + +void* LiteralBase::Piece::untyped_data() { + CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); + return buffer(); +} + +namespace { + +template <typename RepeatedFieldT, typename NativeT> +Status CopyFromRepeatedField(tensorflow::gtl::MutableArraySlice<NativeT> dest, + const RepeatedFieldT& src) { + if (dest.size() != src.size()) { + return InvalidArgument( + "Expected %lu elements in LiteralProto repeated field, has %d", + dest.size(), src.size()); + } + std::copy(src.begin(), src.end(), dest.begin()); + return Status::OK(); +} + +} // namespace + +Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) { + // These conditions should have been checked in Literal::CreateFromProto. + TF_RET_CHECK(proto.has_shape()); + TF_RET_CHECK(LayoutUtil::HasLayout(proto.shape())); + TF_RET_CHECK(ShapeUtil::Equal(proto.shape(), subshape())); + + switch (subshape().element_type()) { + case PRED: + TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<bool>(), proto.preds())); + break; + case U8: { + auto u8_data = data<uint8>(); + TF_RET_CHECK(proto.u8s().size() == u8_data.size()); + std::copy(proto.u8s().begin(), proto.u8s().end(), u8_data.begin()); + } break; + case S32: + TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<int32>(), proto.s32s())); + break; + case S64: + TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<int64>(), proto.s64s())); + break; + case U32: + TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<uint32>(), proto.u32s())); + break; + case U64: + TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<uint64>(), proto.u64s())); + break; + case F16: { + const string& s(proto.f16s()); + TF_RET_CHECK(data<half>().size() * sizeof(half) == s.size()); + memcpy(untyped_data(), s.data(), s.size()); + if (!kLittleEndian) { + ConvertEndianShort(reinterpret_cast<char*>(untyped_data()), s.size()); + } + } break; + + case BF16: { + const string& s(proto.bf16s()); + TF_RET_CHECK(data<bfloat16>().size() * sizeof(bfloat16) == s.size()); + memcpy(untyped_data(), s.data(), s.size()); + if (!kLittleEndian) { + ConvertEndianShort(reinterpret_cast<char*>(untyped_data()), s.size()); + } + } break; + case F32: + TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<float>(), proto.f32s())); + break; + case F64: + TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<double>(), proto.f64s())); + break; + case C64: { + auto complex_data = data<complex64>(); + TF_RET_CHECK(proto.c64s_size() == complex_data.size() * 2); + for (int64 i = 0; i < complex_data.size(); ++i) { + complex_data[i] = complex64{proto.c64s(i * 2), proto.c64s(i * 2 + 1)}; + } + } break; + case TUPLE: + LOG(FATAL) << "Should not be called on tuple shapes: " + << ShapeUtil::HumanString(subshape()); + break; + default: + LOG(FATAL) << "Unhandled primitive type " << subshape().element_type(); + } + return Status::OK(); +} + +LiteralProto LiteralBase::ToProto() const { + LiteralProto proto; + root_piece().ForEachSubpiece( + [&](const ShapeIndex& index, const Piece& piece) { + LiteralProto* proto_piece = &proto; + for (int64 i : index) { + while (proto_piece->tuple_literals_size() <= i) { + proto_piece->add_tuple_literals(); + } + proto_piece = proto_piece->mutable_tuple_literals(i); + } + piece.WriteToProto(proto_piece); + }); + + if (LayoutUtil::IsSparseArray(shape())) { + CopyToRepeatedField(proto.mutable_sparse_indices(), + sparse_indices()->data()); + } + + return proto; +} + +const void* LiteralBase::untyped_data(const ShapeIndex& shape_index) const { + return piece(shape_index).untyped_data(); +} + +void* Literal::untyped_data(const ShapeIndex& shape_index) { + return piece(shape_index).untyped_data(); +} + +int64 LiteralBase::size_bytes(const ShapeIndex& shape_index) const { + return piece(shape_index).size_bytes(); +} + +string LiteralBase::GetR1U8AsString() const { + CHECK(ShapeUtil::IsArray(shape())); + CHECK_EQ(ShapeUtil::Rank(shape()), 1); + CHECK_EQ(shape().element_type(), U8); + return string(tensorflow::bit_cast<const char*>(data<uint8>().data()), + ShapeUtil::ElementsIn(shape())); +} + +void BorrowingLiteral::BuildPieceSubtree(const Shape& shape, Piece* piece) { + CHECK(ShapeUtil::IsTuple(shape)); + for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { + const Shape& subshape = shape.tuple_shapes(i); + + auto child_piece = Piece(); + child_piece.set_subshape(&subshape); + + if (ShapeUtil::IsTuple(subshape)) { + BuildPieceSubtree(subshape, &child_piece); + } + + piece->emplace_back(std::move(child_piece)); + } +} + +LiteralSlice::LiteralSlice(const LiteralBase& literal) + : LiteralBase(), root_piece_(&literal.root_piece()) {} + +LiteralSlice::LiteralSlice(const LiteralBase& literal, + const ShapeIndex& view_root) + : LiteralBase(), root_piece_(&literal.piece(view_root)) {} + +BorrowingLiteral::BorrowingLiteral(const char* src_buf_ptr, const Shape& shape) + : LiteralBase(), shape_(MakeUnique<Shape>(shape)) { + CHECK(ShapeUtil::IsArray(*shape_)); + CHECK(LayoutUtil::HasLayout(*shape_)); + + root_piece_ = Piece(); + root_piece_.set_buffer(const_cast<char*>(src_buf_ptr)); + root_piece_.set_subshape(shape_.get()); +} + +BorrowingLiteral::BorrowingLiteral( + tensorflow::gtl::ArraySlice<const char*> src_buf_ptrs, const Shape& shape) + : LiteralBase(), shape_(MakeUnique<Shape>(shape)) { + CHECK(ShapeUtil::IsTuple(*shape_)); + CHECK(!ShapeUtil::IsNestedTuple(*shape_)); + CHECK_EQ(src_buf_ptrs.size(), ShapeUtil::TupleElementCount(*shape_)); + root_piece_ = Piece(); + root_piece_.set_subshape(shape_.get()); + BuildPieceSubtree(*shape_, &root_piece_); + + for (int i = 0; i < src_buf_ptrs.size(); ++i) { + const auto& src_shape = shape_->tuple_shapes(i); + CHECK(ShapeUtil::IsArray(src_shape)); + root_piece_.child(i).set_buffer(const_cast<char*>(src_buf_ptrs[i])); + } +} + +} // namespace xla |