diff options
author | 2018-08-20 20:20:14 -0700 | |
---|---|---|
committer | 2018-08-20 20:23:24 -0700 | |
commit | e924d67bff8c4fb58c8316d00b662f8d1e80eb95 (patch) | |
tree | bf1b0f5b9d0c699150295f98187b19d6a10710a6 /tensorflow/compiler/xla/literal.cc | |
parent | 49115abfd39d30506679d9fdc572ccd2f7c22dbe (diff) |
[XLA] Use absl::make_unique instead of xla::MakeUnique.
Same for WrapUnique.
PiperOrigin-RevId: 209531124
Diffstat (limited to 'tensorflow/compiler/xla/literal.cc')
-rw-r--r-- | tensorflow/compiler/xla/literal.cc | 41 |
1 files changed, 21 insertions, 20 deletions
diff --git a/tensorflow/compiler/xla/literal.cc b/tensorflow/compiler/xla/literal.cc index 36e472568e..d54f051a1a 100644 --- a/tensorflow/compiler/xla/literal.cc +++ b/tensorflow/compiler/xla/literal.cc @@ -22,6 +22,7 @@ limitations under the License. #include <numeric> #include <vector> +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/index_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -134,7 +135,7 @@ void Literal::SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays) { Literal::Literal(const Shape& shape, bool allocate_arrays) : MutableLiteralBase() { - shape_ = MakeUnique<Shape>(shape); + shape_ = absl::make_unique<Shape>(shape); CHECK(LayoutUtil::HasLayout(*shape_)); root_piece_ = new Piece(); root_piece_->set_subshape(shape_.get()); @@ -175,7 +176,7 @@ Literal& Literal::operator=(Literal&& other) { } std::unique_ptr<Literal> LiteralBase::CreateFromShape(const Shape& shape) { - auto literal = MakeUnique<Literal>(shape); + auto literal = absl::make_unique<Literal>(shape); literal->root_piece_->ForEachMutableSubpiece( [&](const ShapeIndex& index, Piece* piece) { if (ShapeUtil::IsArray(piece->subshape())) { @@ -289,7 +290,7 @@ MutableLiteralBase::CreateFromProto(const LiteralProto& proto) { return InvalidArgument("LiteralProto has no layout"); } - auto literal = MakeUnique<Literal>(proto.shape()); + auto literal = absl::make_unique<Literal>(proto.shape()); TF_RETURN_IF_ERROR(literal->root_piece_->ForEachMutableSubpieceWithStatus( [&](const ShapeIndex& index, Piece* piece) { @@ -479,7 +480,7 @@ Status Literal::MoveFrom(Literal&& src_literal, dest_piece.set_sparse_indices(src_piece.sparse_indices()); }); - src_literal.shape_ = MakeUnique<Shape>(ShapeUtil::MakeNil()); + src_literal.shape_ = absl::make_unique<Shape>(ShapeUtil::MakeNil()); delete src_literal.root_piece_; src_literal.root_piece_ = new LiteralBase::Piece(); src_literal.root_piece_->set_subshape(src_literal.shape_.get()); @@ -566,7 +567,7 @@ std::unique_ptr<Literal> LiteralBase::Relayout( Shape* subshape = ShapeUtil::GetMutableSubshape(&new_shape, shape_index); TF_CHECK_OK(LayoutUtil::ValidateLayoutForShape(new_layout, *subshape)); *subshape->mutable_layout() = new_layout; - auto result = MakeUnique<Literal>(new_shape); + auto result = absl::make_unique<Literal>(new_shape); TF_CHECK_OK(result->CopyFrom(*this)); return result; } @@ -602,7 +603,7 @@ StatusOr<std::unique_ptr<Literal>> LiteralBase::Broadcast( result_shape.dimensions(dimensions[i])); } - std::unique_ptr<Literal> result = MakeUnique<Literal>(result_shape); + std::unique_ptr<Literal> result = absl::make_unique<Literal>(result_shape); // scratch_source_index is temporary storage space for the computed index into // the input literal. We put it here to avoid allocating an std::vector in @@ -691,7 +692,7 @@ std::unique_ptr<Literal> LiteralBase::Transpose( for (auto index : LayoutUtil::MinorToMajor(shape())) { layout->add_minor_to_major(inverse_permutation[index]); } - auto new_literal = MakeUnique<Literal>(permuted_shape); + auto new_literal = absl::make_unique<Literal>(permuted_shape); DCHECK_EQ(ShapeUtil::ByteSizeOf(new_literal->shape()), ShapeUtil::ByteSizeOf(shape())); std::memcpy(new_literal->untyped_data(), untyped_data(), size_bytes()); @@ -702,7 +703,7 @@ template <typename NativeT> std::unique_ptr<Literal> LiteralBase::SliceInternal( const Shape& result_shape, tensorflow::gtl::ArraySlice<int64> start_indices) const { - auto result_literal = MakeUnique<Literal>(result_shape); + auto result_literal = absl::make_unique<Literal>(result_shape); DimensionVector new_indices(ShapeUtil::Rank(result_shape)); result_literal->EachCell<NativeT>( [&](tensorflow::gtl::ArraySlice<int64> indices, NativeT /*value*/) { @@ -756,7 +757,7 @@ Literal LiteralBase::Clone() const { } std::unique_ptr<Literal> LiteralBase::CloneToUnique() const { - auto result = MakeUnique<Literal>(shape()); + auto result = absl::make_unique<Literal>(shape()); TF_CHECK_OK(result->CopyFrom(*this)); return result; } @@ -1203,7 +1204,7 @@ template <typename NativeSrcT, typename NativeDestT, typename ConverterType> std::unique_ptr<Literal> ConvertBetweenNativeTypesWithConverter( const LiteralBase& src_literal, const ConverterType& converter) { CHECK(ShapeUtil::IsArray(src_literal.shape())); - auto result_literal = MakeUnique<Literal>(ShapeUtil::ChangeElementType( + auto result_literal = absl::make_unique<Literal>(ShapeUtil::ChangeElementType( src_literal.shape(), primitive_util::NativeToPrimitiveType<NativeDestT>())); auto src_data = src_literal.data<NativeSrcT>(); @@ -1249,7 +1250,7 @@ BitcastBetweenNativeTypes(const LiteralBase& src_literal) { template <PrimitiveType primitive_src_type> std::unique_ptr<Literal> ConvertToC64(const LiteralBase& src_literal) { CHECK(ShapeUtil::IsArray(src_literal.shape())); - auto result_literal = MakeUnique<Literal>( + auto result_literal = absl::make_unique<Literal>( ShapeUtil::ChangeElementType(src_literal.shape(), C64)); using NativeSrcT = typename primitive_util::PrimitiveTypeToNative<primitive_src_type>::type; @@ -1396,7 +1397,7 @@ StatusOr<std::unique_ptr<Literal>> LiteralBase::ConvertToShape( element.ConvertToShape(ShapeUtil::GetSubshape(dest_shape, {i}))); elements.push_back(std::move(*new_element)); } - auto converted = MakeUnique<Literal>(); + auto converted = absl::make_unique<Literal>(); *converted = MutableLiteralBase::MoveIntoTuple(&elements); return std::move(converted); } @@ -1956,7 +1957,7 @@ MutableLiteralBase::~MutableLiteralBase() {} MutableBorrowingLiteral::MutableBorrowingLiteral( const MutableBorrowingLiteral& literal) : MutableLiteralBase() { - shape_ = MakeUnique<Shape>(literal.shape()); + shape_ = absl::make_unique<Shape>(literal.shape()); CHECK(LayoutUtil::HasLayout(*shape_)); root_piece_ = new Piece(); @@ -1967,7 +1968,7 @@ MutableBorrowingLiteral::MutableBorrowingLiteral( MutableBorrowingLiteral& MutableBorrowingLiteral::operator=( const MutableBorrowingLiteral& literal) { - shape_ = MakeUnique<Shape>(literal.shape()); + shape_ = absl::make_unique<Shape>(literal.shape()); CHECK(LayoutUtil::HasLayout(*shape_)); root_piece_ = new Piece(); @@ -1981,7 +1982,7 @@ MutableBorrowingLiteral& MutableBorrowingLiteral::operator=( MutableBorrowingLiteral::MutableBorrowingLiteral( const MutableLiteralBase& literal) : MutableLiteralBase() { - shape_ = MakeUnique<Shape>(literal.shape()); + shape_ = absl::make_unique<Shape>(literal.shape()); CHECK(LayoutUtil::HasLayout(*shape_)); root_piece_ = new Piece(); @@ -1992,7 +1993,7 @@ MutableBorrowingLiteral::MutableBorrowingLiteral( MutableBorrowingLiteral::MutableBorrowingLiteral(MutableLiteralBase* literal) : MutableLiteralBase() { - shape_ = MakeUnique<Shape>(literal->shape()); + shape_ = absl::make_unique<Shape>(literal->shape()); CHECK(LayoutUtil::HasLayout(*shape_)); root_piece_ = new Piece(); @@ -2004,7 +2005,7 @@ MutableBorrowingLiteral::MutableBorrowingLiteral(MutableLiteralBase* literal) MutableBorrowingLiteral::MutableBorrowingLiteral( MutableBorrowingLiteral literal, const ShapeIndex& view_root) : MutableLiteralBase() { - shape_ = MakeUnique<Shape>(literal.piece(view_root).subshape()); + shape_ = absl::make_unique<Shape>(literal.piece(view_root).subshape()); CHECK(LayoutUtil::HasLayout(*shape_)); root_piece_ = new Piece(); @@ -2016,7 +2017,7 @@ MutableBorrowingLiteral::MutableBorrowingLiteral( MutableBorrowingLiteral::MutableBorrowingLiteral(const char* src_buf_ptr, const Shape& shape) : MutableLiteralBase() { - shape_ = MakeUnique<Shape>(shape); + shape_ = absl::make_unique<Shape>(shape); CHECK(LayoutUtil::HasLayout(*shape_)); CHECK(!ShapeUtil::IsTuple(*shape_)); @@ -2061,7 +2062,7 @@ void BorrowingLiteral::BuildPieceSubtree(const Shape& shape, Piece* piece) { } BorrowingLiteral::BorrowingLiteral(const char* src_buf_ptr, const Shape& shape) - : LiteralBase(), shape_(MakeUnique<Shape>(shape)) { + : LiteralBase(), shape_(absl::make_unique<Shape>(shape)) { CHECK(ShapeUtil::IsArray(*shape_)); CHECK(LayoutUtil::HasLayout(*shape_)); @@ -2072,7 +2073,7 @@ BorrowingLiteral::BorrowingLiteral(const char* src_buf_ptr, const Shape& shape) BorrowingLiteral::BorrowingLiteral( tensorflow::gtl::ArraySlice<const char*> src_buf_ptrs, const Shape& shape) - : LiteralBase(), shape_(MakeUnique<Shape>(shape)) { + : LiteralBase(), shape_(absl::make_unique<Shape>(shape)) { CHECK(ShapeUtil::IsTuple(*shape_)); CHECK(!ShapeUtil::IsNestedTuple(*shape_)); CHECK_EQ(src_buf_ptrs.size(), ShapeUtil::TupleElementCount(*shape_)); |