diff options
19 files changed, 404 insertions, 214 deletions
diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc index 8cf198239c..0100bf51ed 100644 --- a/tensorflow/compiler/jit/xla_device_context.cc +++ b/tensorflow/compiler/jit/xla_device_context.cc @@ -101,34 +101,27 @@ Status XlaTransferManager::TransferLiteralToDevice( // Unref the host tensor, and capture the literal shared_ptr too so it goes // out of scope when the lambda completes. host_to_device_stream_->ThenDoHostCallback([ref, literal]() { ref.Unref(); }); + return Status::OK(); } void XlaTransferManager::TransferLiteralFromDevice( Tensor* host_tensor, const Tensor& device_tensor, const StatusCallback& done) const { + xla::MutableBorrowingLiteral literal; + TF_CHECK_OK(HostTensorToMutableBorrowingLiteral(host_tensor, &literal)); + const xla::ShapedBuffer& shaped_buffer = XlaTensor::FromTensor(&device_tensor)->shaped_buffer(); TensorReference ref(device_tensor); transfer_manager_->TransferLiteralFromDevice( - device_to_host_stream_, shaped_buffer, - [=, &shaped_buffer]( - xla::StatusOr<std::unique_ptr<xla::Literal> > literal_or) { + device_to_host_stream_, shaped_buffer, literal, + [=, &shaped_buffer, &literal](xla::Status status) { ref.Unref(); done([&]() -> Status { - TF_ASSIGN_OR_RETURN(auto literal, std::move(literal_or)); - VLOG(1) << "Transfer from device as literal: " << literal->ToString() + VLOG(1) << "Transfer from device as literal: " << literal.ToString() << " " << shaped_buffer.ToString(); - Tensor tensor; - TF_RETURN_IF_ERROR( - LiteralToHostTensor(*literal, host_tensor->dtype(), &tensor)); - // Reshape the tensor back to its declared shape. - Status status; - if (!host_tensor->CopyFrom(tensor, device_tensor.shape())) { - status = errors::Internal( - "Tensor::CopyFrom failed when copying from XLA device to CPU"); - } return status; }()); }); diff --git a/tensorflow/compiler/tf2xla/literal_util.cc b/tensorflow/compiler/tf2xla/literal_util.cc index 2fb66913ad..77da1bf29c 100644 --- a/tensorflow/compiler/tf2xla/literal_util.cc +++ b/tensorflow/compiler/tf2xla/literal_util.cc @@ -32,6 +32,23 @@ Status HostTensorToBorrowingLiteral(const Tensor& host_tensor, return Status::OK(); } +Status HostTensorToMutableBorrowingLiteral( + Tensor* host_tensor, xla::MutableBorrowingLiteral* literal) { + xla::Shape xla_shape; + TF_RETURN_IF_ERROR(TensorShapeToXLAShape(host_tensor->dtype(), + host_tensor->shape(), &xla_shape)); + return HostTensorToMutableBorrowingLiteral(xla_shape, host_tensor, literal); +} + +Status HostTensorToMutableBorrowingLiteral( + const xla::Shape& xla_shape, Tensor* host_tensor, + xla::MutableBorrowingLiteral* literal) { + *literal = xla::MutableBorrowingLiteral( + static_cast<const char*>(DMAHelper::base(host_tensor)), xla_shape); + + return Status::OK(); +} + Status HostTensorsToBorrowingLiteralTuple( tensorflow::gtl::ArraySlice<Tensor> host_tensors, xla::BorrowingLiteral* literal) { diff --git a/tensorflow/compiler/tf2xla/literal_util.h b/tensorflow/compiler/tf2xla/literal_util.h index 0610a57029..09d6fa8116 100644 --- a/tensorflow/compiler/tf2xla/literal_util.h +++ b/tensorflow/compiler/tf2xla/literal_util.h @@ -30,6 +30,16 @@ namespace tensorflow { // 'host_tensor'. Status HostTensorToBorrowingLiteral(const Tensor& host_tensor, xla::BorrowingLiteral* literal); +// Returns a MutableBorrowingLiteral that utilizes the same underlying buffer +// owned by 'host_tensor', but is mutable via the xla::Literal methods. +Status HostTensorToMutableBorrowingLiteral( + Tensor* host_tensor, xla::MutableBorrowingLiteral* literal); +// Similar as above, except the literal shape is explicitly provided and used +// instead of obtaining it from the 'host_tensor'. The provided literal shape +// 'xla_shape' must be compatible with the shape of 'host_tensor'. +Status HostTensorToMutableBorrowingLiteral( + const xla::Shape& xla_shape, Tensor* host_tensor, + xla::MutableBorrowingLiteral* literal); // Returns a BorrowingLiteral tuple that utilizes the same underlying buffers // owned by 'host_tensors'. diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc index 8a6c5fb9a7..4d96316d3b 100644 --- a/tensorflow/compiler/xla/client/local_client.cc +++ b/tensorflow/compiler/xla/client/local_client.cc @@ -303,7 +303,7 @@ StatusOr<std::unique_ptr<Literal>> LocalClient::TransferFromOutfeedLocal( const Shape& shape, int device_ordinal) { TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, backend().stream_executor(device_ordinal)); - auto literal = MakeUnique<Literal>(); + auto literal = MakeUnique<Literal>(shape); TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralFromOutfeed( executor, shape, literal.get())); return std::move(literal); diff --git a/tensorflow/compiler/xla/literal.cc b/tensorflow/compiler/xla/literal.cc index 0545deb096..36e472568e 100644 --- a/tensorflow/compiler/xla/literal.cc +++ b/tensorflow/compiler/xla/literal.cc @@ -71,7 +71,7 @@ std::ostream& operator<<(std::ostream& out, const Literal& literal) { return out; } -Literal::StrideConfig::StrideConfig( +MutableLiteralBase::StrideConfig::StrideConfig( const Shape& source_shape, const Shape& dest_shape, tensorflow::gtl::ArraySlice<int64> dimensions) : dimensions(dimensions), @@ -133,7 +133,8 @@ void Literal::SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays) { } Literal::Literal(const Shape& shape, bool allocate_arrays) - : LiteralBase(), shape_(MakeUnique<Shape>(shape)) { + : MutableLiteralBase() { + shape_ = MakeUnique<Shape>(shape); CHECK(LayoutUtil::HasLayout(*shape_)); root_piece_ = new Piece(); root_piece_->set_subshape(shape_.get()); @@ -159,7 +160,9 @@ void Literal::DeallocateBuffers() { }); } -Literal::Literal(Literal&& other) : LiteralBase() { *this = std::move(other); } +Literal::Literal(Literal&& other) : MutableLiteralBase() { + *this = std::move(other); +} Literal& Literal::operator=(Literal&& other) { DCHECK(&other.root_piece_->subshape() == other.shape_.get()); @@ -187,12 +190,13 @@ const SparseIndexArray* LiteralBase::sparse_indices( return piece(shape_index).sparse_indices(); } -SparseIndexArray* Literal::sparse_indices(const ShapeIndex& shape_index) { +SparseIndexArray* MutableLiteralBase::sparse_indices( + const ShapeIndex& shape_index) { return piece(shape_index).sparse_indices(); } template <typename NativeT> -Status Literal::CopySliceFromInternal( +Status MutableLiteralBase::CopySliceFromInternal( const LiteralBase& src_literal, tensorflow::gtl::ArraySlice<int64> src_base, tensorflow::gtl::ArraySlice<int64> dest_base, tensorflow::gtl::ArraySlice<int64> copy_size) { @@ -225,8 +229,8 @@ Status Literal::CopySliceFromInternal( // proper stride size at the matching dimension. DimensionVector src_indexes(src_base.size(), 0); DimensionVector dest_indexes(dest_base.size(), 0); - Literal::StrideConfig stride_config(src_literal.shape(), shape(), - copy_size); + MutableLiteralBase::StrideConfig stride_config(src_literal.shape(), shape(), + copy_size); auto copy_proc = [&](tensorflow::gtl::ArraySlice<int64> indexes) { // Map from multi-dimensional index, to source index. @@ -253,9 +257,10 @@ Status Literal::CopySliceFromInternal( return Status::OK(); } -Status Literal::CopyElementFrom(const LiteralSlice& src_literal, - tensorflow::gtl::ArraySlice<int64> src_index, - tensorflow::gtl::ArraySlice<int64> dest_index) { +Status MutableLiteralBase::CopyElementFrom( + const LiteralSlice& src_literal, + tensorflow::gtl::ArraySlice<int64> src_index, + tensorflow::gtl::ArraySlice<int64> dest_index) { DCHECK_EQ(shape().element_type(), src_literal.shape().element_type()); const int64 src_linear_index = IndexUtil::MultidimensionalIndexToLinearIndex( src_literal.shape(), src_index); @@ -275,8 +280,8 @@ Status Literal::CopyElementFrom(const LiteralSlice& src_literal, return Status::OK(); } -/* static */ StatusOr<std::unique_ptr<Literal>> Literal::CreateFromProto( - const LiteralProto& proto) { +/* static */ StatusOr<std::unique_ptr<Literal>> +MutableLiteralBase::CreateFromProto(const LiteralProto& proto) { if (!proto.has_shape()) { return InvalidArgument("LiteralProto has no shape"); } @@ -405,9 +410,9 @@ Status LiteralBase::Piece::CopyFrom(const LiteralBase::Piece& src) { return Status::OK(); } -Status Literal::CopyFrom(const LiteralSlice& src_literal, - const ShapeIndex& dest_shape_index, - const ShapeIndex& src_shape_index) { +Status MutableLiteralBase::CopyFrom(const LiteralSlice& src_literal, + const ShapeIndex& dest_shape_index, + const ShapeIndex& src_shape_index) { const Shape& dest_subshape = ShapeUtil::GetSubshape(shape(), dest_shape_index); const Shape& src_subshape = @@ -482,10 +487,11 @@ Status Literal::MoveFrom(Literal&& src_literal, return Status::OK(); } -Status Literal::CopySliceFrom(const LiteralSlice& src_literal, - tensorflow::gtl::ArraySlice<int64> src_base, - tensorflow::gtl::ArraySlice<int64> dest_base, - tensorflow::gtl::ArraySlice<int64> copy_size) { +Status MutableLiteralBase::CopySliceFrom( + const LiteralSlice& src_literal, + tensorflow::gtl::ArraySlice<int64> src_base, + tensorflow::gtl::ArraySlice<int64> dest_base, + tensorflow::gtl::ArraySlice<int64> copy_size) { TF_RET_CHECK(ShapeUtil::IsArray(shape())) << ShapeUtil::HumanString(shape()); TF_RET_CHECK(ShapeUtil::IsArray(src_literal.shape())) << ShapeUtil::HumanString(src_literal.shape()); @@ -543,7 +549,7 @@ Status Literal::CopySliceFrom(const LiteralSlice& src_literal, shape().element_type()); } -void Literal::PopulateR1(const tensorflow::core::Bitmap& values) { +void MutableLiteralBase::PopulateR1(const tensorflow::core::Bitmap& values) { CHECK(ShapeUtil::IsArray(shape())); CHECK_EQ(ShapeUtil::Rank(shape()), 1); CHECK_EQ(element_count(), values.bits()); @@ -895,8 +901,8 @@ size_t LiteralBase::Hash() const { return hash_value; } -Status Literal::SetIntegralAsS64(tensorflow::gtl::ArraySlice<int64> multi_index, - int64 value) { +Status MutableLiteralBase::SetIntegralAsS64( + tensorflow::gtl::ArraySlice<int64> multi_index, int64 value) { CHECK(LayoutUtil::IsDenseArray(shape())); switch (shape().element_type()) { case PRED: @@ -933,7 +939,7 @@ tensorflow::gtl::ArraySlice<int64> LiteralBase::GetSparseIndex( return p.sparse_indices()->At(sparse_element_number); } -void Literal::SortSparseElements(const ShapeIndex& shape_index) { +void MutableLiteralBase::SortSparseElements(const ShapeIndex& shape_index) { piece(shape_index).SortSparseElements(); } @@ -1391,11 +1397,11 @@ StatusOr<std::unique_ptr<Literal>> LiteralBase::ConvertToShape( elements.push_back(std::move(*new_element)); } auto converted = MakeUnique<Literal>(); - *converted = Literal::MoveIntoTuple(&elements); + *converted = MutableLiteralBase::MoveIntoTuple(&elements); return std::move(converted); } -/* static */ Literal Literal::MoveIntoTuple( +/* static */ Literal MutableLiteralBase::MoveIntoTuple( tensorflow::gtl::MutableArraySlice<Literal> elements) { std::vector<Shape> element_shapes; for (const Literal& element : elements) { @@ -1808,7 +1814,8 @@ Status CopyFromRepeatedField(tensorflow::gtl::MutableArraySlice<NativeT> dest, } // namespace Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) { - // These conditions should have been checked in Literal::CreateFromProto. + // These conditions should have been checked in + // MutableLiteralBase::CreateFromProto. TF_RET_CHECK(proto.has_shape()); TF_RET_CHECK(LayoutUtil::HasLayout(proto.shape())); TF_RET_CHECK(ShapeUtil::Equal(proto.shape(), subshape())); @@ -1900,7 +1907,7 @@ const void* LiteralBase::untyped_data(const ShapeIndex& shape_index) const { return piece(shape_index).untyped_data(); } -void* Literal::untyped_data(const ShapeIndex& shape_index) { +void* MutableLiteralBase::untyped_data(const ShapeIndex& shape_index) { return piece(shape_index).untyped_data(); } @@ -1916,6 +1923,127 @@ string LiteralBase::GetR1U8AsString() const { ShapeUtil::ElementsIn(shape())); } +void MutableBorrowingLiteral::CopyPieceSubtree(const Shape& shape, + Piece* src_piece, + Piece* dest_piece) { + DCHECK(ShapeUtil::Equal(src_piece->subshape(), dest_piece->subshape())) + << "src_piece has shape: " + << ShapeUtil::HumanString(src_piece->subshape()) + << "dest_piece has shape: " + << ShapeUtil::HumanString(dest_piece->subshape()); + if (ShapeUtil::IsTuple(shape)) { + for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { + const Shape& subshape = shape.tuple_shapes(i); + + auto child_piece = Piece(); + child_piece.set_subshape(&subshape); + + CopyPieceSubtree(subshape, &src_piece->child(i), &child_piece); + + dest_piece->emplace_back(std::move(child_piece)); + } + } else if (ShapeUtil::IsArray(shape)) { + dest_piece->set_buffer(src_piece->buffer()); + } else { + // If the shape is neither an array nor tuple, then it must be + // zero-sized. Otherwise, some memory needs to be allocated for it. + CHECK_EQ(dest_piece->size_bytes(), 0); + } +} + +MutableLiteralBase::~MutableLiteralBase() {} + +MutableBorrowingLiteral::MutableBorrowingLiteral( + const MutableBorrowingLiteral& literal) + : MutableLiteralBase() { + shape_ = MakeUnique<Shape>(literal.shape()); + CHECK(LayoutUtil::HasLayout(*shape_)); + + root_piece_ = new Piece(); + root_piece_->set_subshape(shape_.get()); + + CopyPieceSubtree(*shape_, &literal.root_piece(), root_piece_); +} + +MutableBorrowingLiteral& MutableBorrowingLiteral::operator=( + const MutableBorrowingLiteral& literal) { + shape_ = MakeUnique<Shape>(literal.shape()); + CHECK(LayoutUtil::HasLayout(*shape_)); + + root_piece_ = new Piece(); + root_piece_->set_subshape(shape_.get()); + + CopyPieceSubtree(*shape_, &literal.root_piece(), root_piece_); + + return *this; +} + +MutableBorrowingLiteral::MutableBorrowingLiteral( + const MutableLiteralBase& literal) + : MutableLiteralBase() { + shape_ = MakeUnique<Shape>(literal.shape()); + CHECK(LayoutUtil::HasLayout(*shape_)); + + root_piece_ = new Piece(); + root_piece_->set_subshape(shape_.get()); + + CopyPieceSubtree(*shape_, &literal.root_piece(), root_piece_); +} + +MutableBorrowingLiteral::MutableBorrowingLiteral(MutableLiteralBase* literal) + : MutableLiteralBase() { + shape_ = MakeUnique<Shape>(literal->shape()); + CHECK(LayoutUtil::HasLayout(*shape_)); + + root_piece_ = new Piece(); + root_piece_->set_subshape(shape_.get()); + + CopyPieceSubtree(*shape_, &literal->root_piece(), root_piece_); +} + +MutableBorrowingLiteral::MutableBorrowingLiteral( + MutableBorrowingLiteral literal, const ShapeIndex& view_root) + : MutableLiteralBase() { + shape_ = MakeUnique<Shape>(literal.piece(view_root).subshape()); + CHECK(LayoutUtil::HasLayout(*shape_)); + + root_piece_ = new Piece(); + root_piece_->set_subshape(shape_.get()); + + CopyPieceSubtree(*shape_, &literal.piece(view_root), root_piece_); +} + +MutableBorrowingLiteral::MutableBorrowingLiteral(const char* src_buf_ptr, + const Shape& shape) + : MutableLiteralBase() { + shape_ = MakeUnique<Shape>(shape); + CHECK(LayoutUtil::HasLayout(*shape_)); + CHECK(!ShapeUtil::IsTuple(*shape_)); + + root_piece_ = new Piece(); + root_piece_->set_buffer(const_cast<char*>(src_buf_ptr)); + root_piece_->set_subshape(shape_.get()); +} + +MutableBorrowingLiteral::~MutableBorrowingLiteral() { + if (root_piece_ != nullptr) { + root_piece_->ForEachMutableSubpiece( + [&](const ShapeIndex& index, Piece* piece) { + if (piece->buffer() != nullptr) { + delete piece->sparse_indices(); + } + }); + delete root_piece_; + } +} + +LiteralSlice::LiteralSlice(const LiteralBase& literal) + : LiteralBase(), root_piece_(&literal.root_piece()) {} + +LiteralSlice::LiteralSlice(const LiteralBase& literal, + const ShapeIndex& view_root) + : LiteralBase(), root_piece_(&literal.piece(view_root)) {} + void BorrowingLiteral::BuildPieceSubtree(const Shape& shape, Piece* piece) { CHECK(ShapeUtil::IsTuple(shape)); for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { @@ -1932,13 +2060,6 @@ void BorrowingLiteral::BuildPieceSubtree(const Shape& shape, Piece* piece) { } } -LiteralSlice::LiteralSlice(const LiteralBase& literal) - : LiteralBase(), root_piece_(&literal.root_piece()) {} - -LiteralSlice::LiteralSlice(const LiteralBase& literal, - const ShapeIndex& view_root) - : LiteralBase(), root_piece_(&literal.piece(view_root)) {} - BorrowingLiteral::BorrowingLiteral(const char* src_buf_ptr, const Shape& shape) : LiteralBase(), shape_(MakeUnique<Shape>(shape)) { CHECK(ShapeUtil::IsArray(*shape_)); diff --git a/tensorflow/compiler/xla/literal.h b/tensorflow/compiler/xla/literal.h index dd67dfa8d4..92c0f903cb 100644 --- a/tensorflow/compiler/xla/literal.h +++ b/tensorflow/compiler/xla/literal.h @@ -310,9 +310,10 @@ class LiteralBase { // type of literal itself (0 for numeric types, and false for predicates). // // Note: It's an antipattern to use this method then immediately call - // Literal::Populate on the result (since that results in zero initialization, - // then reinitialization. Conside if a call to MakeUnique<Literal>(shape), - // followed by the call to Literal::Populate can be used instead. + // MutableLiteralBase::Populate on the result (since that results in zero + // initialization, then reinitialization. Conside if a call to + // MakeUnique<Literal>(shape), followed by the call to + // MutableLiteralBase::Populate can be used instead. static std::unique_ptr<Literal> CreateFromShape(const Shape& shape); protected: @@ -534,7 +535,7 @@ class LiteralBase { virtual const Piece& root_piece() const = 0; // LiteralSlice and Literal must access Pieces of other Literals. - friend class Literal; + friend class MutableLiteralBase; friend class LiteralSlice; friend class BorrowingLiteral; @@ -545,33 +546,10 @@ class LiteralBase { tensorflow::gtl::ArraySlice<int64> start_indices) const; }; -// Class representing literal values in XLA. -// -// The underlying buffer and shape is always owned by this class. -class Literal : public LiteralBase { +// Abstract base class representing a mutable literal in XLA. +class MutableLiteralBase : public LiteralBase { public: - Literal() : Literal(ShapeUtil::MakeNil()) {} - - // Create a literal of the given shape. The literal is allocated sufficient - // memory to hold the shape. Memory is uninitialized. - explicit Literal(const Shape& shape); - virtual ~Literal(); - - // Literals are moveable, but not copyable. To copy a literal use - // Literal::Clone or Literal::CloneToUnique. This prevents inadvertent copies - // of literals which can be expensive. - Literal(const Literal& other) = delete; - Literal& operator=(const Literal& other) = delete; - Literal(Literal&& other); - // 'allocate_arrays' indicates whether to allocate memory for the arrays in - // the shape. If false, buffer pointers inside of the Literal::Pieces are set - // to nullptr. - Literal(const Shape& shape, bool allocate_arrays); - Literal& operator=(Literal&& other); - - // TODO(b/67651157): Remove this accessor. Literal users should not be able to - // mutate the shape as this can produce malformed Literals. - Shape* mutable_shape_do_not_use() { return shape_.get(); } + virtual ~MutableLiteralBase() = 0; // Returns a MutableArraySlice view of the array for this literal for the // given NativeT (e.g., float). CHECKs if the subshape of the literal at the @@ -587,6 +565,10 @@ class Literal : public LiteralBase { // is not a sparse array. SparseIndexArray* sparse_indices(const ShapeIndex& shape_index = {}); + // TODO(b/67651157): Remove this accessor. Literal users should not be able to + // mutate the shape as this can produce malformed Literals. + Shape* mutable_shape_do_not_use() { return shape_.get(); } + // Returns a pointer to the underlying buffer holding the array at the given // shape index. CHECKs if the subshape of the literal at the given ShapeIndex // is not array. @@ -613,21 +595,6 @@ class Literal : public LiteralBase { const ShapeIndex& dest_shape_index = {}, const ShapeIndex& src_shape_index = {}); - // Returns a vector containing the tuple elements of this Literal as separate - // Literals. This Literal must be tuple-shaped and can be a nested tuple. The - // elements are moved into the new Literals; no data is copied. Upon return - // this Literal is set to a nil shape (empty tuple) - std::vector<Literal> DecomposeTuple(); - - // Similar to CopyFrom, but with move semantincs. The subshape of this literal - // rooted at 'dest_shape_index' must be *equal* to the shape 'src_literal' - // (layouts and shapes must match), but need not be arrays. The memory - // allocated in this literal for the subshape at dest_shape_index is - // deallocated, and the respective buffers are replaced with those in - // src_literal. Upon return, src_literal is set to a nil shape (empty tuple). - Status MoveFrom(Literal&& src_literal, - const ShapeIndex& dest_shape_index = {}); - // Copies the values from src_literal, starting at src_base shape indexes, // to this literal, starting at dest_base, where the copy size in each // dimension is specified by copy_size. @@ -730,12 +697,7 @@ class Literal : public LiteralBase { static StatusOr<std::unique_ptr<Literal>> CreateFromProto( const LiteralProto& proto); - private: - // Recursively sets the subshapes and buffers of all subpieces rooted at - // 'piece'. If 'allocate_array' is true, memory is allocated for the arrays in - // the shape. - void SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays); - + protected: // Returns the piece at the given ShapeIndex. Piece& piece(const ShapeIndex& shape_index) { return const_cast<Piece&>(LiteralBase::piece(shape_index)); @@ -783,12 +745,83 @@ class Literal : public LiteralBase { template <typename NativeT, typename FnType> Status PopulateInternal(const FnType& generator, bool parallel); + friend class LiteralBase; + friend class MutableBorrowingLiteral; +}; +std::ostream& operator<<(std::ostream& out, const Literal& literal); + +// The underlying buffer and shape is always owned by this class. +class Literal : public MutableLiteralBase { + public: + Literal() : Literal(ShapeUtil::MakeNil()) {} + + // Create a literal of the given shape. The literal is allocated sufficient + // memory to hold the shape. Memory is uninitialized. + explicit Literal(const Shape& shape); + virtual ~Literal(); + + // Literals are moveable, but not copyable. To copy a literal use + // Literal::Clone or Literal::CloneToUnique. This prevents inadvertent copies + // of literals which can be expensive. + Literal(const Literal& other) = delete; + Literal& operator=(const Literal& other) = delete; + Literal(Literal&& other); + // 'allocate_arrays' indicates whether to allocate memory for the arrays in + // the shape. If false, buffer pointers inside of the Literal::Pieces are set + // to nullptr. + Literal(const Shape& shape, bool allocate_arrays); + Literal& operator=(Literal&& other); + + // Similar to CopyFrom, but with move semantincs. The subshape of this literal + // rooted at 'dest_shape_index' must be *equal* to the shape 'src_literal' + // (layouts and shapes must match), but need not be arrays. The memory + // allocated in this literal for the subshape at dest_shape_index is + // deallocated, and the respective buffers are replaced with those in + // src_literal. Upon return, src_literal is set to a nil shape (empty tuple). + virtual Status MoveFrom(Literal&& src_literal, + const ShapeIndex& dest_shape_index = {}); + + // Returns a vector containing the tuple elements of this Literal as separate + // Literals. This Literal must be tuple-shaped and can be a nested tuple. The + // elements are moved into the new Literals; no data is copied. Upon return + // this Literal is set to a nil shape (empty tuple) + std::vector<Literal> DecomposeTuple(); + + private: // Deallocate the buffers held by this literal. void DeallocateBuffers(); - friend class LiteralBase; + // Recursively sets the subshapes and buffers of all subpieces rooted at + // 'piece'. If 'allocate_array' is true, memory is allocated for the arrays in + // the shape. + void SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays); +}; + +// The underlying buffer is not owned by this class and is always owned by +// others. The shape is not owned by this class and not mutable. +class MutableBorrowingLiteral : public MutableLiteralBase { + public: + virtual ~MutableBorrowingLiteral(); + + MutableBorrowingLiteral() : MutableLiteralBase() {} + + MutableBorrowingLiteral(const MutableBorrowingLiteral& literal); + MutableBorrowingLiteral& operator=(const MutableBorrowingLiteral& literal); + + // Implicit conversion constructors. + MutableBorrowingLiteral(const MutableLiteralBase& literal); + MutableBorrowingLiteral(MutableLiteralBase* literal); + MutableBorrowingLiteral(MutableBorrowingLiteral literal, + const ShapeIndex& view_root); + MutableBorrowingLiteral(const char* src_buf_ptr, const Shape& shape); + + private: + // Recursively copies the subtree from the `src_piece` at the given child + // index to the `dest_piece`. For buffers only the pointers are copied, but + // not the content. + void CopyPieceSubtree(const Shape& shape, Piece* src_piece, + Piece* dest_piece); }; -std::ostream& operator<<(std::ostream& out, const Literal& literal); // A read-only view of a Literal. A LiteralSlice contains pointers to shape and // literal buffers always owned by others. @@ -831,9 +864,9 @@ class BorrowingLiteral : public LiteralBase { const Piece& root_piece() const override { return root_piece_; }; Piece root_piece_; - // Shape of this literal. Stored as unique_ptr so such that the (default) - // move construction of this class would be trivially correct: the pointer to - // Shape root_piece_ stores will still point to the correct address. + // Shape of this literal. Stored as unique_ptr such that the (default) move + // construction of this class would be trivially correct: the pointer to Shape + // root_piece_ stores will still point to the correct address. std::unique_ptr<Shape> shape_; }; @@ -886,7 +919,7 @@ tensorflow::gtl::ArraySlice<NativeT> LiteralBase::data( } template <typename NativeT> -tensorflow::gtl::MutableArraySlice<NativeT> Literal::data( +tensorflow::gtl::MutableArraySlice<NativeT> MutableLiteralBase::data( const ShapeIndex& shape_index) { return piece(shape_index).data<NativeT>(); } @@ -904,14 +937,15 @@ inline NativeT LiteralBase::Get( } template <typename NativeT> -inline void Literal::Set(tensorflow::gtl::ArraySlice<int64> multi_index, - const ShapeIndex& shape_index, NativeT value) { +inline void MutableLiteralBase::Set( + tensorflow::gtl::ArraySlice<int64> multi_index, + const ShapeIndex& shape_index, NativeT value) { return piece(shape_index).Set<NativeT>(multi_index, value); } template <typename NativeT> -inline void Literal::Set(tensorflow::gtl::ArraySlice<int64> multi_index, - NativeT value) { +inline void MutableLiteralBase::Set( + tensorflow::gtl::ArraySlice<int64> multi_index, NativeT value) { return root_piece().Set<NativeT>(multi_index, value); } @@ -929,7 +963,7 @@ NativeT LiteralBase::GetSparseElement(int64 sparse_element_number, } template <typename NativeT> -void Literal::AppendSparseElement( +void MutableLiteralBase::AppendSparseElement( tensorflow::gtl::ArraySlice<int64> multi_index, NativeT value, const ShapeIndex& shape_index) { Piece& p = piece(shape_index); @@ -959,7 +993,8 @@ void LiteralBase::EachCell( } template <typename NativeT> -inline void Literal::PopulateR1(tensorflow::gtl::ArraySlice<NativeT> values) { +inline void MutableLiteralBase::PopulateR1( + tensorflow::gtl::ArraySlice<NativeT> values) { CHECK(ShapeUtil::IsArray(shape())); CHECK_EQ(ShapeUtil::Rank(shape()), 1); CHECK_EQ(ShapeUtil::ElementsIn(shape()), values.size()); @@ -971,7 +1006,7 @@ inline void Literal::PopulateR1(tensorflow::gtl::ArraySlice<NativeT> values) { } template <typename NativeT> -void Literal::PopulateR2( +void MutableLiteralBase::PopulateR2( std::initializer_list<std::initializer_list<NativeT>> values) { CHECK(ShapeUtil::IsArray(shape())); CHECK_EQ(ShapeUtil::Rank(shape()), 2); @@ -996,7 +1031,7 @@ void Literal::PopulateR2( } template <typename NativeT> -void Literal::PopulateFromArray(const Array<NativeT>& values) { +void MutableLiteralBase::PopulateFromArray(const Array<NativeT>& values) { CHECK(ShapeUtil::IsArray(shape())); CHECK_EQ(shape().element_type(), primitive_util::NativeToPrimitiveType<NativeT>()); @@ -1009,24 +1044,24 @@ void Literal::PopulateFromArray(const Array<NativeT>& values) { } template <typename NativeT> -void Literal::PopulateR2FromArray2D(const Array2D<NativeT>& values) { +void MutableLiteralBase::PopulateR2FromArray2D(const Array2D<NativeT>& values) { PopulateFromArray(values); } template <typename NativeT> -void Literal::PopulateR3FromArray3D(const Array3D<NativeT>& values) { +void MutableLiteralBase::PopulateR3FromArray3D(const Array3D<NativeT>& values) { PopulateFromArray(values); } template <typename NativeT> -void Literal::PopulateR4FromArray4D(const Array4D<NativeT>& values) { +void MutableLiteralBase::PopulateR4FromArray4D(const Array4D<NativeT>& values) { PopulateFromArray(values); } template <typename NativeT> -void Literal::PopulateSparse(SparseIndexArray indices, - tensorflow::gtl::ArraySlice<NativeT> values, - bool sort) { +void MutableLiteralBase::PopulateSparse( + SparseIndexArray indices, tensorflow::gtl::ArraySlice<NativeT> values, + bool sort) { CHECK(LayoutUtil::IsSparseArray(shape())); int rank = ShapeUtil::Rank(shape()); CHECK_EQ(indices.rank(), rank); @@ -1049,7 +1084,8 @@ void Literal::PopulateSparse(SparseIndexArray indices, } template <typename NativeT, typename FnType> -Status Literal::PopulateInternal(const FnType& generator, bool parallel) { +Status MutableLiteralBase::PopulateInternal(const FnType& generator, + bool parallel) { const Shape& this_shape = shape(); const int64 rank = ShapeUtil::Rank(this_shape); TF_RET_CHECK(LayoutUtil::IsDenseArray(this_shape)); @@ -1092,17 +1128,17 @@ Status Literal::PopulateInternal(const FnType& generator, bool parallel) { return Status::OK(); } template <typename NativeT, typename FnType> -Status Literal::Populate(const FnType& generator) { +Status MutableLiteralBase::Populate(const FnType& generator) { return PopulateInternal<NativeT>(generator, /*parallel=*/false); } template <typename NativeT, typename FnType> -Status Literal::PopulateParallel(const FnType& generator) { +Status MutableLiteralBase::PopulateParallel(const FnType& generator) { return PopulateInternal<NativeT>(generator, /*parallel=*/true); } template <typename NativeT> -void Literal::PopulateWithValue(NativeT value) { +void MutableLiteralBase::PopulateWithValue(NativeT value) { CHECK(ShapeUtil::IsArray(shape())); CHECK_EQ(shape().element_type(), primitive_util::NativeToPrimitiveType<NativeT>()); diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc index 356f12ed78..5d33df7d40 100644 --- a/tensorflow/compiler/xla/literal_util.cc +++ b/tensorflow/compiler/xla/literal_util.cc @@ -34,6 +34,7 @@ limitations under the License. #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/mem.h" #include "tensorflow/core/platform/types.h" using tensorflow::strings::StrCat; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc index 156166bf2b..59bc7e0e16 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc @@ -173,7 +173,7 @@ CpuTransferManager::TransferBufferToInfeedInternal(se::StreamExecutor* executor, Status CpuTransferManager::TransferLiteralFromOutfeed( se::StreamExecutor* executor, const Shape& literal_shape, - Literal* literal) { + MutableBorrowingLiteral literal) { if (!ShapeUtil::IsTuple(literal_shape)) { int64 size = GetByteSizeRequirement(literal_shape); // Note: OSS build didn't like implicit conversion from @@ -181,18 +181,16 @@ Status CpuTransferManager::TransferLiteralFromOutfeed( tensorflow::gtl::ArraySlice<int64> dimensions( tensorflow::bit_cast<const int64*>(literal_shape.dimensions().data()), literal_shape.dimensions().size()); - *literal = std::move(*LiteralUtil::CreateFromDimensions( - literal_shape.element_type(), dimensions)); - TF_ASSIGN_OR_RETURN(Shape received_shape, - TransferArrayBufferFromOutfeed( - executor, literal->untyped_data(), size)); - TF_RET_CHECK(ShapeUtil::Compatible(received_shape, literal->shape())) + TF_ASSIGN_OR_RETURN( + Shape received_shape, + TransferArrayBufferFromOutfeed(executor, literal.untyped_data(), size)); + TF_RET_CHECK(ShapeUtil::Compatible(received_shape, literal.shape())) << "Shape received from outfeed " << ShapeUtil::HumanString(received_shape) << " did not match the shape that was requested for outfeed: " << ShapeUtil::HumanString(literal_shape); TF_RET_CHECK(size == GetByteSizeRequirement(received_shape)); - *literal->mutable_shape_do_not_use() = received_shape; + *literal.mutable_shape_do_not_use() = received_shape; return Status::OK(); } @@ -201,22 +199,12 @@ Status CpuTransferManager::TransferLiteralFromOutfeed( "Nested tuple outfeeds are not yet implemented on CPU."); } - std::vector<std::unique_ptr<Literal>> elements; std::vector<std::pair<void*, int64>> buffer_data; for (int64 i = 0; i < literal_shape.tuple_shapes_size(); ++i) { const Shape& tuple_element_shape = ShapeUtil::GetTupleElementShape(literal_shape, i); - // Note: OSS build didn't like implicit conversion from - // literal_shape.dimensions() to the array slice on 2017-07-10. - tensorflow::gtl::ArraySlice<int64> dimensions( - tensorflow::bit_cast<const int64*>( - tuple_element_shape.dimensions().data()), - tuple_element_shape.dimensions().size()); - auto empty = LiteralUtil::CreateFromDimensions( - tuple_element_shape.element_type(), dimensions); int64 size = GetByteSizeRequirement(tuple_element_shape); - buffer_data.push_back({empty->untyped_data(), size}); - elements.push_back(std::move(empty)); + buffer_data.push_back({literal.untyped_data({i}), size}); } TF_ASSIGN_OR_RETURN(Shape received_shape, @@ -230,11 +218,7 @@ Status CpuTransferManager::TransferLiteralFromOutfeed( TF_RET_CHECK(GetByteSizeRequirement(literal_shape) == GetByteSizeRequirement(received_shape)); - for (int64 i = 0; i < literal_shape.tuple_shapes_size(); ++i) { - *elements[i]->mutable_shape_do_not_use() = received_shape.tuple_shapes(i); - } - *literal = std::move(*LiteralUtil::MakeTupleOwned(std::move(elements))); - TF_RET_CHECK(ShapeUtil::Equal(literal->shape(), literal_shape)); + TF_RET_CHECK(ShapeUtil::Equal(literal.shape(), literal_shape)); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h index 593575c0fd..80ef953d53 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h @@ -18,6 +18,7 @@ limitations under the License. #include <vector> +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/cpu/xfeed_manager.h" #include "tensorflow/compiler/xla/service/generic_transfer_manager.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" @@ -41,7 +42,7 @@ class CpuTransferManager : public GenericTransferManager { const LiteralSlice& literal) override; Status TransferLiteralFromOutfeed(se::StreamExecutor* executor, const Shape& literal_shape, - Literal* literal) override; + MutableBorrowingLiteral literal) override; private: Status TransferBufferToInfeed(se::StreamExecutor* executor, int64 size, diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.cc b/tensorflow/compiler/xla/service/generic_transfer_manager.cc index e314a469f0..0ce2db907b 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.cc @@ -24,7 +24,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/interpreter/platform_id.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -60,17 +59,19 @@ Status GenericTransferManager::WriteSingleTupleIndexTable( void GenericTransferManager::TransferLiteralFromDevice( se::Stream* stream, const ShapedBuffer& device_buffer, - std::function<void(StatusOr<std::unique_ptr<Literal>>)> done) { + MutableBorrowingLiteral literal, std::function<void(Status)> done) { Status status = stream->BlockHostUntilDone(); if (!status.ok()) { return done(status); } - done(TransferLiteralFromDeviceInternal(stream->parent(), device_buffer)); + + done(TransferLiteralFromDeviceInternal(stream->parent(), device_buffer, + literal)); } -StatusOr<std::unique_ptr<Literal>> -GenericTransferManager::TransferLiteralFromDeviceInternal( - se::StreamExecutor* executor, const ShapedBuffer& device_buffer) { +Status GenericTransferManager::TransferLiteralFromDeviceInternal( + se::StreamExecutor* executor, const ShapedBuffer& device_buffer, + MutableBorrowingLiteral literal) { VLOG(2) << "transferring literal from device ordinal " << executor->device_ordinal() << "; device buffer: " << device_buffer; TF_RET_CHECK(executor->device_ordinal() == device_buffer.device_ordinal()); @@ -80,9 +81,6 @@ GenericTransferManager::TransferLiteralFromDeviceInternal( TF_RET_CHECK(ShapeUtil::Equal(device_buffer.on_device_shape(), device_buffer.on_host_shape())); - std::unique_ptr<Literal> literal = - Literal::CreateFromShape(device_buffer.on_host_shape()); - TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( device_buffer.on_host_shape(), [&](const Shape& subshape, const ShapeIndex& index) -> Status { @@ -91,12 +89,12 @@ GenericTransferManager::TransferLiteralFromDeviceInternal( /*source=*/device_buffer.buffer(index), /*size=*/GetByteSizeRequirement(subshape), /*destination=*/ - literal->untyped_data(index))); + literal.untyped_data(index))); } return Status::OK(); })); - return std::move(literal); + return Status::OK(); } Status GenericTransferManager::TransferLiteralToDeviceAsync( @@ -160,7 +158,7 @@ Status GenericTransferManager::TransferLiteralToInfeed( Status GenericTransferManager::TransferLiteralFromOutfeed( se::StreamExecutor* executor, const Shape& literal_shape, - Literal* literal) { + MutableBorrowingLiteral literal) { return Unimplemented("Generic transfer from Outfeed"); } diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.h b/tensorflow/compiler/xla/service/generic_transfer_manager.h index 3cd002c1bf..6c1a21587a 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.h +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.h @@ -19,7 +19,6 @@ limitations under the License. #include <vector> #include "tensorflow/compiler/xla/service/transfer_manager.h" -#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -41,9 +40,10 @@ class GenericTransferManager : public TransferManager { se::Platform::Id PlatformId() const override; - void TransferLiteralFromDevice( - se::Stream* stream, const ShapedBuffer& device_buffer, - std::function<void(StatusOr<std::unique_ptr<Literal>>)> done) override; + void TransferLiteralFromDevice(se::Stream* stream, + const ShapedBuffer& device_buffer, + MutableBorrowingLiteral literal, + std::function<void(Status)> done) override; Status TransferLiteralToDeviceAsync( se::Stream* stream, const LiteralSlice& literal, @@ -53,7 +53,7 @@ class GenericTransferManager : public TransferManager { const LiteralSlice& literal) override; Status TransferLiteralFromOutfeed(se::StreamExecutor* executor, const Shape& literal_shape, - Literal* literal) override; + MutableBorrowingLiteral literal) override; Status ResetDevices( tensorflow::gtl::ArraySlice<se::StreamExecutor*> executors) override; @@ -67,8 +67,9 @@ class GenericTransferManager : public TransferManager { const Shape& shape, se::DeviceMemoryBase* region) override; private: - StatusOr<std::unique_ptr<Literal>> TransferLiteralFromDeviceInternal( - se::StreamExecutor* executor, const ShapedBuffer& device_buffer); + Status TransferLiteralFromDeviceInternal(se::StreamExecutor* executor, + const ShapedBuffer& device_buffer, + MutableBorrowingLiteral literal); // The platform this transfer manager targets. const se::Platform::Id platform_id_; diff --git a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc index 79b3f1efec..a2f53f8446 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc @@ -117,38 +117,37 @@ StatusOr<InfeedBuffer> GpuTransferManager::TransferBufferToInfeedInternal( return std::move(buffer); } -static std::unique_ptr<Literal> ShapeTreeToLiteral( +static void ShapeTreeToLiteral( ShapeTree<std::unique_ptr<gpu::OutfeedBuffer>>* shape_tree) { // This is a struct instead of a lambda for std::function-free recursion. struct Helper { - static std::unique_ptr<Literal> helper( + static void helper( ShapeTree<std::unique_ptr<gpu::OutfeedBuffer>>* shape_tree, ShapeIndex* index) { const Shape& shape = ShapeUtil::GetSubshape(shape_tree->shape(), *index); if (ShapeUtil::IsArray(shape)) { - return (*shape_tree->mutable_element(*index))->WaitUntilAvailable(); + (*shape_tree->mutable_element(*index))->WaitUntilAvailable(); + return; } CHECK(ShapeUtil::IsTuple(shape)) << ShapeUtil::HumanStringWithLayout(shape); const int64 tuple_element_count = ShapeUtil::TupleElementCount(shape); index->push_back(0); - std::vector<std::unique_ptr<Literal>> tuple_operands; for (int64 i = 0; i < tuple_element_count; ++i) { index->back() = i; - tuple_operands.push_back(helper(shape_tree, index)); + helper(shape_tree, index); } index->pop_back(); - return LiteralUtil::MakeTupleOwned(std::move(tuple_operands)); } }; ShapeIndex index; - return Helper::helper(shape_tree, &index); + Helper::helper(shape_tree, &index); } Status GpuTransferManager::TransferLiteralFromOutfeed( se::StreamExecutor* /*executor*/, const Shape& literal_shape, - Literal* literal) { + MutableBorrowingLiteral literal) { ShapeTree<std::unique_ptr<gpu::OutfeedBuffer>> outfeed_buffers( &literal_shape); @@ -162,6 +161,8 @@ Status GpuTransferManager::TransferLiteralFromOutfeed( return; } *buffer = MakeUnique<gpu::OutfeedBuffer>(GetByteSizeRequirement(shape)); + (*buffer)->set_destination( + MakeUnique<MutableBorrowingLiteral>(literal, index)); }); // Give the tree of buffers to the outfeed mananger. The device will fill it @@ -169,8 +170,8 @@ Status GpuTransferManager::TransferLiteralFromOutfeed( gpu::OutfeedManager* outfeed_manager = gpu::GetOrCreateOutfeedManager(); outfeed_manager->EnqueueDestination(&outfeed_buffers); - // Now turn the tree of buffers back into a literal. - *literal = std::move(*ShapeTreeToLiteral(&outfeed_buffers)); + // Now wait for the tree of buffers are written. + ShapeTreeToLiteral(&outfeed_buffers); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h index dceeb9e2eb..7929042869 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h @@ -42,7 +42,7 @@ class GpuTransferManager : public GenericTransferManager { const LiteralSlice& literal) override; Status TransferLiteralFromOutfeed(se::StreamExecutor* executor, const Shape& literal_shape, - Literal* literal) override; + MutableBorrowingLiteral literal) override; private: // Initiates the infeed data transfers. InfeedBuffer->Done() must be diff --git a/tensorflow/compiler/xla/service/gpu/outfeed_manager.h b/tensorflow/compiler/xla/service/gpu/outfeed_manager.h index a752eb7011..160ba4b691 100644 --- a/tensorflow/compiler/xla/service/gpu/outfeed_manager.h +++ b/tensorflow/compiler/xla/service/gpu/outfeed_manager.h @@ -36,22 +36,19 @@ class OutfeedBuffer { OutfeedBuffer(int64 length) : length_(length) {} // Waits for the device transfer to be finished. - std::unique_ptr<Literal> WaitUntilAvailable() { - done_.WaitForNotification(); - return std::move(destination_); - } + void WaitUntilAvailable() { done_.WaitForNotification(); } int64 length() const { return length_; } - void set_destination(std::unique_ptr<Literal> destination) { + void set_destination(std::unique_ptr<MutableBorrowingLiteral> destination) { destination_ = std::move(destination); } - Literal* destination() { return destination_.get(); } + MutableBorrowingLiteral* destination() { return destination_.get(); } // Callback to signal that this buffer is consumed. void Done() { done_.Notify(); } private: - std::unique_ptr<Literal> destination_; + std::unique_ptr<MutableBorrowingLiteral> destination_; const int64 length_; tensorflow::Notification done_; }; diff --git a/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc b/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc index 7986e63f43..b99d998c4d 100644 --- a/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc @@ -50,10 +50,6 @@ Status OutfeedThunk::ExecuteOnStream( if (!*buffer) { // Tuple pointers. return Status::OK(); } - // Allocate storage for the literal data. - const Shape& shape = - ShapeUtil::GetSubshape(outfeed_buffers->shape(), index); - (*buffer)->set_destination(Literal::CreateFromShape(shape)); BufferAllocation::Slice slice = outfeed_slices_.element(index); se::DeviceMemoryBase data_address; diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index 212db0643c..e970e885c5 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -1052,10 +1052,10 @@ Status Service::TransferFromOutfeed(const TransferFromOutfeedRequest* arg, executor = replicas[arg->replica_id()]; } - Literal literal; + Literal literal(arg->shape_with_layout()); TF_RETURN_IF_ERROR( execute_backend_->transfer_manager()->TransferLiteralFromOutfeed( - executor, arg->shape_with_layout(), &literal)); + executor, arg->shape_with_layout(), literal)); *result->mutable_literal() = literal.ToProto(); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/transfer_manager.cc b/tensorflow/compiler/xla/service/transfer_manager.cc index 7232c658b3..32d368a904 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.cc +++ b/tensorflow/compiler/xla/service/transfer_manager.cc @@ -43,15 +43,39 @@ TransferManager::GetPlatformTransferManagers() { StatusOr<std::unique_ptr<Literal>> TransferManager::TransferLiteralFromDevice( se::Stream* stream, const ShapedBuffer& device_buffer) { StatusOr<std::unique_ptr<Literal>> ret; + se::Stream* substream = stream->GetOrCreateSubStream(); substream->ThenWaitFor(stream); auto cleanup = tensorflow::gtl::MakeCleanup( [&]() { stream->ReturnSubStream(substream); }); tensorflow::Notification n; - TransferLiteralFromDevice(substream, device_buffer, - [&](StatusOr<std::unique_ptr<Literal>> arg) { - ret = std::move(arg); + Status s; + Literal literal(device_buffer.on_host_shape()); + TransferLiteralFromDevice(substream, device_buffer, literal, + [&](Status status) { + s = status; + n.Notify(); + }); + n.WaitForNotification(); + if (!s.ok()) { + return s; + } + return MakeUnique<Literal>(std::move(literal)); +} + +Status TransferManager::TransferLiteralFromDevice( + se::Stream* stream, const ShapedBuffer& device_buffer, + const MutableBorrowingLiteral& literal) { + se::Stream* substream = stream->GetOrCreateSubStream(); + auto cleanup = tensorflow::gtl::MakeCleanup( + [&]() { stream->ReturnSubStream(substream); }); + + Status ret; + tensorflow::Notification n; + TransferLiteralFromDevice(substream, device_buffer, literal, + [&](Status status) { + ret = status; n.Notify(); }); n.WaitForNotification(); @@ -76,22 +100,27 @@ Status TransferManager::TransferLiteralToDevice( StatusOr<std::unique_ptr<Literal>> TransferManager::TransferArrayFromDevice( se::Stream* stream, const Shape& shape, const se::DeviceMemoryBase& source) { + StatusOr<std::unique_ptr<Literal>> ret; // Implement the synchronous version by waiting on the asynchronous version. // Use a substream so that if we are called from a HostCallback we don't // deadlock. - StatusOr<std::unique_ptr<Literal>> ret; se::Stream* substream = stream->GetOrCreateSubStream(); auto cleanup = tensorflow::gtl::MakeCleanup( [&]() { stream->ReturnSubStream(substream); }); tensorflow::Notification n; - TransferArrayFromDevice(substream, shape, source, - [&](StatusOr<std::unique_ptr<Literal>> arg) { - ret = std::move(arg); + Literal literal(shape); + Status s; + TransferArrayFromDevice(substream, shape, source, literal, + [&](Status status) { + s = status; n.Notify(); }); n.WaitForNotification(); - return ret; + if (!s.ok()) { + return s; + } + return MakeUnique<Literal>(std::move(literal)); } Status TransferManager::TransferArrayToDevice( @@ -130,7 +159,7 @@ Status TransferManager::TransferArrayToDeviceAsync( void TransferManager::TransferArrayFromDevice( se::Stream* stream, const Shape& shape, const se::DeviceMemoryBase& source, - std::function<void(StatusOr<std::unique_ptr<Literal>>)> done) { + const MutableBorrowingLiteral& literal, std::function<void(Status)> done) { if (!ShapeUtil::Equal(HostShapeToDeviceShape(shape), shape)) { auto error = StrCat("Shape ", ShapeUtil::HumanString(shape), " has a differently shaped representation on-device: ", @@ -147,7 +176,8 @@ void TransferManager::TransferArrayFromDevice( stream->parent()->platform(), stream->parent()->device_ordinal()); shaped_buffer.set_buffer(source, /*index=*/{}); - return TransferLiteralFromDevice(stream, shaped_buffer, std::move(done)); + return TransferLiteralFromDevice(stream, shaped_buffer, literal, + std::move(done)); } /* static */ void TransferManager::RegisterTransferManager( diff --git a/tensorflow/compiler/xla/service/transfer_manager.h b/tensorflow/compiler/xla/service/transfer_manager.h index 82c599e482..475a2e5c14 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.h +++ b/tensorflow/compiler/xla/service/transfer_manager.h @@ -59,6 +59,9 @@ class TransferManager { // This function should be avoided in favor of the asynchronous version below. virtual StatusOr<std::unique_ptr<Literal>> TransferLiteralFromDevice( se::Stream* stream, const ShapedBuffer& device_buffer); + virtual Status TransferLiteralFromDevice( + se::Stream* stream, const ShapedBuffer& device_buffer, + const MutableBorrowingLiteral& literal); // Begins transferring a literal containing the data held in the given // ShapedBuffer using the provided executor. @@ -69,9 +72,10 @@ class TransferManager { // // device_buffer is copied by reference and must live at least until done() is // invoked. - virtual void TransferLiteralFromDevice( - se::Stream* stream, const ShapedBuffer& device_buffer, - std::function<void(StatusOr<std::unique_ptr<Literal>>)> done) = 0; + virtual void TransferLiteralFromDevice(se::Stream* stream, + const ShapedBuffer& device_buffer, + MutableBorrowingLiteral literal, + std::function<void(Status)> done) = 0; // Transfers the given literal into the previously allocated device memory // represented by the given ShapedBuffer using the given executor. The shape @@ -101,10 +105,10 @@ class TransferManager { // transfer an array at a known address. Status TransferArrayToDevice(se::Stream* stream, const LiteralSlice& literal, const se::DeviceMemoryBase& dest); - void TransferArrayFromDevice( - se::Stream* stream, const Shape& shape, - const se::DeviceMemoryBase& source, - std::function<void(StatusOr<std::unique_ptr<Literal>>)> done); + void TransferArrayFromDevice(se::Stream* stream, const Shape& shape, + const se::DeviceMemoryBase& source, + const MutableBorrowingLiteral& literal, + std::function<void(Status)> done); Status TransferArrayToDeviceAsync(se::Stream* stream, const LiteralSlice& literal, @@ -120,9 +124,9 @@ class TransferManager { // Transfers the given literal from the Outfeed interface of the device, // using the given executor. - virtual Status TransferLiteralFromOutfeed(se::StreamExecutor* executor, - const Shape& literal_shape, - Literal* literal) = 0; + virtual Status TransferLiteralFromOutfeed( + se::StreamExecutor* executor, const Shape& literal_shape, + MutableBorrowingLiteral literal) = 0; // Resets the devices associated with this transfer manager. virtual Status ResetDevices( diff --git a/tensorflow/compiler/xla/tests/tuple_test.cc b/tensorflow/compiler/xla/tests/tuple_test.cc index 2fd70b72b5..d9c1dfa3f7 100644 --- a/tensorflow/compiler/xla/tests/tuple_test.cc +++ b/tensorflow/compiler/xla/tests/tuple_test.cc @@ -586,9 +586,9 @@ XLA_TEST_F(TupleHloTest, })); auto expected = LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1<float>({2, 3})); - auto literal = MakeUnique<Literal>(); + auto literal = MakeUnique<Literal>(expected->shape()); TF_EXPECT_OK(backend().transfer_manager()->TransferLiteralFromOutfeed( - backend().default_stream_executor(), expected->shape(), literal.get())); + backend().default_stream_executor(), expected->shape(), *literal)); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *literal)); } |