diff options
Diffstat (limited to 'tensorflow/compiler/xla/literal.cc')
-rw-r--r-- | tensorflow/compiler/xla/literal.cc | 149 |
1 files changed, 68 insertions, 81 deletions
diff --git a/tensorflow/compiler/xla/literal.cc b/tensorflow/compiler/xla/literal.cc index 3f7635bd40..5035f41988 100644 --- a/tensorflow/compiler/xla/literal.cc +++ b/tensorflow/compiler/xla/literal.cc @@ -174,9 +174,9 @@ Literal& Literal::operator=(Literal&& other) { return *this; } -std::unique_ptr<Literal> LiteralBase::CreateFromShape(const Shape& shape) { - auto literal = absl::make_unique<Literal>(shape); - literal->root_piece_->ForEachMutableSubpiece( +Literal LiteralBase::CreateFromShape(const Shape& shape) { + Literal literal(shape); + literal.root_piece_->ForEachMutableSubpiece( [&](const ShapeIndex& index, Piece* piece) { if (ShapeUtil::IsArray(piece->subshape())) { memset(piece->untyped_data(), 0, piece->size_bytes()); @@ -278,8 +278,8 @@ Status MutableLiteralBase::CopyElementFrom(const LiteralSlice& src_literal, return Status::OK(); } -/* static */ StatusOr<std::unique_ptr<Literal>> -MutableLiteralBase::CreateFromProto(const LiteralProto& proto) { +/* static */ StatusOr<Literal> MutableLiteralBase::CreateFromProto( + const LiteralProto& proto) { if (!proto.has_shape()) { return InvalidArgument("LiteralProto has no shape"); } @@ -287,9 +287,9 @@ MutableLiteralBase::CreateFromProto(const LiteralProto& proto) { return InvalidArgument("LiteralProto has no layout"); } - auto literal = absl::make_unique<Literal>(proto.shape()); + Literal literal(proto.shape()); - TF_RETURN_IF_ERROR(literal->root_piece_->ForEachMutableSubpieceWithStatus( + TF_RETURN_IF_ERROR(literal.root_piece_->ForEachMutableSubpieceWithStatus( [&](const ShapeIndex& index, Piece* piece) { const LiteralProto* proto_element = &proto; for (int64 i : index) { @@ -556,38 +556,37 @@ void MutableLiteralBase::PopulateR1(const tensorflow::core::Bitmap& values) { } } -std::unique_ptr<Literal> LiteralBase::Relayout( - const Layout& new_layout, const ShapeIndex& shape_index) const { +Literal LiteralBase::Relayout(const Layout& new_layout, + const ShapeIndex& shape_index) const { // Create new shape with 'new_layout' set at the given shape index. Shape new_shape = shape(); Shape* subshape = ShapeUtil::GetMutableSubshape(&new_shape, shape_index); TF_CHECK_OK(LayoutUtil::ValidateLayoutForShape(new_layout, *subshape)); *subshape->mutable_layout() = new_layout; - auto result = absl::make_unique<Literal>(new_shape); - TF_CHECK_OK(result->CopyFrom(*this)); + Literal result(new_shape); + TF_CHECK_OK(result.CopyFrom(*this)); return result; } -std::unique_ptr<Literal> LiteralBase::Relayout( - const Shape& shape_with_layout) const { +Literal LiteralBase::Relayout(const Shape& shape_with_layout) const { CHECK(ShapeUtil::Compatible(shape_with_layout, shape())) << "Given shape_with_layout " << ShapeUtil::HumanString(shape_with_layout) << " not compatible with literal shape " << ShapeUtil::HumanString(shape()); - std::unique_ptr<Literal> result = CreateFromShape(shape_with_layout); + Literal result = CreateFromShape(shape_with_layout); ShapeUtil::ForEachSubshape( - result->shape(), + result.shape(), [this, &result](const Shape& subshape, const ShapeIndex& index) { if (ShapeUtil::IsArray(subshape)) { - TF_CHECK_OK(result->CopyFrom(*this, - /*dest_shape_index=*/index, - /*src_shape_index=*/index)); + TF_CHECK_OK(result.CopyFrom(*this, + /*dest_shape_index=*/index, + /*src_shape_index=*/index)); } }); return result; } -StatusOr<std::unique_ptr<Literal>> LiteralBase::Broadcast( +StatusOr<Literal> LiteralBase::Broadcast( const Shape& result_shape, absl::Span<const int64> dimensions) const { if (!ShapeUtil::IsArray(shape())) { return InvalidArgument("Broadcast only supports arrays."); @@ -598,14 +597,14 @@ StatusOr<std::unique_ptr<Literal>> LiteralBase::Broadcast( result_shape.dimensions(dimensions[i])); } - std::unique_ptr<Literal> result = absl::make_unique<Literal>(result_shape); + Literal result(result_shape); // scratch_source_index is temporary storage space for the computed index into // the input literal. We put it here to avoid allocating an std::vector in // every iteration of ShapeUtil::ForEachIndex. std::vector<int64> scratch_source_index(shape().dimensions_size()); - char* dest_data = static_cast<char*>(result->untyped_data()); + char* dest_data = static_cast<char*>(result.untyped_data()); const char* source_data = static_cast<const char*>(untyped_data()); const int64 primitive_size = ShapeUtil::ByteSizeOfPrimitiveType(shape().element_type()); @@ -627,37 +626,36 @@ StatusOr<std::unique_ptr<Literal>> LiteralBase::Broadcast( return std::move(result); } -StatusOr<std::unique_ptr<Literal>> LiteralBase::Reshape( +StatusOr<Literal> LiteralBase::Reshape( absl::Span<const int64> dimensions) const { if (!ShapeUtil::IsArray(shape())) { return InvalidArgument("Reshape does not support tuples."); } - std::unique_ptr<Literal> output; + Literal output; if (!LayoutUtil::IsMonotonicWithDim0Major(shape().layout())) { output = Relayout(LayoutUtil::GetDefaultLayoutForRank(ShapeUtil::Rank(shape()))); } else { - output = CloneToUnique(); + output = Clone(); } // Because the layout is monotonic, we can simply reuse the same sequence of // values without changing their order. - *output->mutable_shape_do_not_use() = + *output.mutable_shape_do_not_use() = ShapeUtil::MakeShape(shape().element_type(), dimensions); int64 elements_before = ShapeUtil::ElementsIn(shape()); - int64 elements_after = ShapeUtil::ElementsIn(output->shape()); + int64 elements_after = ShapeUtil::ElementsIn(output.shape()); if (elements_before != elements_after) { return InvalidArgument( "Shapes before and after Literal::Reshape have different numbers " "of elements: %s vs %s.", ShapeUtil::HumanString(shape()), - ShapeUtil::HumanString(output->shape())); + ShapeUtil::HumanString(output.shape())); } return std::move(output); } -std::unique_ptr<Literal> LiteralBase::Transpose( - absl::Span<const int64> permutation) const { +Literal LiteralBase::Transpose(absl::Span<const int64> permutation) const { CHECK(ShapeUtil::IsArray(shape())) << "Tuple is not supported for transpose"; CHECK(IsPermutation(permutation, ShapeUtil::Rank(shape()))) << "Given permutation is not a permutation of dimension numbers"; @@ -687,32 +685,31 @@ std::unique_ptr<Literal> LiteralBase::Transpose( for (auto index : LayoutUtil::MinorToMajor(shape())) { layout->add_minor_to_major(inverse_permutation[index]); } - auto new_literal = absl::make_unique<Literal>(permuted_shape); - DCHECK_EQ(ShapeUtil::ByteSizeOf(new_literal->shape()), + Literal new_literal(permuted_shape); + DCHECK_EQ(ShapeUtil::ByteSizeOf(new_literal.shape()), ShapeUtil::ByteSizeOf(shape())); - std::memcpy(new_literal->untyped_data(), untyped_data(), size_bytes()); + std::memcpy(new_literal.untyped_data(), untyped_data(), size_bytes()); return new_literal; } template <typename NativeT> -std::unique_ptr<Literal> LiteralBase::SliceInternal( +Literal LiteralBase::SliceInternal( const Shape& result_shape, absl::Span<const int64> start_indices) const { - auto result_literal = absl::make_unique<Literal>(result_shape); + Literal result_literal(result_shape); DimensionVector new_indices(ShapeUtil::Rank(result_shape)); - result_literal->EachCell<NativeT>( + result_literal.EachCell<NativeT>( [&](absl::Span<const int64> indices, NativeT /*value*/) { for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) { new_indices[i] = indices[i] + start_indices[i]; } NativeT value = Get<NativeT>(new_indices); - result_literal->Set<NativeT>(indices, value); + result_literal.Set<NativeT>(indices, value); }); return result_literal; } -std::unique_ptr<Literal> LiteralBase::Slice( - absl::Span<const int64> start_indices, - absl::Span<const int64> limit_indices) const { +Literal LiteralBase::Slice(absl::Span<const int64> start_indices, + absl::Span<const int64> limit_indices) const { CHECK(ShapeUtil::IsArray(shape())) << "tuple is not supported for slice"; DimensionVector result_dimensions; @@ -750,12 +747,6 @@ Literal LiteralBase::Clone() const { return result; } -std::unique_ptr<Literal> LiteralBase::CloneToUnique() const { - auto result = absl::make_unique<Literal>(shape()); - TF_CHECK_OK(result->CopyFrom(*this)); - return result; -} - string LiteralBase::GetAsString(absl::Span<const int64> multi_index, const ShapeIndex& shape_index) const { const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index); @@ -1191,14 +1182,14 @@ void LiteralBase::EachCellAsString( namespace { template <typename NativeSrcT, typename NativeDestT, typename ConverterType> -std::unique_ptr<Literal> ConvertBetweenNativeTypesWithConverter( - const LiteralBase& src_literal, const ConverterType& converter) { +Literal ConvertBetweenNativeTypesWithConverter(const LiteralBase& src_literal, + const ConverterType& converter) { CHECK(ShapeUtil::IsArray(src_literal.shape())); - auto result_literal = absl::make_unique<Literal>(ShapeUtil::ChangeElementType( + Literal result_literal(ShapeUtil::ChangeElementType( src_literal.shape(), primitive_util::NativeToPrimitiveType<NativeDestT>())); auto src_data = src_literal.data<NativeSrcT>(); - auto dest_data = result_literal->template data<NativeDestT>(); + auto dest_data = result_literal.template data<NativeDestT>(); int64 num_elements = src_literal.element_count(); for (int64 i = 0; i < num_elements; ++i) { @@ -1208,8 +1199,7 @@ std::unique_ptr<Literal> ConvertBetweenNativeTypesWithConverter( } template <typename NativeSrcT, typename NativeDestT> -std::unique_ptr<Literal> ConvertBetweenNativeTypes( - const LiteralBase& src_literal) { +Literal ConvertBetweenNativeTypes(const LiteralBase& src_literal) { auto converter = [](NativeSrcT src) { return static_cast<NativeDestT>(src); }; return ConvertBetweenNativeTypesWithConverter<NativeSrcT, NativeDestT>( src_literal, converter); @@ -1217,7 +1207,7 @@ std::unique_ptr<Literal> ConvertBetweenNativeTypes( template <typename NativeSrcT, typename NativeDestT> typename std::enable_if<(sizeof(NativeSrcT) == sizeof(NativeDestT)), - std::unique_ptr<Literal>>::type + Literal>::type BitcastBetweenNativeTypes(const LiteralBase& src_literal) { auto converter = [](NativeSrcT src) { return tensorflow::bit_cast<NativeDestT>(src); @@ -1232,20 +1222,20 @@ BitcastBetweenNativeTypes(const LiteralBase& src_literal) { // identical sizes higher up. template <typename NativeSrcT, typename NativeDestT> typename std::enable_if<(sizeof(NativeSrcT) != sizeof(NativeDestT)), - std::unique_ptr<Literal>>::type + Literal>::type BitcastBetweenNativeTypes(const LiteralBase& src_literal) { LOG(FATAL) << "Invalid bitcast between types of different sizes."; } template <PrimitiveType primitive_src_type> -std::unique_ptr<Literal> ConvertToC64(const LiteralBase& src_literal) { +Literal ConvertToC64(const LiteralBase& src_literal) { CHECK(ShapeUtil::IsArray(src_literal.shape())); - auto result_literal = absl::make_unique<Literal>( + Literal result_literal( ShapeUtil::ChangeElementType(src_literal.shape(), C64)); using NativeSrcT = typename primitive_util::PrimitiveTypeToNative<primitive_src_type>::type; absl::Span<const NativeSrcT> src_data = src_literal.data<NativeSrcT>(); - absl::Span<complex64> dest_data = result_literal->data<complex64>(); + absl::Span<complex64> dest_data = result_literal.data<complex64>(); int64 num_elements = src_literal.element_count(); for (int64 i = 0; i < num_elements; ++i) { dest_data[i] = complex64(static_cast<float>(src_data[i]), 0); @@ -1254,8 +1244,7 @@ std::unique_ptr<Literal> ConvertToC64(const LiteralBase& src_literal) { } template <PrimitiveType primitive_src_type, PrimitiveType primitive_dest_type> -std::unique_ptr<Literal> ConvertIfTypesMatch(const LiteralBase& src_literal, - bool bitcast) { +Literal ConvertIfTypesMatch(const LiteralBase& src_literal, bool bitcast) { CHECK_EQ(primitive_src_type, src_literal.shape().element_type()); if (bitcast) { return BitcastBetweenNativeTypes< @@ -1273,9 +1262,9 @@ std::unique_ptr<Literal> ConvertIfTypesMatch(const LiteralBase& src_literal, } template <PrimitiveType primitive_src_type> -StatusOr<std::unique_ptr<Literal>> ConvertIfDestTypeMatches( - const LiteralBase& src_literal, PrimitiveType primitive_dest_type, - bool bitcast) { +StatusOr<Literal> ConvertIfDestTypeMatches(const LiteralBase& src_literal, + PrimitiveType primitive_dest_type, + bool bitcast) { switch (primitive_dest_type) { #define CONVERT_IF_TYPES_MATCH(type) \ case (type): \ @@ -1307,12 +1296,12 @@ StatusOr<std::unique_ptr<Literal>> ConvertIfDestTypeMatches( PrimitiveType_Name(primitive_dest_type)); } -StatusOr<std::unique_ptr<Literal>> ConvertSwitch( - const LiteralBase& literal, PrimitiveType primitive_dest_type, - bool bitcast) { +StatusOr<Literal> ConvertSwitch(const LiteralBase& literal, + PrimitiveType primitive_dest_type, + bool bitcast) { TF_RET_CHECK(ShapeUtil::IsArray(literal.shape())); if (literal.shape().element_type() == primitive_dest_type) { - return literal.CloneToUnique(); + return literal.Clone(); } switch (literal.shape().element_type()) { #define CONVERT_IF_DEST_TYPE_MATCHES(type) \ @@ -1342,12 +1331,12 @@ StatusOr<std::unique_ptr<Literal>> ConvertSwitch( } // namespace -StatusOr<std::unique_ptr<Literal>> LiteralBase::Convert( +StatusOr<Literal> LiteralBase::Convert( PrimitiveType primitive_dest_type) const { return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/false); } -StatusOr<std::unique_ptr<Literal>> LiteralBase::BitcastConvert( +StatusOr<Literal> LiteralBase::BitcastConvert( PrimitiveType primitive_dest_type) const { if (primitive_util::BitWidth(shape().element_type()) != primitive_util::BitWidth(primitive_dest_type)) { @@ -1362,17 +1351,8 @@ StatusOr<std::unique_ptr<Literal>> LiteralBase::BitcastConvert( return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/true); } -StatusOr<std::unique_ptr<Literal>> LiteralBase::ConvertToShape( - const Shape& dest_shape, bool round_f32_to_bf16) const { +StatusOr<Literal> LiteralBase::ConvertToShape(const Shape& dest_shape) const { if (!ShapeUtil::IsTuple(dest_shape)) { - if (round_f32_to_bf16 && shape().element_type() == F32 && - dest_shape.element_type() == BF16) { - auto converter = [](float src) { - return tensorflow::bfloat16::round_to_bfloat16(src); - }; - return ConvertBetweenNativeTypesWithConverter<float, bfloat16>(*this, - converter); - } return Convert(dest_shape.element_type()); } std::vector<Literal> elements; @@ -1381,11 +1361,9 @@ StatusOr<std::unique_ptr<Literal>> LiteralBase::ConvertToShape( TF_ASSIGN_OR_RETURN( auto new_element, element.ConvertToShape(ShapeUtil::GetSubshape(dest_shape, {i}))); - elements.push_back(std::move(*new_element)); + elements.push_back(std::move(new_element)); } - auto converted = absl::make_unique<Literal>(); - *converted = MutableLiteralBase::MoveIntoTuple(absl::MakeSpan(elements)); - return std::move(converted); + return MutableLiteralBase::MoveIntoTuple(absl::MakeSpan(elements)); } /* static */ Literal MutableLiteralBase::MoveIntoTuple( @@ -1782,6 +1760,10 @@ void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const { case PRED: CopyToRepeatedField(proto->mutable_preds(), data<bool>()); break; + case S8: + proto->set_s8s(static_cast<const signed char*>(data<int8>().data()), + element_count()); + break; case U8: proto->set_u8s(static_cast<const unsigned char*>(data<uint8>().data()), element_count()); @@ -1872,6 +1854,11 @@ Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) { case PRED: TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<bool>(), proto.preds())); break; + case S8: { + auto s8_data = data<int8>(); + TF_RET_CHECK(proto.s8s().size() == s8_data.size()); + std::copy(proto.s8s().begin(), proto.s8s().end(), s8_data.begin()); + } break; case U8: { auto u8_data = data<uint8>(); TF_RET_CHECK(proto.u8s().size() == u8_data.size()); |