aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/literal.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/literal.cc')
-rw-r--r--tensorflow/compiler/xla/literal.cc149
1 files changed, 68 insertions, 81 deletions
diff --git a/tensorflow/compiler/xla/literal.cc b/tensorflow/compiler/xla/literal.cc
index 3f7635bd40..5035f41988 100644
--- a/tensorflow/compiler/xla/literal.cc
+++ b/tensorflow/compiler/xla/literal.cc
@@ -174,9 +174,9 @@ Literal& Literal::operator=(Literal&& other) {
return *this;
}
-std::unique_ptr<Literal> LiteralBase::CreateFromShape(const Shape& shape) {
- auto literal = absl::make_unique<Literal>(shape);
- literal->root_piece_->ForEachMutableSubpiece(
+Literal LiteralBase::CreateFromShape(const Shape& shape) {
+ Literal literal(shape);
+ literal.root_piece_->ForEachMutableSubpiece(
[&](const ShapeIndex& index, Piece* piece) {
if (ShapeUtil::IsArray(piece->subshape())) {
memset(piece->untyped_data(), 0, piece->size_bytes());
@@ -278,8 +278,8 @@ Status MutableLiteralBase::CopyElementFrom(const LiteralSlice& src_literal,
return Status::OK();
}
-/* static */ StatusOr<std::unique_ptr<Literal>>
-MutableLiteralBase::CreateFromProto(const LiteralProto& proto) {
+/* static */ StatusOr<Literal> MutableLiteralBase::CreateFromProto(
+ const LiteralProto& proto) {
if (!proto.has_shape()) {
return InvalidArgument("LiteralProto has no shape");
}
@@ -287,9 +287,9 @@ MutableLiteralBase::CreateFromProto(const LiteralProto& proto) {
return InvalidArgument("LiteralProto has no layout");
}
- auto literal = absl::make_unique<Literal>(proto.shape());
+ Literal literal(proto.shape());
- TF_RETURN_IF_ERROR(literal->root_piece_->ForEachMutableSubpieceWithStatus(
+ TF_RETURN_IF_ERROR(literal.root_piece_->ForEachMutableSubpieceWithStatus(
[&](const ShapeIndex& index, Piece* piece) {
const LiteralProto* proto_element = &proto;
for (int64 i : index) {
@@ -556,38 +556,37 @@ void MutableLiteralBase::PopulateR1(const tensorflow::core::Bitmap& values) {
}
}
-std::unique_ptr<Literal> LiteralBase::Relayout(
- const Layout& new_layout, const ShapeIndex& shape_index) const {
+Literal LiteralBase::Relayout(const Layout& new_layout,
+ const ShapeIndex& shape_index) const {
// Create new shape with 'new_layout' set at the given shape index.
Shape new_shape = shape();
Shape* subshape = ShapeUtil::GetMutableSubshape(&new_shape, shape_index);
TF_CHECK_OK(LayoutUtil::ValidateLayoutForShape(new_layout, *subshape));
*subshape->mutable_layout() = new_layout;
- auto result = absl::make_unique<Literal>(new_shape);
- TF_CHECK_OK(result->CopyFrom(*this));
+ Literal result(new_shape);
+ TF_CHECK_OK(result.CopyFrom(*this));
return result;
}
-std::unique_ptr<Literal> LiteralBase::Relayout(
- const Shape& shape_with_layout) const {
+Literal LiteralBase::Relayout(const Shape& shape_with_layout) const {
CHECK(ShapeUtil::Compatible(shape_with_layout, shape()))
<< "Given shape_with_layout " << ShapeUtil::HumanString(shape_with_layout)
<< " not compatible with literal shape "
<< ShapeUtil::HumanString(shape());
- std::unique_ptr<Literal> result = CreateFromShape(shape_with_layout);
+ Literal result = CreateFromShape(shape_with_layout);
ShapeUtil::ForEachSubshape(
- result->shape(),
+ result.shape(),
[this, &result](const Shape& subshape, const ShapeIndex& index) {
if (ShapeUtil::IsArray(subshape)) {
- TF_CHECK_OK(result->CopyFrom(*this,
- /*dest_shape_index=*/index,
- /*src_shape_index=*/index));
+ TF_CHECK_OK(result.CopyFrom(*this,
+ /*dest_shape_index=*/index,
+ /*src_shape_index=*/index));
}
});
return result;
}
-StatusOr<std::unique_ptr<Literal>> LiteralBase::Broadcast(
+StatusOr<Literal> LiteralBase::Broadcast(
const Shape& result_shape, absl::Span<const int64> dimensions) const {
if (!ShapeUtil::IsArray(shape())) {
return InvalidArgument("Broadcast only supports arrays.");
@@ -598,14 +597,14 @@ StatusOr<std::unique_ptr<Literal>> LiteralBase::Broadcast(
result_shape.dimensions(dimensions[i]));
}
- std::unique_ptr<Literal> result = absl::make_unique<Literal>(result_shape);
+ Literal result(result_shape);
// scratch_source_index is temporary storage space for the computed index into
// the input literal. We put it here to avoid allocating an std::vector in
// every iteration of ShapeUtil::ForEachIndex.
std::vector<int64> scratch_source_index(shape().dimensions_size());
- char* dest_data = static_cast<char*>(result->untyped_data());
+ char* dest_data = static_cast<char*>(result.untyped_data());
const char* source_data = static_cast<const char*>(untyped_data());
const int64 primitive_size =
ShapeUtil::ByteSizeOfPrimitiveType(shape().element_type());
@@ -627,37 +626,36 @@ StatusOr<std::unique_ptr<Literal>> LiteralBase::Broadcast(
return std::move(result);
}
-StatusOr<std::unique_ptr<Literal>> LiteralBase::Reshape(
+StatusOr<Literal> LiteralBase::Reshape(
absl::Span<const int64> dimensions) const {
if (!ShapeUtil::IsArray(shape())) {
return InvalidArgument("Reshape does not support tuples.");
}
- std::unique_ptr<Literal> output;
+ Literal output;
if (!LayoutUtil::IsMonotonicWithDim0Major(shape().layout())) {
output =
Relayout(LayoutUtil::GetDefaultLayoutForRank(ShapeUtil::Rank(shape())));
} else {
- output = CloneToUnique();
+ output = Clone();
}
// Because the layout is monotonic, we can simply reuse the same sequence of
// values without changing their order.
- *output->mutable_shape_do_not_use() =
+ *output.mutable_shape_do_not_use() =
ShapeUtil::MakeShape(shape().element_type(), dimensions);
int64 elements_before = ShapeUtil::ElementsIn(shape());
- int64 elements_after = ShapeUtil::ElementsIn(output->shape());
+ int64 elements_after = ShapeUtil::ElementsIn(output.shape());
if (elements_before != elements_after) {
return InvalidArgument(
"Shapes before and after Literal::Reshape have different numbers "
"of elements: %s vs %s.",
ShapeUtil::HumanString(shape()),
- ShapeUtil::HumanString(output->shape()));
+ ShapeUtil::HumanString(output.shape()));
}
return std::move(output);
}
-std::unique_ptr<Literal> LiteralBase::Transpose(
- absl::Span<const int64> permutation) const {
+Literal LiteralBase::Transpose(absl::Span<const int64> permutation) const {
CHECK(ShapeUtil::IsArray(shape())) << "Tuple is not supported for transpose";
CHECK(IsPermutation(permutation, ShapeUtil::Rank(shape())))
<< "Given permutation is not a permutation of dimension numbers";
@@ -687,32 +685,31 @@ std::unique_ptr<Literal> LiteralBase::Transpose(
for (auto index : LayoutUtil::MinorToMajor(shape())) {
layout->add_minor_to_major(inverse_permutation[index]);
}
- auto new_literal = absl::make_unique<Literal>(permuted_shape);
- DCHECK_EQ(ShapeUtil::ByteSizeOf(new_literal->shape()),
+ Literal new_literal(permuted_shape);
+ DCHECK_EQ(ShapeUtil::ByteSizeOf(new_literal.shape()),
ShapeUtil::ByteSizeOf(shape()));
- std::memcpy(new_literal->untyped_data(), untyped_data(), size_bytes());
+ std::memcpy(new_literal.untyped_data(), untyped_data(), size_bytes());
return new_literal;
}
template <typename NativeT>
-std::unique_ptr<Literal> LiteralBase::SliceInternal(
+Literal LiteralBase::SliceInternal(
const Shape& result_shape, absl::Span<const int64> start_indices) const {
- auto result_literal = absl::make_unique<Literal>(result_shape);
+ Literal result_literal(result_shape);
DimensionVector new_indices(ShapeUtil::Rank(result_shape));
- result_literal->EachCell<NativeT>(
+ result_literal.EachCell<NativeT>(
[&](absl::Span<const int64> indices, NativeT /*value*/) {
for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) {
new_indices[i] = indices[i] + start_indices[i];
}
NativeT value = Get<NativeT>(new_indices);
- result_literal->Set<NativeT>(indices, value);
+ result_literal.Set<NativeT>(indices, value);
});
return result_literal;
}
-std::unique_ptr<Literal> LiteralBase::Slice(
- absl::Span<const int64> start_indices,
- absl::Span<const int64> limit_indices) const {
+Literal LiteralBase::Slice(absl::Span<const int64> start_indices,
+ absl::Span<const int64> limit_indices) const {
CHECK(ShapeUtil::IsArray(shape())) << "tuple is not supported for slice";
DimensionVector result_dimensions;
@@ -750,12 +747,6 @@ Literal LiteralBase::Clone() const {
return result;
}
-std::unique_ptr<Literal> LiteralBase::CloneToUnique() const {
- auto result = absl::make_unique<Literal>(shape());
- TF_CHECK_OK(result->CopyFrom(*this));
- return result;
-}
-
string LiteralBase::GetAsString(absl::Span<const int64> multi_index,
const ShapeIndex& shape_index) const {
const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index);
@@ -1191,14 +1182,14 @@ void LiteralBase::EachCellAsString(
namespace {
template <typename NativeSrcT, typename NativeDestT, typename ConverterType>
-std::unique_ptr<Literal> ConvertBetweenNativeTypesWithConverter(
- const LiteralBase& src_literal, const ConverterType& converter) {
+Literal ConvertBetweenNativeTypesWithConverter(const LiteralBase& src_literal,
+ const ConverterType& converter) {
CHECK(ShapeUtil::IsArray(src_literal.shape()));
- auto result_literal = absl::make_unique<Literal>(ShapeUtil::ChangeElementType(
+ Literal result_literal(ShapeUtil::ChangeElementType(
src_literal.shape(),
primitive_util::NativeToPrimitiveType<NativeDestT>()));
auto src_data = src_literal.data<NativeSrcT>();
- auto dest_data = result_literal->template data<NativeDestT>();
+ auto dest_data = result_literal.template data<NativeDestT>();
int64 num_elements = src_literal.element_count();
for (int64 i = 0; i < num_elements; ++i) {
@@ -1208,8 +1199,7 @@ std::unique_ptr<Literal> ConvertBetweenNativeTypesWithConverter(
}
template <typename NativeSrcT, typename NativeDestT>
-std::unique_ptr<Literal> ConvertBetweenNativeTypes(
- const LiteralBase& src_literal) {
+Literal ConvertBetweenNativeTypes(const LiteralBase& src_literal) {
auto converter = [](NativeSrcT src) { return static_cast<NativeDestT>(src); };
return ConvertBetweenNativeTypesWithConverter<NativeSrcT, NativeDestT>(
src_literal, converter);
@@ -1217,7 +1207,7 @@ std::unique_ptr<Literal> ConvertBetweenNativeTypes(
template <typename NativeSrcT, typename NativeDestT>
typename std::enable_if<(sizeof(NativeSrcT) == sizeof(NativeDestT)),
- std::unique_ptr<Literal>>::type
+ Literal>::type
BitcastBetweenNativeTypes(const LiteralBase& src_literal) {
auto converter = [](NativeSrcT src) {
return tensorflow::bit_cast<NativeDestT>(src);
@@ -1232,20 +1222,20 @@ BitcastBetweenNativeTypes(const LiteralBase& src_literal) {
// identical sizes higher up.
template <typename NativeSrcT, typename NativeDestT>
typename std::enable_if<(sizeof(NativeSrcT) != sizeof(NativeDestT)),
- std::unique_ptr<Literal>>::type
+ Literal>::type
BitcastBetweenNativeTypes(const LiteralBase& src_literal) {
LOG(FATAL) << "Invalid bitcast between types of different sizes.";
}
template <PrimitiveType primitive_src_type>
-std::unique_ptr<Literal> ConvertToC64(const LiteralBase& src_literal) {
+Literal ConvertToC64(const LiteralBase& src_literal) {
CHECK(ShapeUtil::IsArray(src_literal.shape()));
- auto result_literal = absl::make_unique<Literal>(
+ Literal result_literal(
ShapeUtil::ChangeElementType(src_literal.shape(), C64));
using NativeSrcT =
typename primitive_util::PrimitiveTypeToNative<primitive_src_type>::type;
absl::Span<const NativeSrcT> src_data = src_literal.data<NativeSrcT>();
- absl::Span<complex64> dest_data = result_literal->data<complex64>();
+ absl::Span<complex64> dest_data = result_literal.data<complex64>();
int64 num_elements = src_literal.element_count();
for (int64 i = 0; i < num_elements; ++i) {
dest_data[i] = complex64(static_cast<float>(src_data[i]), 0);
@@ -1254,8 +1244,7 @@ std::unique_ptr<Literal> ConvertToC64(const LiteralBase& src_literal) {
}
template <PrimitiveType primitive_src_type, PrimitiveType primitive_dest_type>
-std::unique_ptr<Literal> ConvertIfTypesMatch(const LiteralBase& src_literal,
- bool bitcast) {
+Literal ConvertIfTypesMatch(const LiteralBase& src_literal, bool bitcast) {
CHECK_EQ(primitive_src_type, src_literal.shape().element_type());
if (bitcast) {
return BitcastBetweenNativeTypes<
@@ -1273,9 +1262,9 @@ std::unique_ptr<Literal> ConvertIfTypesMatch(const LiteralBase& src_literal,
}
template <PrimitiveType primitive_src_type>
-StatusOr<std::unique_ptr<Literal>> ConvertIfDestTypeMatches(
- const LiteralBase& src_literal, PrimitiveType primitive_dest_type,
- bool bitcast) {
+StatusOr<Literal> ConvertIfDestTypeMatches(const LiteralBase& src_literal,
+ PrimitiveType primitive_dest_type,
+ bool bitcast) {
switch (primitive_dest_type) {
#define CONVERT_IF_TYPES_MATCH(type) \
case (type): \
@@ -1307,12 +1296,12 @@ StatusOr<std::unique_ptr<Literal>> ConvertIfDestTypeMatches(
PrimitiveType_Name(primitive_dest_type));
}
-StatusOr<std::unique_ptr<Literal>> ConvertSwitch(
- const LiteralBase& literal, PrimitiveType primitive_dest_type,
- bool bitcast) {
+StatusOr<Literal> ConvertSwitch(const LiteralBase& literal,
+ PrimitiveType primitive_dest_type,
+ bool bitcast) {
TF_RET_CHECK(ShapeUtil::IsArray(literal.shape()));
if (literal.shape().element_type() == primitive_dest_type) {
- return literal.CloneToUnique();
+ return literal.Clone();
}
switch (literal.shape().element_type()) {
#define CONVERT_IF_DEST_TYPE_MATCHES(type) \
@@ -1342,12 +1331,12 @@ StatusOr<std::unique_ptr<Literal>> ConvertSwitch(
} // namespace
-StatusOr<std::unique_ptr<Literal>> LiteralBase::Convert(
+StatusOr<Literal> LiteralBase::Convert(
PrimitiveType primitive_dest_type) const {
return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/false);
}
-StatusOr<std::unique_ptr<Literal>> LiteralBase::BitcastConvert(
+StatusOr<Literal> LiteralBase::BitcastConvert(
PrimitiveType primitive_dest_type) const {
if (primitive_util::BitWidth(shape().element_type()) !=
primitive_util::BitWidth(primitive_dest_type)) {
@@ -1362,17 +1351,8 @@ StatusOr<std::unique_ptr<Literal>> LiteralBase::BitcastConvert(
return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/true);
}
-StatusOr<std::unique_ptr<Literal>> LiteralBase::ConvertToShape(
- const Shape& dest_shape, bool round_f32_to_bf16) const {
+StatusOr<Literal> LiteralBase::ConvertToShape(const Shape& dest_shape) const {
if (!ShapeUtil::IsTuple(dest_shape)) {
- if (round_f32_to_bf16 && shape().element_type() == F32 &&
- dest_shape.element_type() == BF16) {
- auto converter = [](float src) {
- return tensorflow::bfloat16::round_to_bfloat16(src);
- };
- return ConvertBetweenNativeTypesWithConverter<float, bfloat16>(*this,
- converter);
- }
return Convert(dest_shape.element_type());
}
std::vector<Literal> elements;
@@ -1381,11 +1361,9 @@ StatusOr<std::unique_ptr<Literal>> LiteralBase::ConvertToShape(
TF_ASSIGN_OR_RETURN(
auto new_element,
element.ConvertToShape(ShapeUtil::GetSubshape(dest_shape, {i})));
- elements.push_back(std::move(*new_element));
+ elements.push_back(std::move(new_element));
}
- auto converted = absl::make_unique<Literal>();
- *converted = MutableLiteralBase::MoveIntoTuple(absl::MakeSpan(elements));
- return std::move(converted);
+ return MutableLiteralBase::MoveIntoTuple(absl::MakeSpan(elements));
}
/* static */ Literal MutableLiteralBase::MoveIntoTuple(
@@ -1782,6 +1760,10 @@ void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const {
case PRED:
CopyToRepeatedField(proto->mutable_preds(), data<bool>());
break;
+ case S8:
+ proto->set_s8s(static_cast<const signed char*>(data<int8>().data()),
+ element_count());
+ break;
case U8:
proto->set_u8s(static_cast<const unsigned char*>(data<uint8>().data()),
element_count());
@@ -1872,6 +1854,11 @@ Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) {
case PRED:
TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<bool>(), proto.preds()));
break;
+ case S8: {
+ auto s8_data = data<int8>();
+ TF_RET_CHECK(proto.s8s().size() == s8_data.size());
+ std::copy(proto.s8s().begin(), proto.s8s().end(), s8_data.begin());
+ } break;
case U8: {
auto u8_data = data<uint8>();
TF_RET_CHECK(proto.u8s().size() == u8_data.size());