diff options
Diffstat (limited to 'tensorflow/compiler/xla/literal_util.cc')
-rw-r--r-- | tensorflow/compiler/xla/literal_util.cc | 935 |
1 files changed, 510 insertions, 425 deletions
diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc index b3b5e34ba2..82a2bcad76 100644 --- a/tensorflow/compiler/xla/literal_util.cc +++ b/tensorflow/compiler/xla/literal_util.cc @@ -62,8 +62,49 @@ void ConvertEndianShort(char* bytes, int64 size) { } } +// Return a literal with all arrays of type FromNativeT converted to type +// ToNativeT in the given literal. +template <typename FromNativeT, typename ToNativeT> +std::unique_ptr<Literal> ConvertType(LiteralSlice literal) { + // First construct shape of the result. + Shape result_shape(literal.shape()); + ShapeUtil::ForEachMutableSubshape( + &result_shape, [](Shape* subshape, const ShapeIndex&) { + if (subshape->element_type() == + primitive_util::NativeToPrimitiveType<FromNativeT>()) { + subshape->set_element_type( + primitive_util::NativeToPrimitiveType<ToNativeT>()); + } + }); + auto result = MakeUnique<Literal>(result_shape); + + // Then copy over the data from 'literal' converting FromNativeT values to + // ToNativeT values as necessary. + ShapeUtil::ForEachSubshape( + literal.shape(), + [&](const Shape& subshape, const ShapeIndex& shape_index) { + if (ShapeUtil::IsArray(subshape)) { + if (subshape.element_type() == + primitive_util::NativeToPrimitiveType<FromNativeT>()) { + auto src = literal.data<FromNativeT>(shape_index); + auto dest = result->data<ToNativeT>(shape_index); + for (int64 i = 0; i < src.size(); ++i) { + dest[i] = static_cast<ToNativeT>(src[i]); + } + } else { + TF_CHECK_OK(result->CopyFrom(literal, + /*dest_shape_index=*/shape_index, + /*src_shape_index=*/shape_index)); + } + } + }); + return result; +} + } // namespace +LiteralBase::~LiteralBase() {} + std::ostream& operator<<(std::ostream& out, const Literal& literal) { out << literal.ToString(); return out; @@ -95,99 +136,90 @@ Literal::StrideConfig::StrideConfig( Literal::Literal(const Shape& shape) : Literal(shape, /*allocate_arrays=*/true) {} -Literal::Literal(const Shape& shape, bool allocate_arrays) - : shape_(shape), pieces_(shape), owns_buffers_(true) { - CHECK(LayoutUtil::HasLayout(shape)); - for (auto& pair : pieces_) { - const ShapeIndex& index = pair.first; - Piece& piece = pair.second; - - piece.set_subshape(&ShapeUtil::GetSubshape(shape_, index)); - const Shape& subshape = piece.subshape(); - if (ShapeUtil::IsArray(subshape)) { - if (allocate_arrays) { - if (LayoutUtil::IsSparseArray(subshape)) { - // For sparse arrays, the buffer must be of the size of the maximum - // number of sparse elements possible. - const int64 max_sparse_elements = - LayoutUtil::MaxSparseElements(subshape.layout()); - piece.set_buffer( - new char[max_sparse_elements * ShapeUtil::ByteSizeOfPrimitiveType( - subshape.element_type())]); - piece.set_sparse_indices(new SparseIndexArray( - max_sparse_elements, ShapeUtil::Rank(subshape))); - } else { - piece.set_buffer(new char[piece.size_bytes()]); - } +void Literal::SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays) { + if (ShapeUtil::IsTuple(shape)) { + for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { + const Shape& subshape = shape.tuple_shapes(i); + + auto child_piece = Piece(); + child_piece.set_subshape(&subshape); + + SetPiece(subshape, &child_piece, allocate_arrays); + + piece->emplace_back(std::move(child_piece)); + } + } else { + CHECK(ShapeUtil::IsArray(shape)); + if (allocate_arrays) { + if (LayoutUtil::IsSparseArray(shape)) { + // For sparse arrays, the buffer must be of the size of the maximum + // number of sparse elements possible. + const int64 max_sparse_elements = + LayoutUtil::MaxSparseElements(shape.layout()); + piece->set_buffer( + new char[max_sparse_elements * + ShapeUtil::ByteSizeOfPrimitiveType(shape.element_type())]); + piece->set_sparse_indices( + new SparseIndexArray(max_sparse_elements, ShapeUtil::Rank(shape))); } else { - piece.set_buffer(nullptr); + piece->set_buffer(new char[piece->size_bytes()]); } } } } -Literal::~Literal() { DeallocateBuffers(); } +Literal::Literal(const Shape& shape, bool allocate_arrays) + : LiteralBase(), shape_(MakeUnique<Shape>(shape)) { + CHECK(LayoutUtil::HasLayout(*shape_)); + root_piece_ = new Piece(); + root_piece_->set_subshape(shape_.get()); + CHECK(&root_piece_->subshape() == shape_.get()); -void Literal::DeallocateBuffers() { - if (owns_buffers_) { - for (auto& pair : pieces_) { - Piece& piece = pair.second; - if (piece.buffer() != nullptr) { - delete[] piece.buffer(); - delete piece.sparse_indices(); - } - } - } + SetPiece(*shape_, root_piece_, allocate_arrays); } -Literal::Literal(Literal&& other) { - shape_ = std::move(other.shape_); - pieces_ = std::move(other.pieces_); - // We need to iterate through the pieces to set the subshape pointer - // properly. It must refer to subshapes within shape_. - for (auto& pair : pieces_) { - const ShapeIndex& index = pair.first; - Piece& piece = pair.second; - piece.set_subshape(&ShapeUtil::GetSubshape(shape_, index)); +Literal::~Literal() { + if (root_piece_ != nullptr) { + DeallocateBuffers(); + delete root_piece_; } - owns_buffers_ = other.owns_buffers_; +} - other.shape_ = ShapeUtil::MakeNil(); - other.pieces_ = ShapeTree<Piece>(other.shape_); - other.piece({}).set_subshape(&other.shape_); +void Literal::DeallocateBuffers() { + root_piece_->ForEachMutableSubpiece( + [&](const ShapeIndex& index, Piece* piece) { + if (piece->buffer() != nullptr) { + delete[] piece->buffer(); + delete piece->sparse_indices(); + } + }); } +Literal::Literal(Literal&& other) : LiteralBase() { *this = std::move(other); } + Literal& Literal::operator=(Literal&& other) { - DeallocateBuffers(); - shape_ = std::move(other.shape_); - pieces_ = std::move(other.pieces_); - // We need to iterate through the pieces to set the subshape pointer - // properly. It must refer to subshapes within shape_. - for (auto& pair : pieces_) { - const ShapeIndex& index = pair.first; - Piece& piece = pair.second; - piece.set_subshape(&ShapeUtil::GetSubshape(shape_, index)); - } - owns_buffers_ = other.owns_buffers_; - - other.shape_ = ShapeUtil::MakeNil(); - other.pieces_ = ShapeTree<Piece>(other.shape_); - other.piece({}).set_subshape(&other.shape_); + CHECK(&other.root_piece_->subshape() == other.shape_.get()); + + using std::swap; + swap(shape_, other.shape_); + swap(root_piece_, other.root_piece_); + CHECK(&root_piece_->subshape() == shape_.get()); + return *this; } -std::unique_ptr<Literal> Literal::CreateFromShape(const Shape& shape) { +std::unique_ptr<Literal> LiteralBase::CreateFromShape(const Shape& shape) { auto literal = MakeUnique<Literal>(shape); - for (auto& pair : literal->pieces_) { - Piece& piece = pair.second; - if (ShapeUtil::IsArray(piece.subshape())) { - memset(piece.untyped_data(), 0, piece.size_bytes()); - } - } + literal->root_piece_->ForEachMutableSubpiece( + [&](const ShapeIndex& index, Piece* piece) { + if (ShapeUtil::IsArray(piece->subshape())) { + memset(piece->untyped_data(), 0, piece->size_bytes()); + } + }); return literal; } -const SparseIndexArray* Literal::sparse_indices( +const SparseIndexArray* LiteralBase::sparse_indices( const ShapeIndex& shape_index) const { return piece(shape_index).sparse_indices(); } @@ -202,9 +234,19 @@ SparseIndexArray* Literal::sparse_indices(const ShapeIndex& shape_index) { return CreateFromShape(ShapeUtil::MakeShape(primitive_type, dimensions)); } +/* static */ std::unique_ptr<Literal> Literal::ConvertBF16ToF32( + const LiteralSlice& bf16_literal) { + return ConvertType<bfloat16, float>(bf16_literal); +} + +/* static */ std::unique_ptr<Literal> Literal::ConvertF32ToBF16( + const LiteralSlice& f32_literal) { + return ConvertType<float, bfloat16>(f32_literal); +} + template <typename NativeT> Status Literal::CopySliceFromInternal( - const Literal& src_literal, tensorflow::gtl::ArraySlice<int64> src_base, + const LiteralBase& src_literal, tensorflow::gtl::ArraySlice<int64> src_base, tensorflow::gtl::ArraySlice<int64> dest_base, tensorflow::gtl::ArraySlice<int64> copy_size) { TF_RET_CHECK(ShapeUtil::Rank(src_literal.shape()) == src_base.size()); @@ -217,8 +259,8 @@ Status Literal::CopySliceFromInternal( if (ShapeUtil::Rank(src_literal.shape()) == 0 || ShapeUtil::Rank(shape()) == 0) { - // If any of the two shapes are scalars, we can just call the StridedCopy() - // directly, and we know we will be copying only one value. + // If any of the two shapes are scalars, we can just call the + // StridedCopy() directly, and we know we will be copying only one value. TF_RET_CHECK(copy_size.empty()); StridedCopy(data<NativeT>(), linear_index(shape(), dest_base), 0, src_literal.data<NativeT>(), @@ -264,7 +306,7 @@ Status Literal::CopySliceFromInternal( return Status::OK(); } -Status Literal::CopyElementFrom(const Literal& src_literal, +Status Literal::CopyElementFrom(const LiteralSlice& src_literal, tensorflow::gtl::ArraySlice<int64> src_index, tensorflow::gtl::ArraySlice<int64> dest_index) { DCHECK_EQ(shape().element_type(), src_literal.shape().element_type()); @@ -293,22 +335,21 @@ std::vector<Literal> Literal::DecomposeTuple() { elements.push_back(Literal(ShapeUtil::GetSubshape(shape(), {i}), /*allocate_arrays=*/false)); Literal& element = elements.back(); - for (auto& pair : element.pieces_) { - const ShapeIndex& index = pair.first; - Piece& dest_piece = pair.second; - ShapeIndex src_index = {i}; - for (int64 j : index) { - src_index.push_back(j); - } - Piece& src_piece = piece(src_index); - - // Move the respective buffer and sparse indices over to the element - // Literal. - dest_piece.set_buffer(src_piece.buffer()); - src_piece.set_buffer(nullptr); - dest_piece.set_sparse_indices(src_piece.sparse_indices()); - src_piece.set_sparse_indices(nullptr); - } + element.root_piece_->ForEachMutableSubpiece( + [&](const ShapeIndex& index, Piece* dest_piece) { + ShapeIndex src_index = {i}; + for (int64 j : index) { + src_index.push_back(j); + } + Piece& src_piece = piece(src_index); + + // Move the respective buffer and sparse indices over to the element + // Literal. + dest_piece->set_buffer(src_piece.buffer()); + src_piece.set_buffer(nullptr); + dest_piece->set_sparse_indices(src_piece.sparse_indices()); + src_piece.set_sparse_indices(nullptr); + }); } // Set this literal to be nil-shaped. *this = Literal(); @@ -331,9 +372,9 @@ std::vector<Literal> Literal::DecomposeTuple() { } namespace { - -// Copies the elements in 'src' to 'dest'. The shape and layout of the data in -// the array slices are indicated by dest_shape and src_shape respectively. +// Copies the elements in 'src' to 'dest'. The shape and layout of the data +// in the array slices are indicated by dest_shape and src_shape +// respectively. template <typename NativeT> void CopyElementsBetween(tensorflow::gtl::MutableArraySlice<NativeT> dest, tensorflow::gtl::ArraySlice<NativeT> src, @@ -351,7 +392,7 @@ void CopyElementsBetween(tensorflow::gtl::MutableArraySlice<NativeT> dest, } // namespace -Status Literal::Piece::CopyFrom(const Literal::Piece& src) { +Status LiteralBase::Piece::CopyFrom(const LiteralBase::Piece& src) { if (ShapeUtil::Equal(subshape(), src.subshape())) { // If the layouts are equal it's faster just to memcpy. memcpy(buffer(), src.buffer(), src.size_bytes()); @@ -381,14 +422,15 @@ Status Literal::Piece::CopyFrom(const Literal::Piece& src) { #undef COPY_ELEMENTS default: return Unimplemented( - "Copying a Literal object with element type %s is not implemented.", + "Copying a Literal object with element type %s is not " + "implemented.", PrimitiveType_Name(subshape().element_type()).c_str()); } } return Status::OK(); } -Status Literal::CopyFrom(const Literal& src_literal, +Status Literal::CopyFrom(const LiteralSlice& src_literal, const ShapeIndex& dest_shape_index, const ShapeIndex& src_shape_index) { const Shape& dest_subshape = @@ -402,36 +444,33 @@ Status Literal::CopyFrom(const Literal& src_literal, ShapeUtil::HumanString(src_subshape).c_str()); } - for (auto& pair : pieces_) { - const ShapeIndex& index = pair.first; - Piece& piece = pair.second; - if (!ShapeUtil::IsArray(piece.subshape())) { - continue; - } - - // Determine if this index is in the part of this literal that we want to - // copy over from src_literal. - bool in_subtree_to_copy = true; - for (int i = 0; i < dest_shape_index.size(); ++i) { - if (index[i] != dest_shape_index[i]) { - in_subtree_to_copy = false; - break; - } - } - if (!in_subtree_to_copy) { - continue; - } - - // Construct the index of the corresponding piece in the source literal. - ShapeIndex src_piece_index = src_shape_index; - for (int64 i = dest_shape_index.size(); i < index.size(); ++i) { - src_piece_index.push_back(index[i]); - } + return root_piece_->ForEachMutableSubpieceWithStatus( + [&](const ShapeIndex& index, Piece* piece) { + if (!ShapeUtil::IsArray(piece->subshape())) { + return Status::OK(); + } - TF_RETURN_IF_ERROR(piece.CopyFrom(src_literal.piece(src_piece_index))); - } - return Status::OK(); -} + // Determine if this index is in the part of this literal that we want + // to copy over from src_literal. + bool in_subtree_to_copy = true; + for (int i = 0; i < dest_shape_index.size(); ++i) { + if (index[i] != dest_shape_index[i]) { + in_subtree_to_copy = false; + break; + } + } + if (!in_subtree_to_copy) { + return Status::OK(); + } + // Construct the index of the corresponding piece in the source literal. + ShapeIndex src_piece_index = src_shape_index; + for (int64 i = dest_shape_index.size(); i < index.size(); ++i) { + src_piece_index.push_back(index[i]); + } + TF_RETURN_IF_ERROR(piece->CopyFrom(src_literal.piece(src_piece_index))); + return Status::OK(); + }); +} // namespace xla Status Literal::MoveFrom(Literal&& src_literal, const ShapeIndex& dest_shape_index) { @@ -444,37 +483,32 @@ Status Literal::MoveFrom(Literal&& src_literal, ShapeUtil::HumanString(src_literal.shape()).c_str()); } - if (!(owns_buffers_ && src_literal.owns_buffers_)) { - return InvalidArgument( - "Source and destination literals must both own their buffers (ie, not " - "be views)"); - } + src_literal.root_piece_->ForEachSubpiece( + [&](const ShapeIndex& src_index, const Piece& src_piece) { + if (!ShapeUtil::IsArray(src_piece.subshape())) { + return; + } - for (auto& pair : src_literal.pieces_) { - const ShapeIndex& src_index = pair.first; - Piece& src_piece = pair.second; - if (!ShapeUtil::IsArray(src_piece.subshape())) { - continue; - } + ShapeIndex dest_index = dest_shape_index; + for (int64 i : src_index) { + dest_index.push_back(i); + } + Piece& dest_piece = piece(dest_index); + delete[] dest_piece.buffer(); + dest_piece.set_buffer(src_piece.buffer()); + delete dest_piece.sparse_indices(); + dest_piece.set_sparse_indices(src_piece.sparse_indices()); + }); - ShapeIndex dest_index = dest_shape_index; - for (int64 i : src_index) { - dest_index.push_back(i); - } - Piece& dest_piece = piece(dest_index); - delete[] dest_piece.buffer(); - dest_piece.set_buffer(src_piece.buffer()); - delete dest_piece.sparse_indices(); - dest_piece.set_sparse_indices(src_piece.sparse_indices()); - } + src_literal.shape_ = MakeUnique<Shape>(ShapeUtil::MakeNil()); + delete src_literal.root_piece_; + src_literal.root_piece_ = new LiteralBase::Piece(); + src_literal.root_piece_->set_subshape(src_literal.shape_.get()); - src_literal.shape_ = ShapeUtil::MakeNil(); - src_literal.pieces_ = ShapeTree<Piece>(src_literal.shape_); - src_literal.piece({}).set_subshape(&src_literal.shape_); return Status::OK(); } -Status Literal::CopySliceFrom(const Literal& src_literal, +Status Literal::CopySliceFrom(const LiteralSlice& src_literal, tensorflow::gtl::ArraySlice<int64> src_base, tensorflow::gtl::ArraySlice<int64> dest_base, tensorflow::gtl::ArraySlice<int64> copy_size) { @@ -743,7 +777,7 @@ void Literal::PopulateR1(const tensorflow::core::Bitmap& values) { return CreateR2FromArray2D(*value); } -std::unique_ptr<Literal> Literal::Relayout( +std::unique_ptr<Literal> LiteralBase::Relayout( const Layout& new_layout, const ShapeIndex& shape_index) const { // Create new shape with 'new_layout' set at the given shape index. Shape new_shape = shape(); @@ -755,7 +789,7 @@ std::unique_ptr<Literal> Literal::Relayout( return result; } -std::unique_ptr<Literal> Literal::Relayout( +std::unique_ptr<Literal> LiteralBase::Relayout( const Shape& shape_with_layout) const { CHECK(ShapeUtil::Compatible(shape_with_layout, shape())) << "Given shape_with_layout " << ShapeUtil::HumanString(shape_with_layout) @@ -774,7 +808,7 @@ std::unique_ptr<Literal> Literal::Relayout( return result; } -StatusOr<std::unique_ptr<Literal>> Literal::Reshape( +StatusOr<std::unique_ptr<Literal>> LiteralBase::Reshape( tensorflow::gtl::ArraySlice<int64> dimensions) const { if (!ShapeUtil::IsArray(shape())) { return InvalidArgument("Reshape does not support tuples."); @@ -788,7 +822,8 @@ StatusOr<std::unique_ptr<Literal>> Literal::Reshape( } // Because the layout is monotonic, we can simply reuse the same sequence of // values without changing their order. - output->shape_ = ShapeUtil::MakeShape(shape().element_type(), dimensions); + *output->mutable_shape_do_not_use() = + ShapeUtil::MakeShape(shape().element_type(), dimensions); int64 elements_before = ShapeUtil::ElementsIn(shape()); int64 elements_after = ShapeUtil::ElementsIn(output->shape()); @@ -802,7 +837,79 @@ StatusOr<std::unique_ptr<Literal>> Literal::Reshape( return std::move(output); } -std::unique_ptr<Literal> Literal::Transpose( +/* static */ std::unique_ptr<Literal> Literal::ReshapeSlice( + tensorflow::gtl::ArraySlice<int64> new_dimensions, + tensorflow::gtl::ArraySlice<int64> minor_to_major, + const LiteralSlice& literal) { + int64 new_num_elements = 1; + for (int64 i = 0; i < new_dimensions.size(); ++i) { + new_num_elements *= new_dimensions[i]; + } + CHECK_EQ(ShapeUtil::ElementsIn(literal.shape()), new_num_elements); + CHECK_EQ(new_dimensions.size(), minor_to_major.size()); + + auto new_literal = MakeUnique<Literal>( + ShapeUtil::MakeShape(literal.shape().element_type(), new_dimensions)); + + // Create a new shape with the given minor-to-major layout. This shape is used + // solely for converting linear address to multi-dimensional addresses when + // writing elements to the new literal. + Shape shape_with_layout = new_literal->shape(); + *shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major); + + // Copy data into new literal, element-by-element. + for (int64 i = 0; i < ShapeUtil::ElementsIn(literal.shape()); ++i) { + std::vector<int64> from_multi_index = + IndexUtil::LinearIndexToMultidimensionalIndex(literal.shape(), i); + std::vector<int64> to_multi_index = + IndexUtil::LinearIndexToMultidimensionalIndex(shape_with_layout, i); + switch (literal.shape().element_type()) { + case PRED: + new_literal->Set<bool>(to_multi_index, + literal.Get<bool>(from_multi_index)); + break; + case U8: + new_literal->Set<uint8>(to_multi_index, + literal.Get<uint8>(from_multi_index)); + break; + case U32: + new_literal->Set<uint32>(to_multi_index, + literal.Get<uint32>(from_multi_index)); + break; + case S32: + new_literal->Set<int32>(to_multi_index, + literal.Get<int32>(from_multi_index)); + break; + case U64: + new_literal->Set<uint64>(to_multi_index, + literal.Get<uint64>(from_multi_index)); + break; + case S64: + new_literal->Set<int64>(to_multi_index, + literal.Get<int64>(from_multi_index)); + break; + case F32: + new_literal->Set<float>(to_multi_index, + literal.Get<float>(from_multi_index)); + break; + case F64: + new_literal->Set<double>(to_multi_index, + literal.Get<double>(from_multi_index)); + break; + case C64: + new_literal->Set<complex64>(to_multi_index, + literal.Get<complex64>(from_multi_index)); + break; + default: + LOG(FATAL) << "Unhandled primitive element type: " + << PrimitiveType_Name(literal.shape().element_type()); + } + } + + return new_literal; +} + +std::unique_ptr<Literal> LiteralBase::Transpose( tensorflow::gtl::ArraySlice<int64> permutation) const { CHECK(ShapeUtil::IsArray(shape())) << "Tuple is not supported for transpose"; CHECK(IsPermutation(permutation, ShapeUtil::Rank(shape()))) @@ -819,8 +926,8 @@ std::unique_ptr<Literal> Literal::Transpose( // representation intact. // For example, consider the shape F32[11,8]{1,0} under a {1,0} permutation. // The shape with affine layout resulting from that operation will be - // F32[8,11]{0,1}, since it leaves the original most minor (the 8 sized), the - // most minor. + // F32[8,11]{0,1}, since it leaves the original most minor (the 8 sized), + // the most minor. // // Essentially, given MinMaj(Di) the position of the Di dimension within the // minor to major vector, and given T(Di) the index that the original Di @@ -836,12 +943,11 @@ std::unique_ptr<Literal> Literal::Transpose( std::unique_ptr<Literal> new_literal = CreateFromShape(permuted_shape); DCHECK_GE(ShapeUtil::ByteSizeOf(new_literal->shape()), ShapeUtil::ByteSizeOf(shape())); - std::memcpy(new_literal->root_piece().buffer(), root_piece().buffer(), - root_piece().size_bytes()); + std::memcpy(new_literal->untyped_data(), untyped_data(), size_bytes()); return new_literal; } -std::unique_ptr<Literal> Literal::Slice( +std::unique_ptr<Literal> LiteralBase::Slice( tensorflow::gtl::ArraySlice<int64> start_indices, tensorflow::gtl::ArraySlice<int64> limit_indices) const { CHECK(ShapeUtil::IsArray(shape())) << "tuple is not supported for slice"; @@ -909,20 +1015,20 @@ std::unique_ptr<Literal> Literal::Slice( } } -Literal Literal::Clone() const { +Literal LiteralBase::Clone() const { Literal result(shape()); TF_CHECK_OK(result.CopyFrom(*this)); return result; } -std::unique_ptr<Literal> Literal::CloneToUnique() const { +std::unique_ptr<Literal> LiteralBase::CloneToUnique() const { auto result = MakeUnique<Literal>(shape()); TF_CHECK_OK(result->CopyFrom(*this)); return result; } -string Literal::GetAsString(tensorflow::gtl::ArraySlice<int64> multi_index, - const ShapeIndex& shape_index) const { +string LiteralBase::GetAsString(tensorflow::gtl::ArraySlice<int64> multi_index, + const ShapeIndex& shape_index) const { const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index); CHECK(LayoutUtil::IsDenseArray(subshape)); switch (subshape.element_type()) { @@ -962,8 +1068,8 @@ string Literal::GetAsString(tensorflow::gtl::ArraySlice<int64> multi_index, } } -string Literal::GetSparseElementAsString(int64 sparse_element_number, - const ShapeIndex& shape_index) const { +string LiteralBase::GetSparseElementAsString( + int64 sparse_element_number, const ShapeIndex& shape_index) const { const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index); CHECK(LayoutUtil::IsSparseArray(subshape)); switch (subshape.element_type()) { @@ -1017,7 +1123,7 @@ string Literal::GetSparseElementAsString(int64 sparse_element_number, } } -StatusOr<int64> Literal::GetIntegralAsS64( +StatusOr<int64> LiteralBase::GetIntegralAsS64( tensorflow::gtl::ArraySlice<int64> multi_index) const { CHECK(LayoutUtil::IsDenseArray(shape())); switch (shape().element_type()) { @@ -1070,7 +1176,7 @@ Status Literal::SetIntegralAsS64(tensorflow::gtl::ArraySlice<int64> multi_index, return Status::OK(); } -tensorflow::gtl::ArraySlice<int64> Literal::GetSparseIndex( +tensorflow::gtl::ArraySlice<int64> LiteralBase::GetSparseIndex( int64 sparse_element_number, const ShapeIndex& shape_index) const { const Piece& p = piece(shape_index); CHECK_GE(sparse_element_number, 0); @@ -1082,10 +1188,10 @@ void Literal::SortSparseElements(const ShapeIndex& shape_index) { piece(shape_index).SortSparseElements(); } -Literal Literal::GetFirstScalarLiteral() const { - CHECK(ShapeUtil::IsArray(shape_)); - CHECK_GT(ShapeUtil::ElementsIn(shape_), 0); - switch (shape_.element_type()) { +Literal LiteralBase::GetFirstScalarLiteral() const { + CHECK(ShapeUtil::IsArray(shape())); + CHECK_GT(ShapeUtil::ElementsIn(shape()), 0); + switch (shape().element_type()) { case PRED: return std::move(*Literal::CreateR0<bool>(GetFirstElement<bool>())); // 8 bit types. @@ -1121,11 +1227,11 @@ Literal Literal::GetFirstScalarLiteral() const { case U64: return std::move(*Literal::CreateR0<uint64>(GetFirstElement<uint64>())); default: - LOG(FATAL) << "Unhandled primitive type " << shape_.element_type(); + LOG(FATAL) << "Unhandled primitive type " << shape().element_type(); } } -void Literal::Piece::SortSparseElements() { +void LiteralBase::Piece::SortSparseElements() { switch (subshape().element_type()) { case PRED: SortSparseElementsInternal<bool>(); @@ -1176,7 +1282,7 @@ void Literal::Piece::SortSparseElements() { } template <typename NativeT> -void Literal::Piece::SortSparseElementsInternal() { +void LiteralBase::Piece::SortSparseElementsInternal() { CHECK(LayoutUtil::IsSparseArray(subshape())); int64 num_elements = sparse_indices()->index_count(); auto values = data<NativeT>(); @@ -1186,10 +1292,11 @@ void Literal::Piece::SortSparseElementsInternal() { } namespace { - -void ToStringHelper(const Literal& literal, const ShapeIndex& shape_index, +void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index, bool print_layout, std::vector<string>* pieces) { const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index); + CHECK(LayoutUtil::HasLayout(literal.shape())); + CHECK(LayoutUtil::HasLayout(subshape)); auto shape_to_string = [print_layout](const Shape& shape) { if (print_layout) { @@ -1348,13 +1455,14 @@ void ToStringHelper(const Literal& literal, const ShapeIndex& shape_index, } // namespace -int64 Literal::sparse_element_count() const { +int64 LiteralBase::sparse_element_count() const { CHECK(LayoutUtil::IsSparseArray(shape())); return sparse_indices()->index_count(); } -string Literal::ToString(bool print_layout) const { +string LiteralBase::ToString(bool print_layout) const { std::vector<string> pieces; + CHECK(LayoutUtil::HasLayout(this->shape())); ToStringHelper(*this, {}, print_layout, &pieces); return tensorflow::str_util::Join(pieces, ""); } @@ -1362,7 +1470,7 @@ string Literal::ToString(bool print_layout) const { /* static */ std::unique_ptr<Literal> Literal::MakeTuple( tensorflow::gtl::ArraySlice<const Literal*> elements) { std::vector<Shape> element_shapes; - for (const Literal* element : elements) { + for (const auto* element : elements) { element_shapes.push_back(element->shape()); } auto literal = MakeUnique<Literal>(ShapeUtil::MakeTupleShape(element_shapes)); @@ -1372,6 +1480,19 @@ string Literal::ToString(bool print_layout) const { return literal; } +/* static */ std::unique_ptr<Literal> Literal::MakeTupleFromSlices( + tensorflow::gtl::ArraySlice<LiteralSlice> elements) { + std::vector<Shape> element_shapes; + for (const auto& element : elements) { + element_shapes.push_back(element.shape()); + } + auto literal = MakeUnique<Literal>(ShapeUtil::MakeTupleShape(element_shapes)); + for (int i = 0; i < elements.size(); ++i) { + TF_CHECK_OK(literal->CopyFrom(elements[i], /*dest_shape_index=*/{i})); + } + return literal; +} + /* static */ std::unique_ptr<Literal> Literal::MakeTupleOwned( std::vector<std::unique_ptr<Literal>> elements) { std::vector<Shape> element_shapes; @@ -1387,7 +1508,7 @@ string Literal::ToString(bool print_layout) const { return literal; } -void Literal::EachCellAsString( +void LiteralBase::EachCellAsString( const std::function<void(tensorflow::gtl::ArraySlice<int64> indices, const string& value)>& per_cell) const { if (ShapeUtil::HasZeroElements(shape())) { @@ -1403,7 +1524,7 @@ void Literal::EachCellAsString( namespace { template <typename NativeSrcT, typename NativeDestT, typename ConverterType> std::unique_ptr<Literal> ConvertBetweenNativeTypesWithConverter( - const Literal& src_literal, const ConverterType& converter) { + const LiteralBase& src_literal, const ConverterType& converter) { CHECK(ShapeUtil::IsArray(src_literal.shape())); auto result_literal = MakeUnique<Literal>(ShapeUtil::ChangeElementType( src_literal.shape(), @@ -1419,7 +1540,8 @@ std::unique_ptr<Literal> ConvertBetweenNativeTypesWithConverter( } template <typename NativeSrcT, typename NativeDestT> -std::unique_ptr<Literal> ConvertBetweenNativeTypes(const Literal& src_literal) { +std::unique_ptr<Literal> ConvertBetweenNativeTypes( + const LiteralBase& src_literal) { auto converter = [](NativeSrcT src) { return static_cast<NativeDestT>(src); }; return ConvertBetweenNativeTypesWithConverter<NativeSrcT, NativeDestT>( src_literal, converter); @@ -1428,7 +1550,7 @@ std::unique_ptr<Literal> ConvertBetweenNativeTypes(const Literal& src_literal) { template <typename NativeSrcT, typename NativeDestT> typename std::enable_if<(sizeof(NativeSrcT) == sizeof(NativeDestT)), std::unique_ptr<Literal>>::type -BitcastBetweenNativeTypes(const Literal& src_literal) { +BitcastBetweenNativeTypes(const LiteralBase& src_literal) { auto converter = [](NativeSrcT src) { return tensorflow::bit_cast<NativeDestT>(src); }; @@ -1436,19 +1558,19 @@ BitcastBetweenNativeTypes(const Literal& src_literal) { src_literal, converter); } -// This template specialization is here to make the compiler happy. bit_cast has -// a static check that the types are the same size. This specialization should -// never be used because the source and destination types are checked for -// identical sizes higher up. +// This template specialization is here to make the compiler happy. bit_cast +// has a static check that the types are the same size. This specialization +// should never be used because the source and destination types are checked +// for identical sizes higher up. template <typename NativeSrcT, typename NativeDestT> typename std::enable_if<(sizeof(NativeSrcT) != sizeof(NativeDestT)), std::unique_ptr<Literal>>::type -BitcastBetweenNativeTypes(const Literal& src_literal) { +BitcastBetweenNativeTypes(const LiteralBase& src_literal) { LOG(FATAL) << "Invalid bitcast between types of different sizes."; } template <PrimitiveType primitive_src_type> -std::unique_ptr<Literal> ConvertToC64(const Literal& src_literal) { +std::unique_ptr<Literal> ConvertToC64(const LiteralBase& src_literal) { CHECK(ShapeUtil::IsArray(src_literal.shape())); auto result_literal = MakeUnique<Literal>( ShapeUtil::ChangeElementType(src_literal.shape(), C64)); @@ -1466,7 +1588,7 @@ std::unique_ptr<Literal> ConvertToC64(const Literal& src_literal) { } template <PrimitiveType primitive_src_type, PrimitiveType primitive_dest_type> -std::unique_ptr<Literal> ConvertIfTypesMatch(const Literal& src_literal, +std::unique_ptr<Literal> ConvertIfTypesMatch(const LiteralBase& src_literal, bool bitcast) { CHECK_EQ(primitive_src_type, src_literal.shape().element_type()); if (bitcast) { @@ -1486,7 +1608,7 @@ std::unique_ptr<Literal> ConvertIfTypesMatch(const Literal& src_literal, template <PrimitiveType primitive_src_type> StatusOr<std::unique_ptr<Literal>> ConvertIfDestTypeMatches( - const Literal& src_literal, PrimitiveType primitive_dest_type, + const LiteralBase& src_literal, PrimitiveType primitive_dest_type, bool bitcast) { switch (primitive_dest_type) { #define CONVERT_IF_TYPES_MATCH(type) \ @@ -1521,7 +1643,8 @@ StatusOr<std::unique_ptr<Literal>> ConvertIfDestTypeMatches( } StatusOr<std::unique_ptr<Literal>> ConvertSwitch( - const Literal& literal, PrimitiveType primitive_dest_type, bool bitcast) { + const LiteralBase& literal, PrimitiveType primitive_dest_type, + bool bitcast) { TF_RET_CHECK(ShapeUtil::IsArray(literal.shape())); if (literal.shape().element_type() == primitive_dest_type) { return literal.CloneToUnique(); @@ -1555,17 +1678,18 @@ StatusOr<std::unique_ptr<Literal>> ConvertSwitch( } // namespace -StatusOr<std::unique_ptr<Literal>> Literal::Convert( +StatusOr<std::unique_ptr<Literal>> LiteralBase::Convert( PrimitiveType primitive_dest_type) const { return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/false); } -StatusOr<std::unique_ptr<Literal>> Literal::BitcastConvert( +StatusOr<std::unique_ptr<Literal>> LiteralBase::BitcastConvert( PrimitiveType primitive_dest_type) const { if (primitive_util::BitWidth(shape().element_type()) != primitive_util::BitWidth(primitive_dest_type)) { return InvalidArgument( - "Cannot bitcast convert from %s to %s, bit widths are different: %d != " + "Cannot bitcast convert from %s to %s, bit widths are different: %d " + "!= " "%d", PrimitiveType_Name(shape().element_type()).c_str(), PrimitiveType_Name(primitive_dest_type).c_str(), @@ -1575,7 +1699,7 @@ StatusOr<std::unique_ptr<Literal>> Literal::BitcastConvert( return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/true); } -StatusOr<std::unique_ptr<Literal>> Literal::ConvertToShape( +StatusOr<std::unique_ptr<Literal>> LiteralBase::ConvertToShape( const Shape& dest_shape, bool round_f32_to_bf16) const { if (!ShapeUtil::IsTuple(dest_shape)) { if (round_f32_to_bf16 && shape().element_type() == F32 && @@ -1590,7 +1714,7 @@ StatusOr<std::unique_ptr<Literal>> Literal::ConvertToShape( } std::vector<Literal> elements; for (int i = 0; i < ShapeUtil::TupleElementCount(shape()); ++i) { - auto element = LiteralView::Create(*this, {i}); + auto element = LiteralSlice(*this, {i}); TF_ASSIGN_OR_RETURN( auto new_element, element.ConvertToShape(ShapeUtil::GetSubshape(dest_shape, {i}))); @@ -1602,8 +1726,8 @@ StatusOr<std::unique_ptr<Literal>> Literal::ConvertToShape( } template <typename NativeT> -bool Literal::Piece::EqualElementsInternal( - const Literal::Piece& other, std::vector<int64>* multi_index) const { +bool LiteralBase::Piece::EqualElementsInternal( + const LiteralBase::Piece& other, std::vector<int64>* multi_index) const { if (multi_index->size() == ShapeUtil::Rank(subshape())) { return (Get<NativeT>(*multi_index) == other.Get<NativeT>(*multi_index)); } @@ -1617,7 +1741,7 @@ bool Literal::Piece::EqualElementsInternal( return true; } -bool Literal::Piece::EqualElements(const Literal::Piece& other) const { +bool LiteralBase::Piece::EqualElements(const LiteralBase::Piece& other) const { DCHECK(ShapeUtil::Compatible(subshape(), other.subshape())); std::vector<int64> multi_index; @@ -1645,32 +1769,31 @@ bool Literal::Piece::EqualElements(const Literal::Piece& other) const { case C64: return EqualElementsInternal<complex64>(other, &multi_index); default: - LOG(FATAL) << "Unimplemented: Literal::Piece::EqualElements for type " + LOG(FATAL) << "Unimplemented: LiteralBase::Piece::EqualElements for type " << PrimitiveType_Name(subshape().element_type()); } } -bool Literal::operator==(const Literal& other) const { +bool LiteralBase::operator==(const LiteralBase& other) const { if (!ShapeUtil::Compatible(shape(), other.shape())) { return false; } - for (const auto& pair : pieces_) { - const ShapeIndex& index = pair.first; - const Piece& piece = pair.second; - if (!ShapeUtil::IsArray(piece.subshape())) { - continue; - } - const Piece& other_piece = other.piece(index); - if (!piece.EqualElements(other_piece)) { - return false; - } - } - return true; + return root_piece().ForEachSubpieceWithBool( + [&](const ShapeIndex& index, const Piece& piece) { + if (!ShapeUtil::IsArray(piece.subshape())) { + return true; + } + + const Piece& other_piece = other.piece(index); + if (!piece.EqualElements(other_piece)) { + return false; + } + return true; + }); } namespace { - template <typename NativeT> static bool AllElementsEqualValue(tensorflow::gtl::ArraySlice<NativeT> data, NativeT value) { @@ -1684,11 +1807,11 @@ static bool AllElementsEqualValue(tensorflow::gtl::ArraySlice<NativeT> data, } // namespace -bool Literal::IsAll(int8 value) const { - for (const auto& pair : pieces_) { - const Piece& piece = pair.second; +bool LiteralBase::IsAll(int8 value) const { + return root_piece().ForEachSubpieceWithBool([&](const ShapeIndex& index, + const Piece& piece) { if (!ShapeUtil::IsArray(piece.subshape())) { - continue; + return true; } auto piece_is_all = [&]() { @@ -1741,41 +1864,41 @@ bool Literal::IsAll(int8 value) const { if (!piece_is_all()) { return false; } - } - return true; -} + return true; + }); +} // namespace xla -bool Literal::IsAllFloat(float value) const { - for (const auto& pair : pieces_) { - const Piece& piece = pair.second; - if (!ShapeUtil::IsArray(piece.subshape())) { - continue; - } +bool LiteralBase::IsAllFloat(float value) const { + return root_piece().ForEachSubpieceWithBool( + [&](const ShapeIndex& index, const Piece& piece) { + if (!ShapeUtil::IsArray(piece.subshape())) { + return true; + } - auto piece_is_all = [&]() { - switch (shape().element_type()) { - case F32: - return AllElementsEqualValue<float>(piece.data<float>(), value); - case F64: - return AllElementsEqualValue<double>(piece.data<double>(), value); - case F16: - return AllElementsEqualValue<half>(piece.data<half>(), - static_cast<half>(value)); - case BF16: - return AllElementsEqualValue<bfloat16>(piece.data<bfloat16>(), - static_cast<bfloat16>(value)); - default: + auto piece_is_all = [&]() { + switch (shape().element_type()) { + case F32: + return AllElementsEqualValue<float>(piece.data<float>(), value); + case F64: + return AllElementsEqualValue<double>(piece.data<double>(), value); + case F16: + return AllElementsEqualValue<half>(piece.data<half>(), + static_cast<half>(value)); + case BF16: + return AllElementsEqualValue<bfloat16>( + piece.data<bfloat16>(), static_cast<bfloat16>(value)); + default: + return false; + } + }; + if (!piece_is_all()) { return false; - } - }; - if (!piece_is_all()) { - return false; - } - } - return true; + } + return true; + }); } -bool Literal::IsAllComplex(complex64 value) const { +bool LiteralBase::IsAllComplex(complex64 value) const { switch (shape().element_type()) { case C64: return AllElementsEqualValue<complex64>(root_piece().data<complex64>(), @@ -1785,93 +1908,93 @@ bool Literal::IsAllComplex(complex64 value) const { } } -bool Literal::IsAllFirst() const { - for (const auto& pair : pieces_) { - const Piece& piece = pair.second; - if (!ShapeUtil::IsArray(piece.subshape())) { - continue; - } - - // Empty shapes are not all the first element since there is no first - // element. - if (ShapeUtil::HasZeroElements(piece.subshape())) { - return false; - } - auto piece_is_all = [&]() { - switch (piece.subshape().element_type()) { - case PRED: { - auto data = piece.data<bool>(); - return AllElementsEqualValue<bool>(data, data[0]); - } - // 8 bit types - case S8: { - auto data = piece.data<int8>(); - return AllElementsEqualValue<int8>(data, data[0]); - } - case U8: { - auto data = piece.data<uint8>(); - return AllElementsEqualValue<uint8>(data, data[0]); - } - // 16 bit types - case BF16: { - auto data = piece.data<bfloat16>(); - return AllElementsEqualValue<bfloat16>(data, data[0]); - } - case F16: { - auto data = piece.data<half>(); - return AllElementsEqualValue<half>(data, data[0]); +bool LiteralBase::IsAllFirst() const { + return root_piece().ForEachSubpieceWithBool( + [&](const ShapeIndex& index, const Piece& piece) { + if (!ShapeUtil::IsArray(piece.subshape())) { + return true; } - case S16: { - auto data = piece.data<int16>(); - return AllElementsEqualValue<int16>(data, data[0]); - } - case U16: { - auto data = piece.data<uint16>(); - return AllElementsEqualValue<uint16>(data, data[0]); - } - // 32 bit types - case F32: { - auto data = piece.data<float>(); - return AllElementsEqualValue<float>(data, data[0]); - } - case U32: { - auto data = piece.data<uint32>(); - return AllElementsEqualValue<uint32>(data, data[0]); - } - case S32: { - auto data = piece.data<int32>(); - return AllElementsEqualValue<int32>(data, data[0]); - } - // 64 bit types - case C64: { - auto data = piece.data<complex64>(); - return AllElementsEqualValue<complex64>(data, data[0]); - } - case F64: { - auto data = piece.data<double>(); - return AllElementsEqualValue<double>(data, data[0]); - } - case S64: { - auto data = piece.data<int64>(); - return AllElementsEqualValue<int64>(data, data[0]); - } - case U64: { - auto data = piece.data<uint64>(); - return AllElementsEqualValue<uint64>(data, data[0]); - } - default: + + // Empty shapes are not all the first element since there is no first + // element. + if (ShapeUtil::HasZeroElements(piece.subshape())) { return false; - } - }; + } + auto piece_is_all = [&]() { + switch (piece.subshape().element_type()) { + case PRED: { + auto data = piece.data<bool>(); + return AllElementsEqualValue<bool>(data, data[0]); + } + // 8 bit types + case S8: { + auto data = piece.data<int8>(); + return AllElementsEqualValue<int8>(data, data[0]); + } + case U8: { + auto data = piece.data<uint8>(); + return AllElementsEqualValue<uint8>(data, data[0]); + } + // 16 bit types + case BF16: { + auto data = piece.data<bfloat16>(); + return AllElementsEqualValue<bfloat16>(data, data[0]); + } + case F16: { + auto data = piece.data<half>(); + return AllElementsEqualValue<half>(data, data[0]); + } + case S16: { + auto data = piece.data<int16>(); + return AllElementsEqualValue<int16>(data, data[0]); + } + case U16: { + auto data = piece.data<uint16>(); + return AllElementsEqualValue<uint16>(data, data[0]); + } + // 32 bit types + case F32: { + auto data = piece.data<float>(); + return AllElementsEqualValue<float>(data, data[0]); + } + case U32: { + auto data = piece.data<uint32>(); + return AllElementsEqualValue<uint32>(data, data[0]); + } + case S32: { + auto data = piece.data<int32>(); + return AllElementsEqualValue<int32>(data, data[0]); + } + // 64 bit types + case C64: { + auto data = piece.data<complex64>(); + return AllElementsEqualValue<complex64>(data, data[0]); + } + case F64: { + auto data = piece.data<double>(); + return AllElementsEqualValue<double>(data, data[0]); + } + case S64: { + auto data = piece.data<int64>(); + return AllElementsEqualValue<int64>(data, data[0]); + } + case U64: { + auto data = piece.data<uint64>(); + return AllElementsEqualValue<uint64>(data, data[0]); + } + default: + return false; + } + }; - if (!piece_is_all()) { - return false; - } - } - return true; + if (!piece_is_all()) { + return false; + } + return true; + }); } -bool Literal::IsZero(tensorflow::gtl::ArraySlice<int64> indices) const { +bool LiteralBase::IsZero(tensorflow::gtl::ArraySlice<int64> indices) const { CHECK(ShapeUtil::IsArray(shape())); switch (shape().element_type()) { case U8: @@ -1904,7 +2027,6 @@ bool Literal::IsZero(tensorflow::gtl::ArraySlice<int64> indices) const { } namespace { - template <typename RepeatedFieldT, typename NativeT> void CopyToRepeatedField(RepeatedFieldT* dest, const tensorflow::gtl::ArraySlice<NativeT> src) { @@ -1913,7 +2035,7 @@ void CopyToRepeatedField(RepeatedFieldT* dest, } // namespace -void Literal::Piece::WriteToProto(LiteralProto* proto) const { +void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const { *proto->mutable_shape() = subshape(); switch (subshape().element_type()) { case PRED: @@ -1969,18 +2091,17 @@ void Literal::Piece::WriteToProto(LiteralProto* proto) const { } } -const void* Literal::Piece::untyped_data() const { +const void* LiteralBase::Piece::untyped_data() const { CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); return buffer(); } -void* Literal::Piece::untyped_data() { +void* LiteralBase::Piece::untyped_data() { CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); return buffer(); } namespace { - template <typename RepeatedFieldT, typename NativeT> Status CopyFromRepeatedField(tensorflow::gtl::MutableArraySlice<NativeT> dest, const RepeatedFieldT& src) { @@ -1995,7 +2116,7 @@ Status CopyFromRepeatedField(tensorflow::gtl::MutableArraySlice<NativeT> dest, } // namespace -Status Literal::Piece::CopyFromProto(const LiteralProto& proto) { +Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) { // These conditions should have been checked in Literal::CreateFromProto. TF_RET_CHECK(proto.has_shape()); TF_RET_CHECK(LayoutUtil::HasLayout(proto.shape())); @@ -2062,21 +2183,19 @@ Status Literal::Piece::CopyFromProto(const LiteralProto& proto) { return Status::OK(); } -LiteralProto Literal::ToProto() const { +LiteralProto LiteralBase::ToProto() const { LiteralProto proto; - for (const auto& pair : pieces_) { - const ShapeIndex& index = pair.first; - const Piece& piece = pair.second; - - LiteralProto* proto_piece = &proto; - for (int64 i : index) { - while (proto_piece->tuple_literals_size() <= i) { - proto_piece->add_tuple_literals(); - } - proto_piece = proto_piece->mutable_tuple_literals(i); - } - piece.WriteToProto(proto_piece); - } + root_piece().ForEachSubpiece( + [&](const ShapeIndex& index, const Piece& piece) { + LiteralProto* proto_piece = &proto; + for (int64 i : index) { + while (proto_piece->tuple_literals_size() <= i) { + proto_piece->add_tuple_literals(); + } + proto_piece = proto_piece->mutable_tuple_literals(i); + } + piece.WriteToProto(proto_piece); + }); if (LayoutUtil::IsSparseArray(shape())) { CopyToRepeatedField(proto.mutable_sparse_indices(), @@ -2098,33 +2217,39 @@ StatusOr<std::unique_ptr<Literal>> Literal::CreateFromProto( auto literal = MakeUnique<Literal>(proto.shape()); - for (auto& pair : literal->pieces_) { - const ShapeIndex& index = pair.first; - Piece& piece = pair.second; - const LiteralProto* proto_element = &proto; - for (int64 i : index) { - TF_RET_CHECK(i < proto_element->tuple_literals_size()); - proto_element = &proto_element->tuple_literals(i); - } + TF_RETURN_IF_ERROR(literal->root_piece_->ForEachMutableSubpieceWithStatus( + [&](const ShapeIndex& index, Piece* piece) { + const LiteralProto* proto_element = &proto; + for (int64 i : index) { + CHECK(i < proto_element->tuple_literals_size()); + proto_element = &proto_element->tuple_literals(i); + } - if (ShapeUtil::IsTuple(piece.subshape())) { - if (proto_element->tuple_literals_size() != - ShapeUtil::TupleElementCount(piece.subshape())) { - return InvalidArgument( - "Expected %lld tuple elements in LiteralProto, has %d", - ShapeUtil::TupleElementCount(piece.subshape()), - proto_element->tuple_literals_size()); - } - continue; - } + if (ShapeUtil::IsTuple(piece->subshape())) { + if (proto_element->tuple_literals_size() != + ShapeUtil::TupleElementCount(piece->subshape())) { + return InvalidArgument( + "Expected %lld tuple elements in LiteralProto, has %d", + ShapeUtil::TupleElementCount(piece->subshape()), + proto_element->tuple_literals_size()); + } + return Status::OK(); + } - TF_RET_CHECK(ShapeUtil::IsArray(piece.subshape())); - TF_RETURN_IF_ERROR(piece.CopyFromProto(*proto_element)); - } + CHECK(ShapeUtil::IsArray(piece->subshape())); + TF_RETURN_IF_ERROR(piece->CopyFromProto(*proto_element)); + + return Status::OK(); + })); return std::move(literal); } -const void* Literal::untyped_data(const ShapeIndex& shape_index) const { +/* static */ string Literal::MultiIndexAsString( + tensorflow::gtl::ArraySlice<int64> multi_index) { + return StrCat("{", tensorflow::str_util::Join(multi_index, ","), "}"); +} + +const void* LiteralBase::untyped_data(const ShapeIndex& shape_index) const { return piece(shape_index).untyped_data(); } @@ -2132,11 +2257,11 @@ void* Literal::untyped_data(const ShapeIndex& shape_index) { return piece(shape_index).untyped_data(); } -int64 Literal::size_bytes(const ShapeIndex& shape_index) const { +int64 LiteralBase::size_bytes(const ShapeIndex& shape_index) const { return piece(shape_index).size_bytes(); } -string Literal::GetR1U8AsString() const { +string LiteralBase::GetR1U8AsString() const { CHECK(ShapeUtil::IsArray(shape())); CHECK_EQ(ShapeUtil::Rank(shape()), 1); CHECK_EQ(shape().element_type(), U8); @@ -2144,12 +2269,14 @@ string Literal::GetR1U8AsString() const { ShapeUtil::ElementsIn(shape())); } -/* static */ const LiteralView LiteralView::Create( - const Literal& literal, const ShapeIndex& view_root) { - return LiteralView(literal, view_root); -} +LiteralSlice::LiteralSlice(const LiteralBase& literal) + : LiteralBase(), root_piece_(&literal.root_piece()) {} -size_t Literal::Hash() const { +LiteralSlice::LiteralSlice(const LiteralBase& literal, + const ShapeIndex& view_root) + : LiteralBase(), root_piece_(&literal.piece(view_root)) {} + +size_t LiteralBase::Hash() const { using tensorflow::Hash64; using tensorflow::Hash64Combine; @@ -2170,46 +2297,4 @@ size_t Literal::Hash() const { return hash_value; } -LiteralView::LiteralView(const Literal& literal, const ShapeIndex& view_root) { - shape_ = ShapeUtil::GetSubshape(literal.shape(), view_root); - pieces_ = ShapeTree<Piece>(shape_); - owns_buffers_ = false; - for (auto& pair : pieces_) { - const ShapeIndex& index = pair.first; - Piece& piece = pair.second; - - ShapeIndex src_index = view_root; - for (int64 i : index) { - src_index.push_back(i); - } - const Piece& src_piece = literal.piece(src_index); - piece.set_buffer(src_piece.buffer()); - piece.set_sparse_indices(src_piece.sparse_indices()); - piece.set_subshape(&ShapeUtil::GetSubshape(shape_, index)); - } -} - -LiteralView::~LiteralView() {} - -LiteralView::LiteralView(const LiteralView& other) { CopyFrom(other); } - -LiteralView& LiteralView::operator=(const LiteralView& other) { - CopyFrom(other); - return *this; -} - -void LiteralView::CopyFrom(const LiteralView& other) { - // We can't use the default copy-constructor/copy-assignment because - // Piece::subshape_ points to subshapes within the Shape of the owning - // Literal/LiteralView. - shape_ = other.shape(); - pieces_ = other.pieces_; - for (auto& pair : pieces_) { - const ShapeIndex& index = pair.first; - Piece& piece = pair.second; - piece.set_subshape(&ShapeUtil::GetSubshape(shape_, index)); - } - owns_buffers_ = false; -} - } // namespace xla |