aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/literal_util.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/literal_util.cc')
-rw-r--r--tensorflow/compiler/xla/literal_util.cc935
1 files changed, 510 insertions, 425 deletions
diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc
index b3b5e34ba2..82a2bcad76 100644
--- a/tensorflow/compiler/xla/literal_util.cc
+++ b/tensorflow/compiler/xla/literal_util.cc
@@ -62,8 +62,49 @@ void ConvertEndianShort(char* bytes, int64 size) {
}
}
+// Return a literal with all arrays of type FromNativeT converted to type
+// ToNativeT in the given literal.
+template <typename FromNativeT, typename ToNativeT>
+std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
+ // First construct shape of the result.
+ Shape result_shape(literal.shape());
+ ShapeUtil::ForEachMutableSubshape(
+ &result_shape, [](Shape* subshape, const ShapeIndex&) {
+ if (subshape->element_type() ==
+ primitive_util::NativeToPrimitiveType<FromNativeT>()) {
+ subshape->set_element_type(
+ primitive_util::NativeToPrimitiveType<ToNativeT>());
+ }
+ });
+ auto result = MakeUnique<Literal>(result_shape);
+
+ // Then copy over the data from 'literal' converting FromNativeT values to
+ // ToNativeT values as necessary.
+ ShapeUtil::ForEachSubshape(
+ literal.shape(),
+ [&](const Shape& subshape, const ShapeIndex& shape_index) {
+ if (ShapeUtil::IsArray(subshape)) {
+ if (subshape.element_type() ==
+ primitive_util::NativeToPrimitiveType<FromNativeT>()) {
+ auto src = literal.data<FromNativeT>(shape_index);
+ auto dest = result->data<ToNativeT>(shape_index);
+ for (int64 i = 0; i < src.size(); ++i) {
+ dest[i] = static_cast<ToNativeT>(src[i]);
+ }
+ } else {
+ TF_CHECK_OK(result->CopyFrom(literal,
+ /*dest_shape_index=*/shape_index,
+ /*src_shape_index=*/shape_index));
+ }
+ }
+ });
+ return result;
+}
+
} // namespace
+LiteralBase::~LiteralBase() {}
+
std::ostream& operator<<(std::ostream& out, const Literal& literal) {
out << literal.ToString();
return out;
@@ -95,99 +136,90 @@ Literal::StrideConfig::StrideConfig(
Literal::Literal(const Shape& shape)
: Literal(shape, /*allocate_arrays=*/true) {}
-Literal::Literal(const Shape& shape, bool allocate_arrays)
- : shape_(shape), pieces_(shape), owns_buffers_(true) {
- CHECK(LayoutUtil::HasLayout(shape));
- for (auto& pair : pieces_) {
- const ShapeIndex& index = pair.first;
- Piece& piece = pair.second;
-
- piece.set_subshape(&ShapeUtil::GetSubshape(shape_, index));
- const Shape& subshape = piece.subshape();
- if (ShapeUtil::IsArray(subshape)) {
- if (allocate_arrays) {
- if (LayoutUtil::IsSparseArray(subshape)) {
- // For sparse arrays, the buffer must be of the size of the maximum
- // number of sparse elements possible.
- const int64 max_sparse_elements =
- LayoutUtil::MaxSparseElements(subshape.layout());
- piece.set_buffer(
- new char[max_sparse_elements * ShapeUtil::ByteSizeOfPrimitiveType(
- subshape.element_type())]);
- piece.set_sparse_indices(new SparseIndexArray(
- max_sparse_elements, ShapeUtil::Rank(subshape)));
- } else {
- piece.set_buffer(new char[piece.size_bytes()]);
- }
+void Literal::SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays) {
+ if (ShapeUtil::IsTuple(shape)) {
+ for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
+ const Shape& subshape = shape.tuple_shapes(i);
+
+ auto child_piece = Piece();
+ child_piece.set_subshape(&subshape);
+
+ SetPiece(subshape, &child_piece, allocate_arrays);
+
+ piece->emplace_back(std::move(child_piece));
+ }
+ } else {
+ CHECK(ShapeUtil::IsArray(shape));
+ if (allocate_arrays) {
+ if (LayoutUtil::IsSparseArray(shape)) {
+ // For sparse arrays, the buffer must be of the size of the maximum
+ // number of sparse elements possible.
+ const int64 max_sparse_elements =
+ LayoutUtil::MaxSparseElements(shape.layout());
+ piece->set_buffer(
+ new char[max_sparse_elements *
+ ShapeUtil::ByteSizeOfPrimitiveType(shape.element_type())]);
+ piece->set_sparse_indices(
+ new SparseIndexArray(max_sparse_elements, ShapeUtil::Rank(shape)));
} else {
- piece.set_buffer(nullptr);
+ piece->set_buffer(new char[piece->size_bytes()]);
}
}
}
}
-Literal::~Literal() { DeallocateBuffers(); }
+Literal::Literal(const Shape& shape, bool allocate_arrays)
+ : LiteralBase(), shape_(MakeUnique<Shape>(shape)) {
+ CHECK(LayoutUtil::HasLayout(*shape_));
+ root_piece_ = new Piece();
+ root_piece_->set_subshape(shape_.get());
+ CHECK(&root_piece_->subshape() == shape_.get());
-void Literal::DeallocateBuffers() {
- if (owns_buffers_) {
- for (auto& pair : pieces_) {
- Piece& piece = pair.second;
- if (piece.buffer() != nullptr) {
- delete[] piece.buffer();
- delete piece.sparse_indices();
- }
- }
- }
+ SetPiece(*shape_, root_piece_, allocate_arrays);
}
-Literal::Literal(Literal&& other) {
- shape_ = std::move(other.shape_);
- pieces_ = std::move(other.pieces_);
- // We need to iterate through the pieces to set the subshape pointer
- // properly. It must refer to subshapes within shape_.
- for (auto& pair : pieces_) {
- const ShapeIndex& index = pair.first;
- Piece& piece = pair.second;
- piece.set_subshape(&ShapeUtil::GetSubshape(shape_, index));
+Literal::~Literal() {
+ if (root_piece_ != nullptr) {
+ DeallocateBuffers();
+ delete root_piece_;
}
- owns_buffers_ = other.owns_buffers_;
+}
- other.shape_ = ShapeUtil::MakeNil();
- other.pieces_ = ShapeTree<Piece>(other.shape_);
- other.piece({}).set_subshape(&other.shape_);
+void Literal::DeallocateBuffers() {
+ root_piece_->ForEachMutableSubpiece(
+ [&](const ShapeIndex& index, Piece* piece) {
+ if (piece->buffer() != nullptr) {
+ delete[] piece->buffer();
+ delete piece->sparse_indices();
+ }
+ });
}
+Literal::Literal(Literal&& other) : LiteralBase() { *this = std::move(other); }
+
Literal& Literal::operator=(Literal&& other) {
- DeallocateBuffers();
- shape_ = std::move(other.shape_);
- pieces_ = std::move(other.pieces_);
- // We need to iterate through the pieces to set the subshape pointer
- // properly. It must refer to subshapes within shape_.
- for (auto& pair : pieces_) {
- const ShapeIndex& index = pair.first;
- Piece& piece = pair.second;
- piece.set_subshape(&ShapeUtil::GetSubshape(shape_, index));
- }
- owns_buffers_ = other.owns_buffers_;
-
- other.shape_ = ShapeUtil::MakeNil();
- other.pieces_ = ShapeTree<Piece>(other.shape_);
- other.piece({}).set_subshape(&other.shape_);
+ CHECK(&other.root_piece_->subshape() == other.shape_.get());
+
+ using std::swap;
+ swap(shape_, other.shape_);
+ swap(root_piece_, other.root_piece_);
+ CHECK(&root_piece_->subshape() == shape_.get());
+
return *this;
}
-std::unique_ptr<Literal> Literal::CreateFromShape(const Shape& shape) {
+std::unique_ptr<Literal> LiteralBase::CreateFromShape(const Shape& shape) {
auto literal = MakeUnique<Literal>(shape);
- for (auto& pair : literal->pieces_) {
- Piece& piece = pair.second;
- if (ShapeUtil::IsArray(piece.subshape())) {
- memset(piece.untyped_data(), 0, piece.size_bytes());
- }
- }
+ literal->root_piece_->ForEachMutableSubpiece(
+ [&](const ShapeIndex& index, Piece* piece) {
+ if (ShapeUtil::IsArray(piece->subshape())) {
+ memset(piece->untyped_data(), 0, piece->size_bytes());
+ }
+ });
return literal;
}
-const SparseIndexArray* Literal::sparse_indices(
+const SparseIndexArray* LiteralBase::sparse_indices(
const ShapeIndex& shape_index) const {
return piece(shape_index).sparse_indices();
}
@@ -202,9 +234,19 @@ SparseIndexArray* Literal::sparse_indices(const ShapeIndex& shape_index) {
return CreateFromShape(ShapeUtil::MakeShape(primitive_type, dimensions));
}
+/* static */ std::unique_ptr<Literal> Literal::ConvertBF16ToF32(
+ const LiteralSlice& bf16_literal) {
+ return ConvertType<bfloat16, float>(bf16_literal);
+}
+
+/* static */ std::unique_ptr<Literal> Literal::ConvertF32ToBF16(
+ const LiteralSlice& f32_literal) {
+ return ConvertType<float, bfloat16>(f32_literal);
+}
+
template <typename NativeT>
Status Literal::CopySliceFromInternal(
- const Literal& src_literal, tensorflow::gtl::ArraySlice<int64> src_base,
+ const LiteralBase& src_literal, tensorflow::gtl::ArraySlice<int64> src_base,
tensorflow::gtl::ArraySlice<int64> dest_base,
tensorflow::gtl::ArraySlice<int64> copy_size) {
TF_RET_CHECK(ShapeUtil::Rank(src_literal.shape()) == src_base.size());
@@ -217,8 +259,8 @@ Status Literal::CopySliceFromInternal(
if (ShapeUtil::Rank(src_literal.shape()) == 0 ||
ShapeUtil::Rank(shape()) == 0) {
- // If any of the two shapes are scalars, we can just call the StridedCopy()
- // directly, and we know we will be copying only one value.
+ // If any of the two shapes are scalars, we can just call the
+ // StridedCopy() directly, and we know we will be copying only one value.
TF_RET_CHECK(copy_size.empty());
StridedCopy(data<NativeT>(), linear_index(shape(), dest_base), 0,
src_literal.data<NativeT>(),
@@ -264,7 +306,7 @@ Status Literal::CopySliceFromInternal(
return Status::OK();
}
-Status Literal::CopyElementFrom(const Literal& src_literal,
+Status Literal::CopyElementFrom(const LiteralSlice& src_literal,
tensorflow::gtl::ArraySlice<int64> src_index,
tensorflow::gtl::ArraySlice<int64> dest_index) {
DCHECK_EQ(shape().element_type(), src_literal.shape().element_type());
@@ -293,22 +335,21 @@ std::vector<Literal> Literal::DecomposeTuple() {
elements.push_back(Literal(ShapeUtil::GetSubshape(shape(), {i}),
/*allocate_arrays=*/false));
Literal& element = elements.back();
- for (auto& pair : element.pieces_) {
- const ShapeIndex& index = pair.first;
- Piece& dest_piece = pair.second;
- ShapeIndex src_index = {i};
- for (int64 j : index) {
- src_index.push_back(j);
- }
- Piece& src_piece = piece(src_index);
-
- // Move the respective buffer and sparse indices over to the element
- // Literal.
- dest_piece.set_buffer(src_piece.buffer());
- src_piece.set_buffer(nullptr);
- dest_piece.set_sparse_indices(src_piece.sparse_indices());
- src_piece.set_sparse_indices(nullptr);
- }
+ element.root_piece_->ForEachMutableSubpiece(
+ [&](const ShapeIndex& index, Piece* dest_piece) {
+ ShapeIndex src_index = {i};
+ for (int64 j : index) {
+ src_index.push_back(j);
+ }
+ Piece& src_piece = piece(src_index);
+
+ // Move the respective buffer and sparse indices over to the element
+ // Literal.
+ dest_piece->set_buffer(src_piece.buffer());
+ src_piece.set_buffer(nullptr);
+ dest_piece->set_sparse_indices(src_piece.sparse_indices());
+ src_piece.set_sparse_indices(nullptr);
+ });
}
// Set this literal to be nil-shaped.
*this = Literal();
@@ -331,9 +372,9 @@ std::vector<Literal> Literal::DecomposeTuple() {
}
namespace {
-
-// Copies the elements in 'src' to 'dest'. The shape and layout of the data in
-// the array slices are indicated by dest_shape and src_shape respectively.
+// Copies the elements in 'src' to 'dest'. The shape and layout of the data
+// in the array slices are indicated by dest_shape and src_shape
+// respectively.
template <typename NativeT>
void CopyElementsBetween(tensorflow::gtl::MutableArraySlice<NativeT> dest,
tensorflow::gtl::ArraySlice<NativeT> src,
@@ -351,7 +392,7 @@ void CopyElementsBetween(tensorflow::gtl::MutableArraySlice<NativeT> dest,
} // namespace
-Status Literal::Piece::CopyFrom(const Literal::Piece& src) {
+Status LiteralBase::Piece::CopyFrom(const LiteralBase::Piece& src) {
if (ShapeUtil::Equal(subshape(), src.subshape())) {
// If the layouts are equal it's faster just to memcpy.
memcpy(buffer(), src.buffer(), src.size_bytes());
@@ -381,14 +422,15 @@ Status Literal::Piece::CopyFrom(const Literal::Piece& src) {
#undef COPY_ELEMENTS
default:
return Unimplemented(
- "Copying a Literal object with element type %s is not implemented.",
+ "Copying a Literal object with element type %s is not "
+ "implemented.",
PrimitiveType_Name(subshape().element_type()).c_str());
}
}
return Status::OK();
}
-Status Literal::CopyFrom(const Literal& src_literal,
+Status Literal::CopyFrom(const LiteralSlice& src_literal,
const ShapeIndex& dest_shape_index,
const ShapeIndex& src_shape_index) {
const Shape& dest_subshape =
@@ -402,36 +444,33 @@ Status Literal::CopyFrom(const Literal& src_literal,
ShapeUtil::HumanString(src_subshape).c_str());
}
- for (auto& pair : pieces_) {
- const ShapeIndex& index = pair.first;
- Piece& piece = pair.second;
- if (!ShapeUtil::IsArray(piece.subshape())) {
- continue;
- }
-
- // Determine if this index is in the part of this literal that we want to
- // copy over from src_literal.
- bool in_subtree_to_copy = true;
- for (int i = 0; i < dest_shape_index.size(); ++i) {
- if (index[i] != dest_shape_index[i]) {
- in_subtree_to_copy = false;
- break;
- }
- }
- if (!in_subtree_to_copy) {
- continue;
- }
-
- // Construct the index of the corresponding piece in the source literal.
- ShapeIndex src_piece_index = src_shape_index;
- for (int64 i = dest_shape_index.size(); i < index.size(); ++i) {
- src_piece_index.push_back(index[i]);
- }
+ return root_piece_->ForEachMutableSubpieceWithStatus(
+ [&](const ShapeIndex& index, Piece* piece) {
+ if (!ShapeUtil::IsArray(piece->subshape())) {
+ return Status::OK();
+ }
- TF_RETURN_IF_ERROR(piece.CopyFrom(src_literal.piece(src_piece_index)));
- }
- return Status::OK();
-}
+ // Determine if this index is in the part of this literal that we want
+ // to copy over from src_literal.
+ bool in_subtree_to_copy = true;
+ for (int i = 0; i < dest_shape_index.size(); ++i) {
+ if (index[i] != dest_shape_index[i]) {
+ in_subtree_to_copy = false;
+ break;
+ }
+ }
+ if (!in_subtree_to_copy) {
+ return Status::OK();
+ }
+ // Construct the index of the corresponding piece in the source literal.
+ ShapeIndex src_piece_index = src_shape_index;
+ for (int64 i = dest_shape_index.size(); i < index.size(); ++i) {
+ src_piece_index.push_back(index[i]);
+ }
+ TF_RETURN_IF_ERROR(piece->CopyFrom(src_literal.piece(src_piece_index)));
+ return Status::OK();
+ });
+} // namespace xla
Status Literal::MoveFrom(Literal&& src_literal,
const ShapeIndex& dest_shape_index) {
@@ -444,37 +483,32 @@ Status Literal::MoveFrom(Literal&& src_literal,
ShapeUtil::HumanString(src_literal.shape()).c_str());
}
- if (!(owns_buffers_ && src_literal.owns_buffers_)) {
- return InvalidArgument(
- "Source and destination literals must both own their buffers (ie, not "
- "be views)");
- }
+ src_literal.root_piece_->ForEachSubpiece(
+ [&](const ShapeIndex& src_index, const Piece& src_piece) {
+ if (!ShapeUtil::IsArray(src_piece.subshape())) {
+ return;
+ }
- for (auto& pair : src_literal.pieces_) {
- const ShapeIndex& src_index = pair.first;
- Piece& src_piece = pair.second;
- if (!ShapeUtil::IsArray(src_piece.subshape())) {
- continue;
- }
+ ShapeIndex dest_index = dest_shape_index;
+ for (int64 i : src_index) {
+ dest_index.push_back(i);
+ }
+ Piece& dest_piece = piece(dest_index);
+ delete[] dest_piece.buffer();
+ dest_piece.set_buffer(src_piece.buffer());
+ delete dest_piece.sparse_indices();
+ dest_piece.set_sparse_indices(src_piece.sparse_indices());
+ });
- ShapeIndex dest_index = dest_shape_index;
- for (int64 i : src_index) {
- dest_index.push_back(i);
- }
- Piece& dest_piece = piece(dest_index);
- delete[] dest_piece.buffer();
- dest_piece.set_buffer(src_piece.buffer());
- delete dest_piece.sparse_indices();
- dest_piece.set_sparse_indices(src_piece.sparse_indices());
- }
+ src_literal.shape_ = MakeUnique<Shape>(ShapeUtil::MakeNil());
+ delete src_literal.root_piece_;
+ src_literal.root_piece_ = new LiteralBase::Piece();
+ src_literal.root_piece_->set_subshape(src_literal.shape_.get());
- src_literal.shape_ = ShapeUtil::MakeNil();
- src_literal.pieces_ = ShapeTree<Piece>(src_literal.shape_);
- src_literal.piece({}).set_subshape(&src_literal.shape_);
return Status::OK();
}
-Status Literal::CopySliceFrom(const Literal& src_literal,
+Status Literal::CopySliceFrom(const LiteralSlice& src_literal,
tensorflow::gtl::ArraySlice<int64> src_base,
tensorflow::gtl::ArraySlice<int64> dest_base,
tensorflow::gtl::ArraySlice<int64> copy_size) {
@@ -743,7 +777,7 @@ void Literal::PopulateR1(const tensorflow::core::Bitmap& values) {
return CreateR2FromArray2D(*value);
}
-std::unique_ptr<Literal> Literal::Relayout(
+std::unique_ptr<Literal> LiteralBase::Relayout(
const Layout& new_layout, const ShapeIndex& shape_index) const {
// Create new shape with 'new_layout' set at the given shape index.
Shape new_shape = shape();
@@ -755,7 +789,7 @@ std::unique_ptr<Literal> Literal::Relayout(
return result;
}
-std::unique_ptr<Literal> Literal::Relayout(
+std::unique_ptr<Literal> LiteralBase::Relayout(
const Shape& shape_with_layout) const {
CHECK(ShapeUtil::Compatible(shape_with_layout, shape()))
<< "Given shape_with_layout " << ShapeUtil::HumanString(shape_with_layout)
@@ -774,7 +808,7 @@ std::unique_ptr<Literal> Literal::Relayout(
return result;
}
-StatusOr<std::unique_ptr<Literal>> Literal::Reshape(
+StatusOr<std::unique_ptr<Literal>> LiteralBase::Reshape(
tensorflow::gtl::ArraySlice<int64> dimensions) const {
if (!ShapeUtil::IsArray(shape())) {
return InvalidArgument("Reshape does not support tuples.");
@@ -788,7 +822,8 @@ StatusOr<std::unique_ptr<Literal>> Literal::Reshape(
}
// Because the layout is monotonic, we can simply reuse the same sequence of
// values without changing their order.
- output->shape_ = ShapeUtil::MakeShape(shape().element_type(), dimensions);
+ *output->mutable_shape_do_not_use() =
+ ShapeUtil::MakeShape(shape().element_type(), dimensions);
int64 elements_before = ShapeUtil::ElementsIn(shape());
int64 elements_after = ShapeUtil::ElementsIn(output->shape());
@@ -802,7 +837,79 @@ StatusOr<std::unique_ptr<Literal>> Literal::Reshape(
return std::move(output);
}
-std::unique_ptr<Literal> Literal::Transpose(
+/* static */ std::unique_ptr<Literal> Literal::ReshapeSlice(
+ tensorflow::gtl::ArraySlice<int64> new_dimensions,
+ tensorflow::gtl::ArraySlice<int64> minor_to_major,
+ const LiteralSlice& literal) {
+ int64 new_num_elements = 1;
+ for (int64 i = 0; i < new_dimensions.size(); ++i) {
+ new_num_elements *= new_dimensions[i];
+ }
+ CHECK_EQ(ShapeUtil::ElementsIn(literal.shape()), new_num_elements);
+ CHECK_EQ(new_dimensions.size(), minor_to_major.size());
+
+ auto new_literal = MakeUnique<Literal>(
+ ShapeUtil::MakeShape(literal.shape().element_type(), new_dimensions));
+
+ // Create a new shape with the given minor-to-major layout. This shape is used
+ // solely for converting linear address to multi-dimensional addresses when
+ // writing elements to the new literal.
+ Shape shape_with_layout = new_literal->shape();
+ *shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major);
+
+ // Copy data into new literal, element-by-element.
+ for (int64 i = 0; i < ShapeUtil::ElementsIn(literal.shape()); ++i) {
+ std::vector<int64> from_multi_index =
+ IndexUtil::LinearIndexToMultidimensionalIndex(literal.shape(), i);
+ std::vector<int64> to_multi_index =
+ IndexUtil::LinearIndexToMultidimensionalIndex(shape_with_layout, i);
+ switch (literal.shape().element_type()) {
+ case PRED:
+ new_literal->Set<bool>(to_multi_index,
+ literal.Get<bool>(from_multi_index));
+ break;
+ case U8:
+ new_literal->Set<uint8>(to_multi_index,
+ literal.Get<uint8>(from_multi_index));
+ break;
+ case U32:
+ new_literal->Set<uint32>(to_multi_index,
+ literal.Get<uint32>(from_multi_index));
+ break;
+ case S32:
+ new_literal->Set<int32>(to_multi_index,
+ literal.Get<int32>(from_multi_index));
+ break;
+ case U64:
+ new_literal->Set<uint64>(to_multi_index,
+ literal.Get<uint64>(from_multi_index));
+ break;
+ case S64:
+ new_literal->Set<int64>(to_multi_index,
+ literal.Get<int64>(from_multi_index));
+ break;
+ case F32:
+ new_literal->Set<float>(to_multi_index,
+ literal.Get<float>(from_multi_index));
+ break;
+ case F64:
+ new_literal->Set<double>(to_multi_index,
+ literal.Get<double>(from_multi_index));
+ break;
+ case C64:
+ new_literal->Set<complex64>(to_multi_index,
+ literal.Get<complex64>(from_multi_index));
+ break;
+ default:
+ LOG(FATAL) << "Unhandled primitive element type: "
+ << PrimitiveType_Name(literal.shape().element_type());
+ }
+ }
+
+ return new_literal;
+}
+
+std::unique_ptr<Literal> LiteralBase::Transpose(
tensorflow::gtl::ArraySlice<int64> permutation) const {
CHECK(ShapeUtil::IsArray(shape())) << "Tuple is not supported for transpose";
CHECK(IsPermutation(permutation, ShapeUtil::Rank(shape())))
@@ -819,8 +926,8 @@ std::unique_ptr<Literal> Literal::Transpose(
// representation intact.
// For example, consider the shape F32[11,8]{1,0} under a {1,0} permutation.
// The shape with affine layout resulting from that operation will be
- // F32[8,11]{0,1}, since it leaves the original most minor (the 8 sized), the
- // most minor.
+ // F32[8,11]{0,1}, since it leaves the original most minor (the 8 sized),
+ // the most minor.
//
// Essentially, given MinMaj(Di) the position of the Di dimension within the
// minor to major vector, and given T(Di) the index that the original Di
@@ -836,12 +943,11 @@ std::unique_ptr<Literal> Literal::Transpose(
std::unique_ptr<Literal> new_literal = CreateFromShape(permuted_shape);
DCHECK_GE(ShapeUtil::ByteSizeOf(new_literal->shape()),
ShapeUtil::ByteSizeOf(shape()));
- std::memcpy(new_literal->root_piece().buffer(), root_piece().buffer(),
- root_piece().size_bytes());
+ std::memcpy(new_literal->untyped_data(), untyped_data(), size_bytes());
return new_literal;
}
-std::unique_ptr<Literal> Literal::Slice(
+std::unique_ptr<Literal> LiteralBase::Slice(
tensorflow::gtl::ArraySlice<int64> start_indices,
tensorflow::gtl::ArraySlice<int64> limit_indices) const {
CHECK(ShapeUtil::IsArray(shape())) << "tuple is not supported for slice";
@@ -909,20 +1015,20 @@ std::unique_ptr<Literal> Literal::Slice(
}
}
-Literal Literal::Clone() const {
+Literal LiteralBase::Clone() const {
Literal result(shape());
TF_CHECK_OK(result.CopyFrom(*this));
return result;
}
-std::unique_ptr<Literal> Literal::CloneToUnique() const {
+std::unique_ptr<Literal> LiteralBase::CloneToUnique() const {
auto result = MakeUnique<Literal>(shape());
TF_CHECK_OK(result->CopyFrom(*this));
return result;
}
-string Literal::GetAsString(tensorflow::gtl::ArraySlice<int64> multi_index,
- const ShapeIndex& shape_index) const {
+string LiteralBase::GetAsString(tensorflow::gtl::ArraySlice<int64> multi_index,
+ const ShapeIndex& shape_index) const {
const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index);
CHECK(LayoutUtil::IsDenseArray(subshape));
switch (subshape.element_type()) {
@@ -962,8 +1068,8 @@ string Literal::GetAsString(tensorflow::gtl::ArraySlice<int64> multi_index,
}
}
-string Literal::GetSparseElementAsString(int64 sparse_element_number,
- const ShapeIndex& shape_index) const {
+string LiteralBase::GetSparseElementAsString(
+ int64 sparse_element_number, const ShapeIndex& shape_index) const {
const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index);
CHECK(LayoutUtil::IsSparseArray(subshape));
switch (subshape.element_type()) {
@@ -1017,7 +1123,7 @@ string Literal::GetSparseElementAsString(int64 sparse_element_number,
}
}
-StatusOr<int64> Literal::GetIntegralAsS64(
+StatusOr<int64> LiteralBase::GetIntegralAsS64(
tensorflow::gtl::ArraySlice<int64> multi_index) const {
CHECK(LayoutUtil::IsDenseArray(shape()));
switch (shape().element_type()) {
@@ -1070,7 +1176,7 @@ Status Literal::SetIntegralAsS64(tensorflow::gtl::ArraySlice<int64> multi_index,
return Status::OK();
}
-tensorflow::gtl::ArraySlice<int64> Literal::GetSparseIndex(
+tensorflow::gtl::ArraySlice<int64> LiteralBase::GetSparseIndex(
int64 sparse_element_number, const ShapeIndex& shape_index) const {
const Piece& p = piece(shape_index);
CHECK_GE(sparse_element_number, 0);
@@ -1082,10 +1188,10 @@ void Literal::SortSparseElements(const ShapeIndex& shape_index) {
piece(shape_index).SortSparseElements();
}
-Literal Literal::GetFirstScalarLiteral() const {
- CHECK(ShapeUtil::IsArray(shape_));
- CHECK_GT(ShapeUtil::ElementsIn(shape_), 0);
- switch (shape_.element_type()) {
+Literal LiteralBase::GetFirstScalarLiteral() const {
+ CHECK(ShapeUtil::IsArray(shape()));
+ CHECK_GT(ShapeUtil::ElementsIn(shape()), 0);
+ switch (shape().element_type()) {
case PRED:
return std::move(*Literal::CreateR0<bool>(GetFirstElement<bool>()));
// 8 bit types.
@@ -1121,11 +1227,11 @@ Literal Literal::GetFirstScalarLiteral() const {
case U64:
return std::move(*Literal::CreateR0<uint64>(GetFirstElement<uint64>()));
default:
- LOG(FATAL) << "Unhandled primitive type " << shape_.element_type();
+ LOG(FATAL) << "Unhandled primitive type " << shape().element_type();
}
}
-void Literal::Piece::SortSparseElements() {
+void LiteralBase::Piece::SortSparseElements() {
switch (subshape().element_type()) {
case PRED:
SortSparseElementsInternal<bool>();
@@ -1176,7 +1282,7 @@ void Literal::Piece::SortSparseElements() {
}
template <typename NativeT>
-void Literal::Piece::SortSparseElementsInternal() {
+void LiteralBase::Piece::SortSparseElementsInternal() {
CHECK(LayoutUtil::IsSparseArray(subshape()));
int64 num_elements = sparse_indices()->index_count();
auto values = data<NativeT>();
@@ -1186,10 +1292,11 @@ void Literal::Piece::SortSparseElementsInternal() {
}
namespace {
-
-void ToStringHelper(const Literal& literal, const ShapeIndex& shape_index,
+void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index,
bool print_layout, std::vector<string>* pieces) {
const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index);
+ CHECK(LayoutUtil::HasLayout(literal.shape()));
+ CHECK(LayoutUtil::HasLayout(subshape));
auto shape_to_string = [print_layout](const Shape& shape) {
if (print_layout) {
@@ -1348,13 +1455,14 @@ void ToStringHelper(const Literal& literal, const ShapeIndex& shape_index,
} // namespace
-int64 Literal::sparse_element_count() const {
+int64 LiteralBase::sparse_element_count() const {
CHECK(LayoutUtil::IsSparseArray(shape()));
return sparse_indices()->index_count();
}
-string Literal::ToString(bool print_layout) const {
+string LiteralBase::ToString(bool print_layout) const {
std::vector<string> pieces;
+ CHECK(LayoutUtil::HasLayout(this->shape()));
ToStringHelper(*this, {}, print_layout, &pieces);
return tensorflow::str_util::Join(pieces, "");
}
@@ -1362,7 +1470,7 @@ string Literal::ToString(bool print_layout) const {
/* static */ std::unique_ptr<Literal> Literal::MakeTuple(
tensorflow::gtl::ArraySlice<const Literal*> elements) {
std::vector<Shape> element_shapes;
- for (const Literal* element : elements) {
+ for (const auto* element : elements) {
element_shapes.push_back(element->shape());
}
auto literal = MakeUnique<Literal>(ShapeUtil::MakeTupleShape(element_shapes));
@@ -1372,6 +1480,19 @@ string Literal::ToString(bool print_layout) const {
return literal;
}
+/* static */ std::unique_ptr<Literal> Literal::MakeTupleFromSlices(
+ tensorflow::gtl::ArraySlice<LiteralSlice> elements) {
+ std::vector<Shape> element_shapes;
+ for (const auto& element : elements) {
+ element_shapes.push_back(element.shape());
+ }
+ auto literal = MakeUnique<Literal>(ShapeUtil::MakeTupleShape(element_shapes));
+ for (int i = 0; i < elements.size(); ++i) {
+ TF_CHECK_OK(literal->CopyFrom(elements[i], /*dest_shape_index=*/{i}));
+ }
+ return literal;
+}
+
/* static */ std::unique_ptr<Literal> Literal::MakeTupleOwned(
std::vector<std::unique_ptr<Literal>> elements) {
std::vector<Shape> element_shapes;
@@ -1387,7 +1508,7 @@ string Literal::ToString(bool print_layout) const {
return literal;
}
-void Literal::EachCellAsString(
+void LiteralBase::EachCellAsString(
const std::function<void(tensorflow::gtl::ArraySlice<int64> indices,
const string& value)>& per_cell) const {
if (ShapeUtil::HasZeroElements(shape())) {
@@ -1403,7 +1524,7 @@ void Literal::EachCellAsString(
namespace {
template <typename NativeSrcT, typename NativeDestT, typename ConverterType>
std::unique_ptr<Literal> ConvertBetweenNativeTypesWithConverter(
- const Literal& src_literal, const ConverterType& converter) {
+ const LiteralBase& src_literal, const ConverterType& converter) {
CHECK(ShapeUtil::IsArray(src_literal.shape()));
auto result_literal = MakeUnique<Literal>(ShapeUtil::ChangeElementType(
src_literal.shape(),
@@ -1419,7 +1540,8 @@ std::unique_ptr<Literal> ConvertBetweenNativeTypesWithConverter(
}
template <typename NativeSrcT, typename NativeDestT>
-std::unique_ptr<Literal> ConvertBetweenNativeTypes(const Literal& src_literal) {
+std::unique_ptr<Literal> ConvertBetweenNativeTypes(
+ const LiteralBase& src_literal) {
auto converter = [](NativeSrcT src) { return static_cast<NativeDestT>(src); };
return ConvertBetweenNativeTypesWithConverter<NativeSrcT, NativeDestT>(
src_literal, converter);
@@ -1428,7 +1550,7 @@ std::unique_ptr<Literal> ConvertBetweenNativeTypes(const Literal& src_literal) {
template <typename NativeSrcT, typename NativeDestT>
typename std::enable_if<(sizeof(NativeSrcT) == sizeof(NativeDestT)),
std::unique_ptr<Literal>>::type
-BitcastBetweenNativeTypes(const Literal& src_literal) {
+BitcastBetweenNativeTypes(const LiteralBase& src_literal) {
auto converter = [](NativeSrcT src) {
return tensorflow::bit_cast<NativeDestT>(src);
};
@@ -1436,19 +1558,19 @@ BitcastBetweenNativeTypes(const Literal& src_literal) {
src_literal, converter);
}
-// This template specialization is here to make the compiler happy. bit_cast has
-// a static check that the types are the same size. This specialization should
-// never be used because the source and destination types are checked for
-// identical sizes higher up.
+// This template specialization is here to make the compiler happy. bit_cast
+// has a static check that the types are the same size. This specialization
+// should never be used because the source and destination types are checked
+// for identical sizes higher up.
template <typename NativeSrcT, typename NativeDestT>
typename std::enable_if<(sizeof(NativeSrcT) != sizeof(NativeDestT)),
std::unique_ptr<Literal>>::type
-BitcastBetweenNativeTypes(const Literal& src_literal) {
+BitcastBetweenNativeTypes(const LiteralBase& src_literal) {
LOG(FATAL) << "Invalid bitcast between types of different sizes.";
}
template <PrimitiveType primitive_src_type>
-std::unique_ptr<Literal> ConvertToC64(const Literal& src_literal) {
+std::unique_ptr<Literal> ConvertToC64(const LiteralBase& src_literal) {
CHECK(ShapeUtil::IsArray(src_literal.shape()));
auto result_literal = MakeUnique<Literal>(
ShapeUtil::ChangeElementType(src_literal.shape(), C64));
@@ -1466,7 +1588,7 @@ std::unique_ptr<Literal> ConvertToC64(const Literal& src_literal) {
}
template <PrimitiveType primitive_src_type, PrimitiveType primitive_dest_type>
-std::unique_ptr<Literal> ConvertIfTypesMatch(const Literal& src_literal,
+std::unique_ptr<Literal> ConvertIfTypesMatch(const LiteralBase& src_literal,
bool bitcast) {
CHECK_EQ(primitive_src_type, src_literal.shape().element_type());
if (bitcast) {
@@ -1486,7 +1608,7 @@ std::unique_ptr<Literal> ConvertIfTypesMatch(const Literal& src_literal,
template <PrimitiveType primitive_src_type>
StatusOr<std::unique_ptr<Literal>> ConvertIfDestTypeMatches(
- const Literal& src_literal, PrimitiveType primitive_dest_type,
+ const LiteralBase& src_literal, PrimitiveType primitive_dest_type,
bool bitcast) {
switch (primitive_dest_type) {
#define CONVERT_IF_TYPES_MATCH(type) \
@@ -1521,7 +1643,8 @@ StatusOr<std::unique_ptr<Literal>> ConvertIfDestTypeMatches(
}
StatusOr<std::unique_ptr<Literal>> ConvertSwitch(
- const Literal& literal, PrimitiveType primitive_dest_type, bool bitcast) {
+ const LiteralBase& literal, PrimitiveType primitive_dest_type,
+ bool bitcast) {
TF_RET_CHECK(ShapeUtil::IsArray(literal.shape()));
if (literal.shape().element_type() == primitive_dest_type) {
return literal.CloneToUnique();
@@ -1555,17 +1678,18 @@ StatusOr<std::unique_ptr<Literal>> ConvertSwitch(
} // namespace
-StatusOr<std::unique_ptr<Literal>> Literal::Convert(
+StatusOr<std::unique_ptr<Literal>> LiteralBase::Convert(
PrimitiveType primitive_dest_type) const {
return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/false);
}
-StatusOr<std::unique_ptr<Literal>> Literal::BitcastConvert(
+StatusOr<std::unique_ptr<Literal>> LiteralBase::BitcastConvert(
PrimitiveType primitive_dest_type) const {
if (primitive_util::BitWidth(shape().element_type()) !=
primitive_util::BitWidth(primitive_dest_type)) {
return InvalidArgument(
- "Cannot bitcast convert from %s to %s, bit widths are different: %d != "
+ "Cannot bitcast convert from %s to %s, bit widths are different: %d "
+ "!= "
"%d",
PrimitiveType_Name(shape().element_type()).c_str(),
PrimitiveType_Name(primitive_dest_type).c_str(),
@@ -1575,7 +1699,7 @@ StatusOr<std::unique_ptr<Literal>> Literal::BitcastConvert(
return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/true);
}
-StatusOr<std::unique_ptr<Literal>> Literal::ConvertToShape(
+StatusOr<std::unique_ptr<Literal>> LiteralBase::ConvertToShape(
const Shape& dest_shape, bool round_f32_to_bf16) const {
if (!ShapeUtil::IsTuple(dest_shape)) {
if (round_f32_to_bf16 && shape().element_type() == F32 &&
@@ -1590,7 +1714,7 @@ StatusOr<std::unique_ptr<Literal>> Literal::ConvertToShape(
}
std::vector<Literal> elements;
for (int i = 0; i < ShapeUtil::TupleElementCount(shape()); ++i) {
- auto element = LiteralView::Create(*this, {i});
+ auto element = LiteralSlice(*this, {i});
TF_ASSIGN_OR_RETURN(
auto new_element,
element.ConvertToShape(ShapeUtil::GetSubshape(dest_shape, {i})));
@@ -1602,8 +1726,8 @@ StatusOr<std::unique_ptr<Literal>> Literal::ConvertToShape(
}
template <typename NativeT>
-bool Literal::Piece::EqualElementsInternal(
- const Literal::Piece& other, std::vector<int64>* multi_index) const {
+bool LiteralBase::Piece::EqualElementsInternal(
+ const LiteralBase::Piece& other, std::vector<int64>* multi_index) const {
if (multi_index->size() == ShapeUtil::Rank(subshape())) {
return (Get<NativeT>(*multi_index) == other.Get<NativeT>(*multi_index));
}
@@ -1617,7 +1741,7 @@ bool Literal::Piece::EqualElementsInternal(
return true;
}
-bool Literal::Piece::EqualElements(const Literal::Piece& other) const {
+bool LiteralBase::Piece::EqualElements(const LiteralBase::Piece& other) const {
DCHECK(ShapeUtil::Compatible(subshape(), other.subshape()));
std::vector<int64> multi_index;
@@ -1645,32 +1769,31 @@ bool Literal::Piece::EqualElements(const Literal::Piece& other) const {
case C64:
return EqualElementsInternal<complex64>(other, &multi_index);
default:
- LOG(FATAL) << "Unimplemented: Literal::Piece::EqualElements for type "
+ LOG(FATAL) << "Unimplemented: LiteralBase::Piece::EqualElements for type "
<< PrimitiveType_Name(subshape().element_type());
}
}
-bool Literal::operator==(const Literal& other) const {
+bool LiteralBase::operator==(const LiteralBase& other) const {
if (!ShapeUtil::Compatible(shape(), other.shape())) {
return false;
}
- for (const auto& pair : pieces_) {
- const ShapeIndex& index = pair.first;
- const Piece& piece = pair.second;
- if (!ShapeUtil::IsArray(piece.subshape())) {
- continue;
- }
- const Piece& other_piece = other.piece(index);
- if (!piece.EqualElements(other_piece)) {
- return false;
- }
- }
- return true;
+ return root_piece().ForEachSubpieceWithBool(
+ [&](const ShapeIndex& index, const Piece& piece) {
+ if (!ShapeUtil::IsArray(piece.subshape())) {
+ return true;
+ }
+
+ const Piece& other_piece = other.piece(index);
+ if (!piece.EqualElements(other_piece)) {
+ return false;
+ }
+ return true;
+ });
}
namespace {
-
template <typename NativeT>
static bool AllElementsEqualValue(tensorflow::gtl::ArraySlice<NativeT> data,
NativeT value) {
@@ -1684,11 +1807,11 @@ static bool AllElementsEqualValue(tensorflow::gtl::ArraySlice<NativeT> data,
} // namespace
-bool Literal::IsAll(int8 value) const {
- for (const auto& pair : pieces_) {
- const Piece& piece = pair.second;
+bool LiteralBase::IsAll(int8 value) const {
+ return root_piece().ForEachSubpieceWithBool([&](const ShapeIndex& index,
+ const Piece& piece) {
if (!ShapeUtil::IsArray(piece.subshape())) {
- continue;
+ return true;
}
auto piece_is_all = [&]() {
@@ -1741,41 +1864,41 @@ bool Literal::IsAll(int8 value) const {
if (!piece_is_all()) {
return false;
}
- }
- return true;
-}
+ return true;
+ });
+} // namespace xla
-bool Literal::IsAllFloat(float value) const {
- for (const auto& pair : pieces_) {
- const Piece& piece = pair.second;
- if (!ShapeUtil::IsArray(piece.subshape())) {
- continue;
- }
+bool LiteralBase::IsAllFloat(float value) const {
+ return root_piece().ForEachSubpieceWithBool(
+ [&](const ShapeIndex& index, const Piece& piece) {
+ if (!ShapeUtil::IsArray(piece.subshape())) {
+ return true;
+ }
- auto piece_is_all = [&]() {
- switch (shape().element_type()) {
- case F32:
- return AllElementsEqualValue<float>(piece.data<float>(), value);
- case F64:
- return AllElementsEqualValue<double>(piece.data<double>(), value);
- case F16:
- return AllElementsEqualValue<half>(piece.data<half>(),
- static_cast<half>(value));
- case BF16:
- return AllElementsEqualValue<bfloat16>(piece.data<bfloat16>(),
- static_cast<bfloat16>(value));
- default:
+ auto piece_is_all = [&]() {
+ switch (shape().element_type()) {
+ case F32:
+ return AllElementsEqualValue<float>(piece.data<float>(), value);
+ case F64:
+ return AllElementsEqualValue<double>(piece.data<double>(), value);
+ case F16:
+ return AllElementsEqualValue<half>(piece.data<half>(),
+ static_cast<half>(value));
+ case BF16:
+ return AllElementsEqualValue<bfloat16>(
+ piece.data<bfloat16>(), static_cast<bfloat16>(value));
+ default:
+ return false;
+ }
+ };
+ if (!piece_is_all()) {
return false;
- }
- };
- if (!piece_is_all()) {
- return false;
- }
- }
- return true;
+ }
+ return true;
+ });
}
-bool Literal::IsAllComplex(complex64 value) const {
+bool LiteralBase::IsAllComplex(complex64 value) const {
switch (shape().element_type()) {
case C64:
return AllElementsEqualValue<complex64>(root_piece().data<complex64>(),
@@ -1785,93 +1908,93 @@ bool Literal::IsAllComplex(complex64 value) const {
}
}
-bool Literal::IsAllFirst() const {
- for (const auto& pair : pieces_) {
- const Piece& piece = pair.second;
- if (!ShapeUtil::IsArray(piece.subshape())) {
- continue;
- }
-
- // Empty shapes are not all the first element since there is no first
- // element.
- if (ShapeUtil::HasZeroElements(piece.subshape())) {
- return false;
- }
- auto piece_is_all = [&]() {
- switch (piece.subshape().element_type()) {
- case PRED: {
- auto data = piece.data<bool>();
- return AllElementsEqualValue<bool>(data, data[0]);
- }
- // 8 bit types
- case S8: {
- auto data = piece.data<int8>();
- return AllElementsEqualValue<int8>(data, data[0]);
- }
- case U8: {
- auto data = piece.data<uint8>();
- return AllElementsEqualValue<uint8>(data, data[0]);
- }
- // 16 bit types
- case BF16: {
- auto data = piece.data<bfloat16>();
- return AllElementsEqualValue<bfloat16>(data, data[0]);
- }
- case F16: {
- auto data = piece.data<half>();
- return AllElementsEqualValue<half>(data, data[0]);
+bool LiteralBase::IsAllFirst() const {
+ return root_piece().ForEachSubpieceWithBool(
+ [&](const ShapeIndex& index, const Piece& piece) {
+ if (!ShapeUtil::IsArray(piece.subshape())) {
+ return true;
}
- case S16: {
- auto data = piece.data<int16>();
- return AllElementsEqualValue<int16>(data, data[0]);
- }
- case U16: {
- auto data = piece.data<uint16>();
- return AllElementsEqualValue<uint16>(data, data[0]);
- }
- // 32 bit types
- case F32: {
- auto data = piece.data<float>();
- return AllElementsEqualValue<float>(data, data[0]);
- }
- case U32: {
- auto data = piece.data<uint32>();
- return AllElementsEqualValue<uint32>(data, data[0]);
- }
- case S32: {
- auto data = piece.data<int32>();
- return AllElementsEqualValue<int32>(data, data[0]);
- }
- // 64 bit types
- case C64: {
- auto data = piece.data<complex64>();
- return AllElementsEqualValue<complex64>(data, data[0]);
- }
- case F64: {
- auto data = piece.data<double>();
- return AllElementsEqualValue<double>(data, data[0]);
- }
- case S64: {
- auto data = piece.data<int64>();
- return AllElementsEqualValue<int64>(data, data[0]);
- }
- case U64: {
- auto data = piece.data<uint64>();
- return AllElementsEqualValue<uint64>(data, data[0]);
- }
- default:
+
+ // Empty shapes are not all the first element since there is no first
+ // element.
+ if (ShapeUtil::HasZeroElements(piece.subshape())) {
return false;
- }
- };
+ }
+ auto piece_is_all = [&]() {
+ switch (piece.subshape().element_type()) {
+ case PRED: {
+ auto data = piece.data<bool>();
+ return AllElementsEqualValue<bool>(data, data[0]);
+ }
+ // 8 bit types
+ case S8: {
+ auto data = piece.data<int8>();
+ return AllElementsEqualValue<int8>(data, data[0]);
+ }
+ case U8: {
+ auto data = piece.data<uint8>();
+ return AllElementsEqualValue<uint8>(data, data[0]);
+ }
+ // 16 bit types
+ case BF16: {
+ auto data = piece.data<bfloat16>();
+ return AllElementsEqualValue<bfloat16>(data, data[0]);
+ }
+ case F16: {
+ auto data = piece.data<half>();
+ return AllElementsEqualValue<half>(data, data[0]);
+ }
+ case S16: {
+ auto data = piece.data<int16>();
+ return AllElementsEqualValue<int16>(data, data[0]);
+ }
+ case U16: {
+ auto data = piece.data<uint16>();
+ return AllElementsEqualValue<uint16>(data, data[0]);
+ }
+ // 32 bit types
+ case F32: {
+ auto data = piece.data<float>();
+ return AllElementsEqualValue<float>(data, data[0]);
+ }
+ case U32: {
+ auto data = piece.data<uint32>();
+ return AllElementsEqualValue<uint32>(data, data[0]);
+ }
+ case S32: {
+ auto data = piece.data<int32>();
+ return AllElementsEqualValue<int32>(data, data[0]);
+ }
+ // 64 bit types
+ case C64: {
+ auto data = piece.data<complex64>();
+ return AllElementsEqualValue<complex64>(data, data[0]);
+ }
+ case F64: {
+ auto data = piece.data<double>();
+ return AllElementsEqualValue<double>(data, data[0]);
+ }
+ case S64: {
+ auto data = piece.data<int64>();
+ return AllElementsEqualValue<int64>(data, data[0]);
+ }
+ case U64: {
+ auto data = piece.data<uint64>();
+ return AllElementsEqualValue<uint64>(data, data[0]);
+ }
+ default:
+ return false;
+ }
+ };
- if (!piece_is_all()) {
- return false;
- }
- }
- return true;
+ if (!piece_is_all()) {
+ return false;
+ }
+ return true;
+ });
}
-bool Literal::IsZero(tensorflow::gtl::ArraySlice<int64> indices) const {
+bool LiteralBase::IsZero(tensorflow::gtl::ArraySlice<int64> indices) const {
CHECK(ShapeUtil::IsArray(shape()));
switch (shape().element_type()) {
case U8:
@@ -1904,7 +2027,6 @@ bool Literal::IsZero(tensorflow::gtl::ArraySlice<int64> indices) const {
}
namespace {
-
template <typename RepeatedFieldT, typename NativeT>
void CopyToRepeatedField(RepeatedFieldT* dest,
const tensorflow::gtl::ArraySlice<NativeT> src) {
@@ -1913,7 +2035,7 @@ void CopyToRepeatedField(RepeatedFieldT* dest,
} // namespace
-void Literal::Piece::WriteToProto(LiteralProto* proto) const {
+void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const {
*proto->mutable_shape() = subshape();
switch (subshape().element_type()) {
case PRED:
@@ -1969,18 +2091,17 @@ void Literal::Piece::WriteToProto(LiteralProto* proto) const {
}
}
-const void* Literal::Piece::untyped_data() const {
+const void* LiteralBase::Piece::untyped_data() const {
CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape());
return buffer();
}
-void* Literal::Piece::untyped_data() {
+void* LiteralBase::Piece::untyped_data() {
CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape());
return buffer();
}
namespace {
-
template <typename RepeatedFieldT, typename NativeT>
Status CopyFromRepeatedField(tensorflow::gtl::MutableArraySlice<NativeT> dest,
const RepeatedFieldT& src) {
@@ -1995,7 +2116,7 @@ Status CopyFromRepeatedField(tensorflow::gtl::MutableArraySlice<NativeT> dest,
} // namespace
-Status Literal::Piece::CopyFromProto(const LiteralProto& proto) {
+Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) {
// These conditions should have been checked in Literal::CreateFromProto.
TF_RET_CHECK(proto.has_shape());
TF_RET_CHECK(LayoutUtil::HasLayout(proto.shape()));
@@ -2062,21 +2183,19 @@ Status Literal::Piece::CopyFromProto(const LiteralProto& proto) {
return Status::OK();
}
-LiteralProto Literal::ToProto() const {
+LiteralProto LiteralBase::ToProto() const {
LiteralProto proto;
- for (const auto& pair : pieces_) {
- const ShapeIndex& index = pair.first;
- const Piece& piece = pair.second;
-
- LiteralProto* proto_piece = &proto;
- for (int64 i : index) {
- while (proto_piece->tuple_literals_size() <= i) {
- proto_piece->add_tuple_literals();
- }
- proto_piece = proto_piece->mutable_tuple_literals(i);
- }
- piece.WriteToProto(proto_piece);
- }
+ root_piece().ForEachSubpiece(
+ [&](const ShapeIndex& index, const Piece& piece) {
+ LiteralProto* proto_piece = &proto;
+ for (int64 i : index) {
+ while (proto_piece->tuple_literals_size() <= i) {
+ proto_piece->add_tuple_literals();
+ }
+ proto_piece = proto_piece->mutable_tuple_literals(i);
+ }
+ piece.WriteToProto(proto_piece);
+ });
if (LayoutUtil::IsSparseArray(shape())) {
CopyToRepeatedField(proto.mutable_sparse_indices(),
@@ -2098,33 +2217,39 @@ StatusOr<std::unique_ptr<Literal>> Literal::CreateFromProto(
auto literal = MakeUnique<Literal>(proto.shape());
- for (auto& pair : literal->pieces_) {
- const ShapeIndex& index = pair.first;
- Piece& piece = pair.second;
- const LiteralProto* proto_element = &proto;
- for (int64 i : index) {
- TF_RET_CHECK(i < proto_element->tuple_literals_size());
- proto_element = &proto_element->tuple_literals(i);
- }
+ TF_RETURN_IF_ERROR(literal->root_piece_->ForEachMutableSubpieceWithStatus(
+ [&](const ShapeIndex& index, Piece* piece) {
+ const LiteralProto* proto_element = &proto;
+ for (int64 i : index) {
+ CHECK(i < proto_element->tuple_literals_size());
+ proto_element = &proto_element->tuple_literals(i);
+ }
- if (ShapeUtil::IsTuple(piece.subshape())) {
- if (proto_element->tuple_literals_size() !=
- ShapeUtil::TupleElementCount(piece.subshape())) {
- return InvalidArgument(
- "Expected %lld tuple elements in LiteralProto, has %d",
- ShapeUtil::TupleElementCount(piece.subshape()),
- proto_element->tuple_literals_size());
- }
- continue;
- }
+ if (ShapeUtil::IsTuple(piece->subshape())) {
+ if (proto_element->tuple_literals_size() !=
+ ShapeUtil::TupleElementCount(piece->subshape())) {
+ return InvalidArgument(
+ "Expected %lld tuple elements in LiteralProto, has %d",
+ ShapeUtil::TupleElementCount(piece->subshape()),
+ proto_element->tuple_literals_size());
+ }
+ return Status::OK();
+ }
- TF_RET_CHECK(ShapeUtil::IsArray(piece.subshape()));
- TF_RETURN_IF_ERROR(piece.CopyFromProto(*proto_element));
- }
+ CHECK(ShapeUtil::IsArray(piece->subshape()));
+ TF_RETURN_IF_ERROR(piece->CopyFromProto(*proto_element));
+
+ return Status::OK();
+ }));
return std::move(literal);
}
-const void* Literal::untyped_data(const ShapeIndex& shape_index) const {
+/* static */ string Literal::MultiIndexAsString(
+ tensorflow::gtl::ArraySlice<int64> multi_index) {
+ return StrCat("{", tensorflow::str_util::Join(multi_index, ","), "}");
+}
+
+const void* LiteralBase::untyped_data(const ShapeIndex& shape_index) const {
return piece(shape_index).untyped_data();
}
@@ -2132,11 +2257,11 @@ void* Literal::untyped_data(const ShapeIndex& shape_index) {
return piece(shape_index).untyped_data();
}
-int64 Literal::size_bytes(const ShapeIndex& shape_index) const {
+int64 LiteralBase::size_bytes(const ShapeIndex& shape_index) const {
return piece(shape_index).size_bytes();
}
-string Literal::GetR1U8AsString() const {
+string LiteralBase::GetR1U8AsString() const {
CHECK(ShapeUtil::IsArray(shape()));
CHECK_EQ(ShapeUtil::Rank(shape()), 1);
CHECK_EQ(shape().element_type(), U8);
@@ -2144,12 +2269,14 @@ string Literal::GetR1U8AsString() const {
ShapeUtil::ElementsIn(shape()));
}
-/* static */ const LiteralView LiteralView::Create(
- const Literal& literal, const ShapeIndex& view_root) {
- return LiteralView(literal, view_root);
-}
+LiteralSlice::LiteralSlice(const LiteralBase& literal)
+ : LiteralBase(), root_piece_(&literal.root_piece()) {}
-size_t Literal::Hash() const {
+LiteralSlice::LiteralSlice(const LiteralBase& literal,
+ const ShapeIndex& view_root)
+ : LiteralBase(), root_piece_(&literal.piece(view_root)) {}
+
+size_t LiteralBase::Hash() const {
using tensorflow::Hash64;
using tensorflow::Hash64Combine;
@@ -2170,46 +2297,4 @@ size_t Literal::Hash() const {
return hash_value;
}
-LiteralView::LiteralView(const Literal& literal, const ShapeIndex& view_root) {
- shape_ = ShapeUtil::GetSubshape(literal.shape(), view_root);
- pieces_ = ShapeTree<Piece>(shape_);
- owns_buffers_ = false;
- for (auto& pair : pieces_) {
- const ShapeIndex& index = pair.first;
- Piece& piece = pair.second;
-
- ShapeIndex src_index = view_root;
- for (int64 i : index) {
- src_index.push_back(i);
- }
- const Piece& src_piece = literal.piece(src_index);
- piece.set_buffer(src_piece.buffer());
- piece.set_sparse_indices(src_piece.sparse_indices());
- piece.set_subshape(&ShapeUtil::GetSubshape(shape_, index));
- }
-}
-
-LiteralView::~LiteralView() {}
-
-LiteralView::LiteralView(const LiteralView& other) { CopyFrom(other); }
-
-LiteralView& LiteralView::operator=(const LiteralView& other) {
- CopyFrom(other);
- return *this;
-}
-
-void LiteralView::CopyFrom(const LiteralView& other) {
- // We can't use the default copy-constructor/copy-assignment because
- // Piece::subshape_ points to subshapes within the Shape of the owning
- // Literal/LiteralView.
- shape_ = other.shape();
- pieces_ = other.pieces_;
- for (auto& pair : pieces_) {
- const ShapeIndex& index = pair.first;
- Piece& piece = pair.second;
- piece.set_subshape(&ShapeUtil::GetSubshape(shape_, index));
- }
- owns_buffers_ = false;
-}
-
} // namespace xla