aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/compiler/jit/xla_device_context.cc21
-rw-r--r--tensorflow/compiler/tf2xla/literal_util.cc17
-rw-r--r--tensorflow/compiler/tf2xla/literal_util.h10
-rw-r--r--tensorflow/compiler/xla/client/local_client.cc2
-rw-r--r--tensorflow/compiler/xla/literal.cc189
-rw-r--r--tensorflow/compiler/xla/literal.h186
-rw-r--r--tensorflow/compiler/xla/literal_util.cc1
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc32
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h3
-rw-r--r--tensorflow/compiler/xla/service/generic_transfer_manager.cc22
-rw-r--r--tensorflow/compiler/xla/service/generic_transfer_manager.h15
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc21
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h2
-rw-r--r--tensorflow/compiler/xla/service/gpu/outfeed_manager.h11
-rw-r--r--tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc4
-rw-r--r--tensorflow/compiler/xla/service/service.cc4
-rw-r--r--tensorflow/compiler/xla/service/transfer_manager.cc50
-rw-r--r--tensorflow/compiler/xla/service/transfer_manager.h24
-rw-r--r--tensorflow/compiler/xla/tests/tuple_test.cc4
19 files changed, 404 insertions, 214 deletions
diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc
index 8cf198239c..0100bf51ed 100644
--- a/tensorflow/compiler/jit/xla_device_context.cc
+++ b/tensorflow/compiler/jit/xla_device_context.cc
@@ -101,34 +101,27 @@ Status XlaTransferManager::TransferLiteralToDevice(
// Unref the host tensor, and capture the literal shared_ptr too so it goes
// out of scope when the lambda completes.
host_to_device_stream_->ThenDoHostCallback([ref, literal]() { ref.Unref(); });
+
return Status::OK();
}
void XlaTransferManager::TransferLiteralFromDevice(
Tensor* host_tensor, const Tensor& device_tensor,
const StatusCallback& done) const {
+ xla::MutableBorrowingLiteral literal;
+ TF_CHECK_OK(HostTensorToMutableBorrowingLiteral(host_tensor, &literal));
+
const xla::ShapedBuffer& shaped_buffer =
XlaTensor::FromTensor(&device_tensor)->shaped_buffer();
TensorReference ref(device_tensor);
transfer_manager_->TransferLiteralFromDevice(
- device_to_host_stream_, shaped_buffer,
- [=, &shaped_buffer](
- xla::StatusOr<std::unique_ptr<xla::Literal> > literal_or) {
+ device_to_host_stream_, shaped_buffer, literal,
+ [=, &shaped_buffer, &literal](xla::Status status) {
ref.Unref();
done([&]() -> Status {
- TF_ASSIGN_OR_RETURN(auto literal, std::move(literal_or));
- VLOG(1) << "Transfer from device as literal: " << literal->ToString()
+ VLOG(1) << "Transfer from device as literal: " << literal.ToString()
<< " " << shaped_buffer.ToString();
- Tensor tensor;
- TF_RETURN_IF_ERROR(
- LiteralToHostTensor(*literal, host_tensor->dtype(), &tensor));
- // Reshape the tensor back to its declared shape.
- Status status;
- if (!host_tensor->CopyFrom(tensor, device_tensor.shape())) {
- status = errors::Internal(
- "Tensor::CopyFrom failed when copying from XLA device to CPU");
- }
return status;
}());
});
diff --git a/tensorflow/compiler/tf2xla/literal_util.cc b/tensorflow/compiler/tf2xla/literal_util.cc
index 2fb66913ad..77da1bf29c 100644
--- a/tensorflow/compiler/tf2xla/literal_util.cc
+++ b/tensorflow/compiler/tf2xla/literal_util.cc
@@ -32,6 +32,23 @@ Status HostTensorToBorrowingLiteral(const Tensor& host_tensor,
return Status::OK();
}
+Status HostTensorToMutableBorrowingLiteral(
+ Tensor* host_tensor, xla::MutableBorrowingLiteral* literal) {
+ xla::Shape xla_shape;
+ TF_RETURN_IF_ERROR(TensorShapeToXLAShape(host_tensor->dtype(),
+ host_tensor->shape(), &xla_shape));
+ return HostTensorToMutableBorrowingLiteral(xla_shape, host_tensor, literal);
+}
+
+Status HostTensorToMutableBorrowingLiteral(
+ const xla::Shape& xla_shape, Tensor* host_tensor,
+ xla::MutableBorrowingLiteral* literal) {
+ *literal = xla::MutableBorrowingLiteral(
+ static_cast<const char*>(DMAHelper::base(host_tensor)), xla_shape);
+
+ return Status::OK();
+}
+
Status HostTensorsToBorrowingLiteralTuple(
tensorflow::gtl::ArraySlice<Tensor> host_tensors,
xla::BorrowingLiteral* literal) {
diff --git a/tensorflow/compiler/tf2xla/literal_util.h b/tensorflow/compiler/tf2xla/literal_util.h
index 0610a57029..09d6fa8116 100644
--- a/tensorflow/compiler/tf2xla/literal_util.h
+++ b/tensorflow/compiler/tf2xla/literal_util.h
@@ -30,6 +30,16 @@ namespace tensorflow {
// 'host_tensor'.
Status HostTensorToBorrowingLiteral(const Tensor& host_tensor,
xla::BorrowingLiteral* literal);
+// Returns a MutableBorrowingLiteral that utilizes the same underlying buffer
+// owned by 'host_tensor', but is mutable via the xla::Literal methods.
+Status HostTensorToMutableBorrowingLiteral(
+ Tensor* host_tensor, xla::MutableBorrowingLiteral* literal);
+// Similar as above, except the literal shape is explicitly provided and used
+// instead of obtaining it from the 'host_tensor'. The provided literal shape
+// 'xla_shape' must be compatible with the shape of 'host_tensor'.
+Status HostTensorToMutableBorrowingLiteral(
+ const xla::Shape& xla_shape, Tensor* host_tensor,
+ xla::MutableBorrowingLiteral* literal);
// Returns a BorrowingLiteral tuple that utilizes the same underlying buffers
// owned by 'host_tensors'.
diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc
index 8a6c5fb9a7..4d96316d3b 100644
--- a/tensorflow/compiler/xla/client/local_client.cc
+++ b/tensorflow/compiler/xla/client/local_client.cc
@@ -303,7 +303,7 @@ StatusOr<std::unique_ptr<Literal>> LocalClient::TransferFromOutfeedLocal(
const Shape& shape, int device_ordinal) {
TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor,
backend().stream_executor(device_ordinal));
- auto literal = MakeUnique<Literal>();
+ auto literal = MakeUnique<Literal>(shape);
TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralFromOutfeed(
executor, shape, literal.get()));
return std::move(literal);
diff --git a/tensorflow/compiler/xla/literal.cc b/tensorflow/compiler/xla/literal.cc
index 0545deb096..36e472568e 100644
--- a/tensorflow/compiler/xla/literal.cc
+++ b/tensorflow/compiler/xla/literal.cc
@@ -71,7 +71,7 @@ std::ostream& operator<<(std::ostream& out, const Literal& literal) {
return out;
}
-Literal::StrideConfig::StrideConfig(
+MutableLiteralBase::StrideConfig::StrideConfig(
const Shape& source_shape, const Shape& dest_shape,
tensorflow::gtl::ArraySlice<int64> dimensions)
: dimensions(dimensions),
@@ -133,7 +133,8 @@ void Literal::SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays) {
}
Literal::Literal(const Shape& shape, bool allocate_arrays)
- : LiteralBase(), shape_(MakeUnique<Shape>(shape)) {
+ : MutableLiteralBase() {
+ shape_ = MakeUnique<Shape>(shape);
CHECK(LayoutUtil::HasLayout(*shape_));
root_piece_ = new Piece();
root_piece_->set_subshape(shape_.get());
@@ -159,7 +160,9 @@ void Literal::DeallocateBuffers() {
});
}
-Literal::Literal(Literal&& other) : LiteralBase() { *this = std::move(other); }
+Literal::Literal(Literal&& other) : MutableLiteralBase() {
+ *this = std::move(other);
+}
Literal& Literal::operator=(Literal&& other) {
DCHECK(&other.root_piece_->subshape() == other.shape_.get());
@@ -187,12 +190,13 @@ const SparseIndexArray* LiteralBase::sparse_indices(
return piece(shape_index).sparse_indices();
}
-SparseIndexArray* Literal::sparse_indices(const ShapeIndex& shape_index) {
+SparseIndexArray* MutableLiteralBase::sparse_indices(
+ const ShapeIndex& shape_index) {
return piece(shape_index).sparse_indices();
}
template <typename NativeT>
-Status Literal::CopySliceFromInternal(
+Status MutableLiteralBase::CopySliceFromInternal(
const LiteralBase& src_literal, tensorflow::gtl::ArraySlice<int64> src_base,
tensorflow::gtl::ArraySlice<int64> dest_base,
tensorflow::gtl::ArraySlice<int64> copy_size) {
@@ -225,8 +229,8 @@ Status Literal::CopySliceFromInternal(
// proper stride size at the matching dimension.
DimensionVector src_indexes(src_base.size(), 0);
DimensionVector dest_indexes(dest_base.size(), 0);
- Literal::StrideConfig stride_config(src_literal.shape(), shape(),
- copy_size);
+ MutableLiteralBase::StrideConfig stride_config(src_literal.shape(), shape(),
+ copy_size);
auto copy_proc = [&](tensorflow::gtl::ArraySlice<int64> indexes) {
// Map from multi-dimensional index, to source index.
@@ -253,9 +257,10 @@ Status Literal::CopySliceFromInternal(
return Status::OK();
}
-Status Literal::CopyElementFrom(const LiteralSlice& src_literal,
- tensorflow::gtl::ArraySlice<int64> src_index,
- tensorflow::gtl::ArraySlice<int64> dest_index) {
+Status MutableLiteralBase::CopyElementFrom(
+ const LiteralSlice& src_literal,
+ tensorflow::gtl::ArraySlice<int64> src_index,
+ tensorflow::gtl::ArraySlice<int64> dest_index) {
DCHECK_EQ(shape().element_type(), src_literal.shape().element_type());
const int64 src_linear_index = IndexUtil::MultidimensionalIndexToLinearIndex(
src_literal.shape(), src_index);
@@ -275,8 +280,8 @@ Status Literal::CopyElementFrom(const LiteralSlice& src_literal,
return Status::OK();
}
-/* static */ StatusOr<std::unique_ptr<Literal>> Literal::CreateFromProto(
- const LiteralProto& proto) {
+/* static */ StatusOr<std::unique_ptr<Literal>>
+MutableLiteralBase::CreateFromProto(const LiteralProto& proto) {
if (!proto.has_shape()) {
return InvalidArgument("LiteralProto has no shape");
}
@@ -405,9 +410,9 @@ Status LiteralBase::Piece::CopyFrom(const LiteralBase::Piece& src) {
return Status::OK();
}
-Status Literal::CopyFrom(const LiteralSlice& src_literal,
- const ShapeIndex& dest_shape_index,
- const ShapeIndex& src_shape_index) {
+Status MutableLiteralBase::CopyFrom(const LiteralSlice& src_literal,
+ const ShapeIndex& dest_shape_index,
+ const ShapeIndex& src_shape_index) {
const Shape& dest_subshape =
ShapeUtil::GetSubshape(shape(), dest_shape_index);
const Shape& src_subshape =
@@ -482,10 +487,11 @@ Status Literal::MoveFrom(Literal&& src_literal,
return Status::OK();
}
-Status Literal::CopySliceFrom(const LiteralSlice& src_literal,
- tensorflow::gtl::ArraySlice<int64> src_base,
- tensorflow::gtl::ArraySlice<int64> dest_base,
- tensorflow::gtl::ArraySlice<int64> copy_size) {
+Status MutableLiteralBase::CopySliceFrom(
+ const LiteralSlice& src_literal,
+ tensorflow::gtl::ArraySlice<int64> src_base,
+ tensorflow::gtl::ArraySlice<int64> dest_base,
+ tensorflow::gtl::ArraySlice<int64> copy_size) {
TF_RET_CHECK(ShapeUtil::IsArray(shape())) << ShapeUtil::HumanString(shape());
TF_RET_CHECK(ShapeUtil::IsArray(src_literal.shape()))
<< ShapeUtil::HumanString(src_literal.shape());
@@ -543,7 +549,7 @@ Status Literal::CopySliceFrom(const LiteralSlice& src_literal,
shape().element_type());
}
-void Literal::PopulateR1(const tensorflow::core::Bitmap& values) {
+void MutableLiteralBase::PopulateR1(const tensorflow::core::Bitmap& values) {
CHECK(ShapeUtil::IsArray(shape()));
CHECK_EQ(ShapeUtil::Rank(shape()), 1);
CHECK_EQ(element_count(), values.bits());
@@ -895,8 +901,8 @@ size_t LiteralBase::Hash() const {
return hash_value;
}
-Status Literal::SetIntegralAsS64(tensorflow::gtl::ArraySlice<int64> multi_index,
- int64 value) {
+Status MutableLiteralBase::SetIntegralAsS64(
+ tensorflow::gtl::ArraySlice<int64> multi_index, int64 value) {
CHECK(LayoutUtil::IsDenseArray(shape()));
switch (shape().element_type()) {
case PRED:
@@ -933,7 +939,7 @@ tensorflow::gtl::ArraySlice<int64> LiteralBase::GetSparseIndex(
return p.sparse_indices()->At(sparse_element_number);
}
-void Literal::SortSparseElements(const ShapeIndex& shape_index) {
+void MutableLiteralBase::SortSparseElements(const ShapeIndex& shape_index) {
piece(shape_index).SortSparseElements();
}
@@ -1391,11 +1397,11 @@ StatusOr<std::unique_ptr<Literal>> LiteralBase::ConvertToShape(
elements.push_back(std::move(*new_element));
}
auto converted = MakeUnique<Literal>();
- *converted = Literal::MoveIntoTuple(&elements);
+ *converted = MutableLiteralBase::MoveIntoTuple(&elements);
return std::move(converted);
}
-/* static */ Literal Literal::MoveIntoTuple(
+/* static */ Literal MutableLiteralBase::MoveIntoTuple(
tensorflow::gtl::MutableArraySlice<Literal> elements) {
std::vector<Shape> element_shapes;
for (const Literal& element : elements) {
@@ -1808,7 +1814,8 @@ Status CopyFromRepeatedField(tensorflow::gtl::MutableArraySlice<NativeT> dest,
} // namespace
Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) {
- // These conditions should have been checked in Literal::CreateFromProto.
+ // These conditions should have been checked in
+ // MutableLiteralBase::CreateFromProto.
TF_RET_CHECK(proto.has_shape());
TF_RET_CHECK(LayoutUtil::HasLayout(proto.shape()));
TF_RET_CHECK(ShapeUtil::Equal(proto.shape(), subshape()));
@@ -1900,7 +1907,7 @@ const void* LiteralBase::untyped_data(const ShapeIndex& shape_index) const {
return piece(shape_index).untyped_data();
}
-void* Literal::untyped_data(const ShapeIndex& shape_index) {
+void* MutableLiteralBase::untyped_data(const ShapeIndex& shape_index) {
return piece(shape_index).untyped_data();
}
@@ -1916,6 +1923,127 @@ string LiteralBase::GetR1U8AsString() const {
ShapeUtil::ElementsIn(shape()));
}
+void MutableBorrowingLiteral::CopyPieceSubtree(const Shape& shape,
+ Piece* src_piece,
+ Piece* dest_piece) {
+ DCHECK(ShapeUtil::Equal(src_piece->subshape(), dest_piece->subshape()))
+ << "src_piece has shape: "
+ << ShapeUtil::HumanString(src_piece->subshape())
+ << "dest_piece has shape: "
+ << ShapeUtil::HumanString(dest_piece->subshape());
+ if (ShapeUtil::IsTuple(shape)) {
+ for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
+ const Shape& subshape = shape.tuple_shapes(i);
+
+ auto child_piece = Piece();
+ child_piece.set_subshape(&subshape);
+
+ CopyPieceSubtree(subshape, &src_piece->child(i), &child_piece);
+
+ dest_piece->emplace_back(std::move(child_piece));
+ }
+ } else if (ShapeUtil::IsArray(shape)) {
+ dest_piece->set_buffer(src_piece->buffer());
+ } else {
+ // If the shape is neither an array nor tuple, then it must be
+ // zero-sized. Otherwise, some memory needs to be allocated for it.
+ CHECK_EQ(dest_piece->size_bytes(), 0);
+ }
+}
+
+MutableLiteralBase::~MutableLiteralBase() {}
+
+MutableBorrowingLiteral::MutableBorrowingLiteral(
+ const MutableBorrowingLiteral& literal)
+ : MutableLiteralBase() {
+ shape_ = MakeUnique<Shape>(literal.shape());
+ CHECK(LayoutUtil::HasLayout(*shape_));
+
+ root_piece_ = new Piece();
+ root_piece_->set_subshape(shape_.get());
+
+ CopyPieceSubtree(*shape_, &literal.root_piece(), root_piece_);
+}
+
+MutableBorrowingLiteral& MutableBorrowingLiteral::operator=(
+ const MutableBorrowingLiteral& literal) {
+ shape_ = MakeUnique<Shape>(literal.shape());
+ CHECK(LayoutUtil::HasLayout(*shape_));
+
+ root_piece_ = new Piece();
+ root_piece_->set_subshape(shape_.get());
+
+ CopyPieceSubtree(*shape_, &literal.root_piece(), root_piece_);
+
+ return *this;
+}
+
+MutableBorrowingLiteral::MutableBorrowingLiteral(
+ const MutableLiteralBase& literal)
+ : MutableLiteralBase() {
+ shape_ = MakeUnique<Shape>(literal.shape());
+ CHECK(LayoutUtil::HasLayout(*shape_));
+
+ root_piece_ = new Piece();
+ root_piece_->set_subshape(shape_.get());
+
+ CopyPieceSubtree(*shape_, &literal.root_piece(), root_piece_);
+}
+
+MutableBorrowingLiteral::MutableBorrowingLiteral(MutableLiteralBase* literal)
+ : MutableLiteralBase() {
+ shape_ = MakeUnique<Shape>(literal->shape());
+ CHECK(LayoutUtil::HasLayout(*shape_));
+
+ root_piece_ = new Piece();
+ root_piece_->set_subshape(shape_.get());
+
+ CopyPieceSubtree(*shape_, &literal->root_piece(), root_piece_);
+}
+
+MutableBorrowingLiteral::MutableBorrowingLiteral(
+ MutableBorrowingLiteral literal, const ShapeIndex& view_root)
+ : MutableLiteralBase() {
+ shape_ = MakeUnique<Shape>(literal.piece(view_root).subshape());
+ CHECK(LayoutUtil::HasLayout(*shape_));
+
+ root_piece_ = new Piece();
+ root_piece_->set_subshape(shape_.get());
+
+ CopyPieceSubtree(*shape_, &literal.piece(view_root), root_piece_);
+}
+
+MutableBorrowingLiteral::MutableBorrowingLiteral(const char* src_buf_ptr,
+ const Shape& shape)
+ : MutableLiteralBase() {
+ shape_ = MakeUnique<Shape>(shape);
+ CHECK(LayoutUtil::HasLayout(*shape_));
+ CHECK(!ShapeUtil::IsTuple(*shape_));
+
+ root_piece_ = new Piece();
+ root_piece_->set_buffer(const_cast<char*>(src_buf_ptr));
+ root_piece_->set_subshape(shape_.get());
+}
+
+MutableBorrowingLiteral::~MutableBorrowingLiteral() {
+ if (root_piece_ != nullptr) {
+ root_piece_->ForEachMutableSubpiece(
+ [&](const ShapeIndex& index, Piece* piece) {
+ if (piece->buffer() != nullptr) {
+ delete piece->sparse_indices();
+ }
+ });
+ delete root_piece_;
+ }
+}
+
+LiteralSlice::LiteralSlice(const LiteralBase& literal)
+ : LiteralBase(), root_piece_(&literal.root_piece()) {}
+
+LiteralSlice::LiteralSlice(const LiteralBase& literal,
+ const ShapeIndex& view_root)
+ : LiteralBase(), root_piece_(&literal.piece(view_root)) {}
+
void BorrowingLiteral::BuildPieceSubtree(const Shape& shape, Piece* piece) {
CHECK(ShapeUtil::IsTuple(shape));
for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
@@ -1932,13 +2060,6 @@ void BorrowingLiteral::BuildPieceSubtree(const Shape& shape, Piece* piece) {
}
}
-LiteralSlice::LiteralSlice(const LiteralBase& literal)
- : LiteralBase(), root_piece_(&literal.root_piece()) {}
-
-LiteralSlice::LiteralSlice(const LiteralBase& literal,
- const ShapeIndex& view_root)
- : LiteralBase(), root_piece_(&literal.piece(view_root)) {}
-
BorrowingLiteral::BorrowingLiteral(const char* src_buf_ptr, const Shape& shape)
: LiteralBase(), shape_(MakeUnique<Shape>(shape)) {
CHECK(ShapeUtil::IsArray(*shape_));
diff --git a/tensorflow/compiler/xla/literal.h b/tensorflow/compiler/xla/literal.h
index dd67dfa8d4..92c0f903cb 100644
--- a/tensorflow/compiler/xla/literal.h
+++ b/tensorflow/compiler/xla/literal.h
@@ -310,9 +310,10 @@ class LiteralBase {
// type of literal itself (0 for numeric types, and false for predicates).
//
// Note: It's an antipattern to use this method then immediately call
- // Literal::Populate on the result (since that results in zero initialization,
- // then reinitialization. Conside if a call to MakeUnique<Literal>(shape),
- // followed by the call to Literal::Populate can be used instead.
+ // MutableLiteralBase::Populate on the result (since that results in zero
+ // initialization, then reinitialization. Conside if a call to
+ // MakeUnique<Literal>(shape), followed by the call to
+ // MutableLiteralBase::Populate can be used instead.
static std::unique_ptr<Literal> CreateFromShape(const Shape& shape);
protected:
@@ -534,7 +535,7 @@ class LiteralBase {
virtual const Piece& root_piece() const = 0;
// LiteralSlice and Literal must access Pieces of other Literals.
- friend class Literal;
+ friend class MutableLiteralBase;
friend class LiteralSlice;
friend class BorrowingLiteral;
@@ -545,33 +546,10 @@ class LiteralBase {
tensorflow::gtl::ArraySlice<int64> start_indices) const;
};
-// Class representing literal values in XLA.
-//
-// The underlying buffer and shape is always owned by this class.
-class Literal : public LiteralBase {
+// Abstract base class representing a mutable literal in XLA.
+class MutableLiteralBase : public LiteralBase {
public:
- Literal() : Literal(ShapeUtil::MakeNil()) {}
-
- // Create a literal of the given shape. The literal is allocated sufficient
- // memory to hold the shape. Memory is uninitialized.
- explicit Literal(const Shape& shape);
- virtual ~Literal();
-
- // Literals are moveable, but not copyable. To copy a literal use
- // Literal::Clone or Literal::CloneToUnique. This prevents inadvertent copies
- // of literals which can be expensive.
- Literal(const Literal& other) = delete;
- Literal& operator=(const Literal& other) = delete;
- Literal(Literal&& other);
- // 'allocate_arrays' indicates whether to allocate memory for the arrays in
- // the shape. If false, buffer pointers inside of the Literal::Pieces are set
- // to nullptr.
- Literal(const Shape& shape, bool allocate_arrays);
- Literal& operator=(Literal&& other);
-
- // TODO(b/67651157): Remove this accessor. Literal users should not be able to
- // mutate the shape as this can produce malformed Literals.
- Shape* mutable_shape_do_not_use() { return shape_.get(); }
+ virtual ~MutableLiteralBase() = 0;
// Returns a MutableArraySlice view of the array for this literal for the
// given NativeT (e.g., float). CHECKs if the subshape of the literal at the
@@ -587,6 +565,10 @@ class Literal : public LiteralBase {
// is not a sparse array.
SparseIndexArray* sparse_indices(const ShapeIndex& shape_index = {});
+ // TODO(b/67651157): Remove this accessor. Literal users should not be able to
+ // mutate the shape as this can produce malformed Literals.
+ Shape* mutable_shape_do_not_use() { return shape_.get(); }
+
// Returns a pointer to the underlying buffer holding the array at the given
// shape index. CHECKs if the subshape of the literal at the given ShapeIndex
// is not array.
@@ -613,21 +595,6 @@ class Literal : public LiteralBase {
const ShapeIndex& dest_shape_index = {},
const ShapeIndex& src_shape_index = {});
- // Returns a vector containing the tuple elements of this Literal as separate
- // Literals. This Literal must be tuple-shaped and can be a nested tuple. The
- // elements are moved into the new Literals; no data is copied. Upon return
- // this Literal is set to a nil shape (empty tuple)
- std::vector<Literal> DecomposeTuple();
-
- // Similar to CopyFrom, but with move semantincs. The subshape of this literal
- // rooted at 'dest_shape_index' must be *equal* to the shape 'src_literal'
- // (layouts and shapes must match), but need not be arrays. The memory
- // allocated in this literal for the subshape at dest_shape_index is
- // deallocated, and the respective buffers are replaced with those in
- // src_literal. Upon return, src_literal is set to a nil shape (empty tuple).
- Status MoveFrom(Literal&& src_literal,
- const ShapeIndex& dest_shape_index = {});
-
// Copies the values from src_literal, starting at src_base shape indexes,
// to this literal, starting at dest_base, where the copy size in each
// dimension is specified by copy_size.
@@ -730,12 +697,7 @@ class Literal : public LiteralBase {
static StatusOr<std::unique_ptr<Literal>> CreateFromProto(
const LiteralProto& proto);
- private:
- // Recursively sets the subshapes and buffers of all subpieces rooted at
- // 'piece'. If 'allocate_array' is true, memory is allocated for the arrays in
- // the shape.
- void SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays);
-
+ protected:
// Returns the piece at the given ShapeIndex.
Piece& piece(const ShapeIndex& shape_index) {
return const_cast<Piece&>(LiteralBase::piece(shape_index));
@@ -783,12 +745,83 @@ class Literal : public LiteralBase {
template <typename NativeT, typename FnType>
Status PopulateInternal(const FnType& generator, bool parallel);
+ friend class LiteralBase;
+ friend class MutableBorrowingLiteral;
+};
+std::ostream& operator<<(std::ostream& out, const Literal& literal);
+
+// The underlying buffer and shape is always owned by this class.
+class Literal : public MutableLiteralBase {
+ public:
+ Literal() : Literal(ShapeUtil::MakeNil()) {}
+
+ // Create a literal of the given shape. The literal is allocated sufficient
+ // memory to hold the shape. Memory is uninitialized.
+ explicit Literal(const Shape& shape);
+ virtual ~Literal();
+
+ // Literals are moveable, but not copyable. To copy a literal use
+ // Literal::Clone or Literal::CloneToUnique. This prevents inadvertent copies
+ // of literals which can be expensive.
+ Literal(const Literal& other) = delete;
+ Literal& operator=(const Literal& other) = delete;
+ Literal(Literal&& other);
+ // 'allocate_arrays' indicates whether to allocate memory for the arrays in
+ // the shape. If false, buffer pointers inside of the Literal::Pieces are set
+ // to nullptr.
+ Literal(const Shape& shape, bool allocate_arrays);
+ Literal& operator=(Literal&& other);
+
+ // Similar to CopyFrom, but with move semantincs. The subshape of this literal
+ // rooted at 'dest_shape_index' must be *equal* to the shape 'src_literal'
+ // (layouts and shapes must match), but need not be arrays. The memory
+ // allocated in this literal for the subshape at dest_shape_index is
+ // deallocated, and the respective buffers are replaced with those in
+ // src_literal. Upon return, src_literal is set to a nil shape (empty tuple).
+ virtual Status MoveFrom(Literal&& src_literal,
+ const ShapeIndex& dest_shape_index = {});
+
+ // Returns a vector containing the tuple elements of this Literal as separate
+ // Literals. This Literal must be tuple-shaped and can be a nested tuple. The
+ // elements are moved into the new Literals; no data is copied. Upon return
+ // this Literal is set to a nil shape (empty tuple)
+ std::vector<Literal> DecomposeTuple();
+
+ private:
// Deallocate the buffers held by this literal.
void DeallocateBuffers();
- friend class LiteralBase;
+ // Recursively sets the subshapes and buffers of all subpieces rooted at
+ // 'piece'. If 'allocate_array' is true, memory is allocated for the arrays in
+ // the shape.
+ void SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays);
+};
+
+// The underlying buffer is not owned by this class and is always owned by
+// others. The shape is not owned by this class and not mutable.
+class MutableBorrowingLiteral : public MutableLiteralBase {
+ public:
+ virtual ~MutableBorrowingLiteral();
+
+ MutableBorrowingLiteral() : MutableLiteralBase() {}
+
+ MutableBorrowingLiteral(const MutableBorrowingLiteral& literal);
+ MutableBorrowingLiteral& operator=(const MutableBorrowingLiteral& literal);
+
+ // Implicit conversion constructors.
+ MutableBorrowingLiteral(const MutableLiteralBase& literal);
+ MutableBorrowingLiteral(MutableLiteralBase* literal);
+ MutableBorrowingLiteral(MutableBorrowingLiteral literal,
+ const ShapeIndex& view_root);
+ MutableBorrowingLiteral(const char* src_buf_ptr, const Shape& shape);
+
+ private:
+ // Recursively copies the subtree from the `src_piece` at the given child
+ // index to the `dest_piece`. For buffers only the pointers are copied, but
+ // not the content.
+ void CopyPieceSubtree(const Shape& shape, Piece* src_piece,
+ Piece* dest_piece);
};
-std::ostream& operator<<(std::ostream& out, const Literal& literal);
// A read-only view of a Literal. A LiteralSlice contains pointers to shape and
// literal buffers always owned by others.
@@ -831,9 +864,9 @@ class BorrowingLiteral : public LiteralBase {
const Piece& root_piece() const override { return root_piece_; };
Piece root_piece_;
- // Shape of this literal. Stored as unique_ptr so such that the (default)
- // move construction of this class would be trivially correct: the pointer to
- // Shape root_piece_ stores will still point to the correct address.
+ // Shape of this literal. Stored as unique_ptr such that the (default) move
+ // construction of this class would be trivially correct: the pointer to Shape
+ // root_piece_ stores will still point to the correct address.
std::unique_ptr<Shape> shape_;
};
@@ -886,7 +919,7 @@ tensorflow::gtl::ArraySlice<NativeT> LiteralBase::data(
}
template <typename NativeT>
-tensorflow::gtl::MutableArraySlice<NativeT> Literal::data(
+tensorflow::gtl::MutableArraySlice<NativeT> MutableLiteralBase::data(
const ShapeIndex& shape_index) {
return piece(shape_index).data<NativeT>();
}
@@ -904,14 +937,15 @@ inline NativeT LiteralBase::Get(
}
template <typename NativeT>
-inline void Literal::Set(tensorflow::gtl::ArraySlice<int64> multi_index,
- const ShapeIndex& shape_index, NativeT value) {
+inline void MutableLiteralBase::Set(
+ tensorflow::gtl::ArraySlice<int64> multi_index,
+ const ShapeIndex& shape_index, NativeT value) {
return piece(shape_index).Set<NativeT>(multi_index, value);
}
template <typename NativeT>
-inline void Literal::Set(tensorflow::gtl::ArraySlice<int64> multi_index,
- NativeT value) {
+inline void MutableLiteralBase::Set(
+ tensorflow::gtl::ArraySlice<int64> multi_index, NativeT value) {
return root_piece().Set<NativeT>(multi_index, value);
}
@@ -929,7 +963,7 @@ NativeT LiteralBase::GetSparseElement(int64 sparse_element_number,
}
template <typename NativeT>
-void Literal::AppendSparseElement(
+void MutableLiteralBase::AppendSparseElement(
tensorflow::gtl::ArraySlice<int64> multi_index, NativeT value,
const ShapeIndex& shape_index) {
Piece& p = piece(shape_index);
@@ -959,7 +993,8 @@ void LiteralBase::EachCell(
}
template <typename NativeT>
-inline void Literal::PopulateR1(tensorflow::gtl::ArraySlice<NativeT> values) {
+inline void MutableLiteralBase::PopulateR1(
+ tensorflow::gtl::ArraySlice<NativeT> values) {
CHECK(ShapeUtil::IsArray(shape()));
CHECK_EQ(ShapeUtil::Rank(shape()), 1);
CHECK_EQ(ShapeUtil::ElementsIn(shape()), values.size());
@@ -971,7 +1006,7 @@ inline void Literal::PopulateR1(tensorflow::gtl::ArraySlice<NativeT> values) {
}
template <typename NativeT>
-void Literal::PopulateR2(
+void MutableLiteralBase::PopulateR2(
std::initializer_list<std::initializer_list<NativeT>> values) {
CHECK(ShapeUtil::IsArray(shape()));
CHECK_EQ(ShapeUtil::Rank(shape()), 2);
@@ -996,7 +1031,7 @@ void Literal::PopulateR2(
}
template <typename NativeT>
-void Literal::PopulateFromArray(const Array<NativeT>& values) {
+void MutableLiteralBase::PopulateFromArray(const Array<NativeT>& values) {
CHECK(ShapeUtil::IsArray(shape()));
CHECK_EQ(shape().element_type(),
primitive_util::NativeToPrimitiveType<NativeT>());
@@ -1009,24 +1044,24 @@ void Literal::PopulateFromArray(const Array<NativeT>& values) {
}
template <typename NativeT>
-void Literal::PopulateR2FromArray2D(const Array2D<NativeT>& values) {
+void MutableLiteralBase::PopulateR2FromArray2D(const Array2D<NativeT>& values) {
PopulateFromArray(values);
}
template <typename NativeT>
-void Literal::PopulateR3FromArray3D(const Array3D<NativeT>& values) {
+void MutableLiteralBase::PopulateR3FromArray3D(const Array3D<NativeT>& values) {
PopulateFromArray(values);
}
template <typename NativeT>
-void Literal::PopulateR4FromArray4D(const Array4D<NativeT>& values) {
+void MutableLiteralBase::PopulateR4FromArray4D(const Array4D<NativeT>& values) {
PopulateFromArray(values);
}
template <typename NativeT>
-void Literal::PopulateSparse(SparseIndexArray indices,
- tensorflow::gtl::ArraySlice<NativeT> values,
- bool sort) {
+void MutableLiteralBase::PopulateSparse(
+ SparseIndexArray indices, tensorflow::gtl::ArraySlice<NativeT> values,
+ bool sort) {
CHECK(LayoutUtil::IsSparseArray(shape()));
int rank = ShapeUtil::Rank(shape());
CHECK_EQ(indices.rank(), rank);
@@ -1049,7 +1084,8 @@ void Literal::PopulateSparse(SparseIndexArray indices,
}
template <typename NativeT, typename FnType>
-Status Literal::PopulateInternal(const FnType& generator, bool parallel) {
+Status MutableLiteralBase::PopulateInternal(const FnType& generator,
+ bool parallel) {
const Shape& this_shape = shape();
const int64 rank = ShapeUtil::Rank(this_shape);
TF_RET_CHECK(LayoutUtil::IsDenseArray(this_shape));
@@ -1092,17 +1128,17 @@ Status Literal::PopulateInternal(const FnType& generator, bool parallel) {
return Status::OK();
}
template <typename NativeT, typename FnType>
-Status Literal::Populate(const FnType& generator) {
+Status MutableLiteralBase::Populate(const FnType& generator) {
return PopulateInternal<NativeT>(generator, /*parallel=*/false);
}
template <typename NativeT, typename FnType>
-Status Literal::PopulateParallel(const FnType& generator) {
+Status MutableLiteralBase::PopulateParallel(const FnType& generator) {
return PopulateInternal<NativeT>(generator, /*parallel=*/true);
}
template <typename NativeT>
-void Literal::PopulateWithValue(NativeT value) {
+void MutableLiteralBase::PopulateWithValue(NativeT value) {
CHECK(ShapeUtil::IsArray(shape()));
CHECK_EQ(shape().element_type(),
primitive_util::NativeToPrimitiveType<NativeT>());
diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc
index 356f12ed78..5d33df7d40 100644
--- a/tensorflow/compiler/xla/literal_util.cc
+++ b/tensorflow/compiler/xla/literal_util.cc
@@ -34,6 +34,7 @@ limitations under the License.
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/mem.h"
#include "tensorflow/core/platform/types.h"
using tensorflow::strings::StrCat;
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc
index 156166bf2b..59bc7e0e16 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc
@@ -173,7 +173,7 @@ CpuTransferManager::TransferBufferToInfeedInternal(se::StreamExecutor* executor,
Status CpuTransferManager::TransferLiteralFromOutfeed(
se::StreamExecutor* executor, const Shape& literal_shape,
- Literal* literal) {
+ MutableBorrowingLiteral literal) {
if (!ShapeUtil::IsTuple(literal_shape)) {
int64 size = GetByteSizeRequirement(literal_shape);
// Note: OSS build didn't like implicit conversion from
@@ -181,18 +181,16 @@ Status CpuTransferManager::TransferLiteralFromOutfeed(
tensorflow::gtl::ArraySlice<int64> dimensions(
tensorflow::bit_cast<const int64*>(literal_shape.dimensions().data()),
literal_shape.dimensions().size());
- *literal = std::move(*LiteralUtil::CreateFromDimensions(
- literal_shape.element_type(), dimensions));
- TF_ASSIGN_OR_RETURN(Shape received_shape,
- TransferArrayBufferFromOutfeed(
- executor, literal->untyped_data(), size));
- TF_RET_CHECK(ShapeUtil::Compatible(received_shape, literal->shape()))
+ TF_ASSIGN_OR_RETURN(
+ Shape received_shape,
+ TransferArrayBufferFromOutfeed(executor, literal.untyped_data(), size));
+ TF_RET_CHECK(ShapeUtil::Compatible(received_shape, literal.shape()))
<< "Shape received from outfeed "
<< ShapeUtil::HumanString(received_shape)
<< " did not match the shape that was requested for outfeed: "
<< ShapeUtil::HumanString(literal_shape);
TF_RET_CHECK(size == GetByteSizeRequirement(received_shape));
- *literal->mutable_shape_do_not_use() = received_shape;
+ *literal.mutable_shape_do_not_use() = received_shape;
return Status::OK();
}
@@ -201,22 +199,12 @@ Status CpuTransferManager::TransferLiteralFromOutfeed(
"Nested tuple outfeeds are not yet implemented on CPU.");
}
- std::vector<std::unique_ptr<Literal>> elements;
std::vector<std::pair<void*, int64>> buffer_data;
for (int64 i = 0; i < literal_shape.tuple_shapes_size(); ++i) {
const Shape& tuple_element_shape =
ShapeUtil::GetTupleElementShape(literal_shape, i);
- // Note: OSS build didn't like implicit conversion from
- // literal_shape.dimensions() to the array slice on 2017-07-10.
- tensorflow::gtl::ArraySlice<int64> dimensions(
- tensorflow::bit_cast<const int64*>(
- tuple_element_shape.dimensions().data()),
- tuple_element_shape.dimensions().size());
- auto empty = LiteralUtil::CreateFromDimensions(
- tuple_element_shape.element_type(), dimensions);
int64 size = GetByteSizeRequirement(tuple_element_shape);
- buffer_data.push_back({empty->untyped_data(), size});
- elements.push_back(std::move(empty));
+ buffer_data.push_back({literal.untyped_data({i}), size});
}
TF_ASSIGN_OR_RETURN(Shape received_shape,
@@ -230,11 +218,7 @@ Status CpuTransferManager::TransferLiteralFromOutfeed(
TF_RET_CHECK(GetByteSizeRequirement(literal_shape) ==
GetByteSizeRequirement(received_shape));
- for (int64 i = 0; i < literal_shape.tuple_shapes_size(); ++i) {
- *elements[i]->mutable_shape_do_not_use() = received_shape.tuple_shapes(i);
- }
- *literal = std::move(*LiteralUtil::MakeTupleOwned(std::move(elements)));
- TF_RET_CHECK(ShapeUtil::Equal(literal->shape(), literal_shape));
+ TF_RET_CHECK(ShapeUtil::Equal(literal.shape(), literal_shape));
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h
index 593575c0fd..80ef953d53 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h
+++ b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h
@@ -18,6 +18,7 @@ limitations under the License.
#include <vector>
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/cpu/xfeed_manager.h"
#include "tensorflow/compiler/xla/service/generic_transfer_manager.h"
#include "tensorflow/compiler/xla/service/transfer_manager.h"
@@ -41,7 +42,7 @@ class CpuTransferManager : public GenericTransferManager {
const LiteralSlice& literal) override;
Status TransferLiteralFromOutfeed(se::StreamExecutor* executor,
const Shape& literal_shape,
- Literal* literal) override;
+ MutableBorrowingLiteral literal) override;
private:
Status TransferBufferToInfeed(se::StreamExecutor* executor, int64 size,
diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.cc b/tensorflow/compiler/xla/service/generic_transfer_manager.cc
index e314a469f0..0ce2db907b 100644
--- a/tensorflow/compiler/xla/service/generic_transfer_manager.cc
+++ b/tensorflow/compiler/xla/service/generic_transfer_manager.cc
@@ -24,7 +24,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/interpreter/platform_id.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
-#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
@@ -60,17 +59,19 @@ Status GenericTransferManager::WriteSingleTupleIndexTable(
void GenericTransferManager::TransferLiteralFromDevice(
se::Stream* stream, const ShapedBuffer& device_buffer,
- std::function<void(StatusOr<std::unique_ptr<Literal>>)> done) {
+ MutableBorrowingLiteral literal, std::function<void(Status)> done) {
Status status = stream->BlockHostUntilDone();
if (!status.ok()) {
return done(status);
}
- done(TransferLiteralFromDeviceInternal(stream->parent(), device_buffer));
+
+ done(TransferLiteralFromDeviceInternal(stream->parent(), device_buffer,
+ literal));
}
-StatusOr<std::unique_ptr<Literal>>
-GenericTransferManager::TransferLiteralFromDeviceInternal(
- se::StreamExecutor* executor, const ShapedBuffer& device_buffer) {
+Status GenericTransferManager::TransferLiteralFromDeviceInternal(
+ se::StreamExecutor* executor, const ShapedBuffer& device_buffer,
+ MutableBorrowingLiteral literal) {
VLOG(2) << "transferring literal from device ordinal "
<< executor->device_ordinal() << "; device buffer: " << device_buffer;
TF_RET_CHECK(executor->device_ordinal() == device_buffer.device_ordinal());
@@ -80,9 +81,6 @@ GenericTransferManager::TransferLiteralFromDeviceInternal(
TF_RET_CHECK(ShapeUtil::Equal(device_buffer.on_device_shape(),
device_buffer.on_host_shape()));
- std::unique_ptr<Literal> literal =
- Literal::CreateFromShape(device_buffer.on_host_shape());
-
TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
device_buffer.on_host_shape(),
[&](const Shape& subshape, const ShapeIndex& index) -> Status {
@@ -91,12 +89,12 @@ GenericTransferManager::TransferLiteralFromDeviceInternal(
/*source=*/device_buffer.buffer(index),
/*size=*/GetByteSizeRequirement(subshape),
/*destination=*/
- literal->untyped_data(index)));
+ literal.untyped_data(index)));
}
return Status::OK();
}));
- return std::move(literal);
+ return Status::OK();
}
Status GenericTransferManager::TransferLiteralToDeviceAsync(
@@ -160,7 +158,7 @@ Status GenericTransferManager::TransferLiteralToInfeed(
Status GenericTransferManager::TransferLiteralFromOutfeed(
se::StreamExecutor* executor, const Shape& literal_shape,
- Literal* literal) {
+ MutableBorrowingLiteral literal) {
return Unimplemented("Generic transfer from Outfeed");
}
diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.h b/tensorflow/compiler/xla/service/generic_transfer_manager.h
index 3cd002c1bf..6c1a21587a 100644
--- a/tensorflow/compiler/xla/service/generic_transfer_manager.h
+++ b/tensorflow/compiler/xla/service/generic_transfer_manager.h
@@ -19,7 +19,6 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/xla/service/transfer_manager.h"
-#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
@@ -41,9 +40,10 @@ class GenericTransferManager : public TransferManager {
se::Platform::Id PlatformId() const override;
- void TransferLiteralFromDevice(
- se::Stream* stream, const ShapedBuffer& device_buffer,
- std::function<void(StatusOr<std::unique_ptr<Literal>>)> done) override;
+ void TransferLiteralFromDevice(se::Stream* stream,
+ const ShapedBuffer& device_buffer,
+ MutableBorrowingLiteral literal,
+ std::function<void(Status)> done) override;
Status TransferLiteralToDeviceAsync(
se::Stream* stream, const LiteralSlice& literal,
@@ -53,7 +53,7 @@ class GenericTransferManager : public TransferManager {
const LiteralSlice& literal) override;
Status TransferLiteralFromOutfeed(se::StreamExecutor* executor,
const Shape& literal_shape,
- Literal* literal) override;
+ MutableBorrowingLiteral literal) override;
Status ResetDevices(
tensorflow::gtl::ArraySlice<se::StreamExecutor*> executors) override;
@@ -67,8 +67,9 @@ class GenericTransferManager : public TransferManager {
const Shape& shape, se::DeviceMemoryBase* region) override;
private:
- StatusOr<std::unique_ptr<Literal>> TransferLiteralFromDeviceInternal(
- se::StreamExecutor* executor, const ShapedBuffer& device_buffer);
+ Status TransferLiteralFromDeviceInternal(se::StreamExecutor* executor,
+ const ShapedBuffer& device_buffer,
+ MutableBorrowingLiteral literal);
// The platform this transfer manager targets.
const se::Platform::Id platform_id_;
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc
index 79b3f1efec..a2f53f8446 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc
@@ -117,38 +117,37 @@ StatusOr<InfeedBuffer> GpuTransferManager::TransferBufferToInfeedInternal(
return std::move(buffer);
}
-static std::unique_ptr<Literal> ShapeTreeToLiteral(
+static void ShapeTreeToLiteral(
ShapeTree<std::unique_ptr<gpu::OutfeedBuffer>>* shape_tree) {
// This is a struct instead of a lambda for std::function-free recursion.
struct Helper {
- static std::unique_ptr<Literal> helper(
+ static void helper(
ShapeTree<std::unique_ptr<gpu::OutfeedBuffer>>* shape_tree,
ShapeIndex* index) {
const Shape& shape = ShapeUtil::GetSubshape(shape_tree->shape(), *index);
if (ShapeUtil::IsArray(shape)) {
- return (*shape_tree->mutable_element(*index))->WaitUntilAvailable();
+ (*shape_tree->mutable_element(*index))->WaitUntilAvailable();
+ return;
}
CHECK(ShapeUtil::IsTuple(shape))
<< ShapeUtil::HumanStringWithLayout(shape);
const int64 tuple_element_count = ShapeUtil::TupleElementCount(shape);
index->push_back(0);
- std::vector<std::unique_ptr<Literal>> tuple_operands;
for (int64 i = 0; i < tuple_element_count; ++i) {
index->back() = i;
- tuple_operands.push_back(helper(shape_tree, index));
+ helper(shape_tree, index);
}
index->pop_back();
- return LiteralUtil::MakeTupleOwned(std::move(tuple_operands));
}
};
ShapeIndex index;
- return Helper::helper(shape_tree, &index);
+ Helper::helper(shape_tree, &index);
}
Status GpuTransferManager::TransferLiteralFromOutfeed(
se::StreamExecutor* /*executor*/, const Shape& literal_shape,
- Literal* literal) {
+ MutableBorrowingLiteral literal) {
ShapeTree<std::unique_ptr<gpu::OutfeedBuffer>> outfeed_buffers(
&literal_shape);
@@ -162,6 +161,8 @@ Status GpuTransferManager::TransferLiteralFromOutfeed(
return;
}
*buffer = MakeUnique<gpu::OutfeedBuffer>(GetByteSizeRequirement(shape));
+ (*buffer)->set_destination(
+ MakeUnique<MutableBorrowingLiteral>(literal, index));
});
// Give the tree of buffers to the outfeed mananger. The device will fill it
@@ -169,8 +170,8 @@ Status GpuTransferManager::TransferLiteralFromOutfeed(
gpu::OutfeedManager* outfeed_manager = gpu::GetOrCreateOutfeedManager();
outfeed_manager->EnqueueDestination(&outfeed_buffers);
- // Now turn the tree of buffers back into a literal.
- *literal = std::move(*ShapeTreeToLiteral(&outfeed_buffers));
+ // Now wait for the tree of buffers are written.
+ ShapeTreeToLiteral(&outfeed_buffers);
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h
index dceeb9e2eb..7929042869 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h
+++ b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h
@@ -42,7 +42,7 @@ class GpuTransferManager : public GenericTransferManager {
const LiteralSlice& literal) override;
Status TransferLiteralFromOutfeed(se::StreamExecutor* executor,
const Shape& literal_shape,
- Literal* literal) override;
+ MutableBorrowingLiteral literal) override;
private:
// Initiates the infeed data transfers. InfeedBuffer->Done() must be
diff --git a/tensorflow/compiler/xla/service/gpu/outfeed_manager.h b/tensorflow/compiler/xla/service/gpu/outfeed_manager.h
index a752eb7011..160ba4b691 100644
--- a/tensorflow/compiler/xla/service/gpu/outfeed_manager.h
+++ b/tensorflow/compiler/xla/service/gpu/outfeed_manager.h
@@ -36,22 +36,19 @@ class OutfeedBuffer {
OutfeedBuffer(int64 length) : length_(length) {}
// Waits for the device transfer to be finished.
- std::unique_ptr<Literal> WaitUntilAvailable() {
- done_.WaitForNotification();
- return std::move(destination_);
- }
+ void WaitUntilAvailable() { done_.WaitForNotification(); }
int64 length() const { return length_; }
- void set_destination(std::unique_ptr<Literal> destination) {
+ void set_destination(std::unique_ptr<MutableBorrowingLiteral> destination) {
destination_ = std::move(destination);
}
- Literal* destination() { return destination_.get(); }
+ MutableBorrowingLiteral* destination() { return destination_.get(); }
// Callback to signal that this buffer is consumed.
void Done() { done_.Notify(); }
private:
- std::unique_ptr<Literal> destination_;
+ std::unique_ptr<MutableBorrowingLiteral> destination_;
const int64 length_;
tensorflow::Notification done_;
};
diff --git a/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc b/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc
index 7986e63f43..b99d998c4d 100644
--- a/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc
@@ -50,10 +50,6 @@ Status OutfeedThunk::ExecuteOnStream(
if (!*buffer) { // Tuple pointers.
return Status::OK();
}
- // Allocate storage for the literal data.
- const Shape& shape =
- ShapeUtil::GetSubshape(outfeed_buffers->shape(), index);
- (*buffer)->set_destination(Literal::CreateFromShape(shape));
BufferAllocation::Slice slice = outfeed_slices_.element(index);
se::DeviceMemoryBase data_address;
diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc
index 212db0643c..e970e885c5 100644
--- a/tensorflow/compiler/xla/service/service.cc
+++ b/tensorflow/compiler/xla/service/service.cc
@@ -1052,10 +1052,10 @@ Status Service::TransferFromOutfeed(const TransferFromOutfeedRequest* arg,
executor = replicas[arg->replica_id()];
}
- Literal literal;
+ Literal literal(arg->shape_with_layout());
TF_RETURN_IF_ERROR(
execute_backend_->transfer_manager()->TransferLiteralFromOutfeed(
- executor, arg->shape_with_layout(), &literal));
+ executor, arg->shape_with_layout(), literal));
*result->mutable_literal() = literal.ToProto();
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/transfer_manager.cc b/tensorflow/compiler/xla/service/transfer_manager.cc
index 7232c658b3..32d368a904 100644
--- a/tensorflow/compiler/xla/service/transfer_manager.cc
+++ b/tensorflow/compiler/xla/service/transfer_manager.cc
@@ -43,15 +43,39 @@ TransferManager::GetPlatformTransferManagers() {
StatusOr<std::unique_ptr<Literal>> TransferManager::TransferLiteralFromDevice(
se::Stream* stream, const ShapedBuffer& device_buffer) {
StatusOr<std::unique_ptr<Literal>> ret;
+
se::Stream* substream = stream->GetOrCreateSubStream();
substream->ThenWaitFor(stream);
auto cleanup = tensorflow::gtl::MakeCleanup(
[&]() { stream->ReturnSubStream(substream); });
tensorflow::Notification n;
- TransferLiteralFromDevice(substream, device_buffer,
- [&](StatusOr<std::unique_ptr<Literal>> arg) {
- ret = std::move(arg);
+ Status s;
+ Literal literal(device_buffer.on_host_shape());
+ TransferLiteralFromDevice(substream, device_buffer, literal,
+ [&](Status status) {
+ s = status;
+ n.Notify();
+ });
+ n.WaitForNotification();
+ if (!s.ok()) {
+ return s;
+ }
+ return MakeUnique<Literal>(std::move(literal));
+}
+
+Status TransferManager::TransferLiteralFromDevice(
+ se::Stream* stream, const ShapedBuffer& device_buffer,
+ const MutableBorrowingLiteral& literal) {
+ se::Stream* substream = stream->GetOrCreateSubStream();
+ auto cleanup = tensorflow::gtl::MakeCleanup(
+ [&]() { stream->ReturnSubStream(substream); });
+
+ Status ret;
+ tensorflow::Notification n;
+ TransferLiteralFromDevice(substream, device_buffer, literal,
+ [&](Status status) {
+ ret = status;
n.Notify();
});
n.WaitForNotification();
@@ -76,22 +100,27 @@ Status TransferManager::TransferLiteralToDevice(
StatusOr<std::unique_ptr<Literal>> TransferManager::TransferArrayFromDevice(
se::Stream* stream, const Shape& shape,
const se::DeviceMemoryBase& source) {
+ StatusOr<std::unique_ptr<Literal>> ret;
// Implement the synchronous version by waiting on the asynchronous version.
// Use a substream so that if we are called from a HostCallback we don't
// deadlock.
- StatusOr<std::unique_ptr<Literal>> ret;
se::Stream* substream = stream->GetOrCreateSubStream();
auto cleanup = tensorflow::gtl::MakeCleanup(
[&]() { stream->ReturnSubStream(substream); });
tensorflow::Notification n;
- TransferArrayFromDevice(substream, shape, source,
- [&](StatusOr<std::unique_ptr<Literal>> arg) {
- ret = std::move(arg);
+ Literal literal(shape);
+ Status s;
+ TransferArrayFromDevice(substream, shape, source, literal,
+ [&](Status status) {
+ s = status;
n.Notify();
});
n.WaitForNotification();
- return ret;
+ if (!s.ok()) {
+ return s;
+ }
+ return MakeUnique<Literal>(std::move(literal));
}
Status TransferManager::TransferArrayToDevice(
@@ -130,7 +159,7 @@ Status TransferManager::TransferArrayToDeviceAsync(
void TransferManager::TransferArrayFromDevice(
se::Stream* stream, const Shape& shape, const se::DeviceMemoryBase& source,
- std::function<void(StatusOr<std::unique_ptr<Literal>>)> done) {
+ const MutableBorrowingLiteral& literal, std::function<void(Status)> done) {
if (!ShapeUtil::Equal(HostShapeToDeviceShape(shape), shape)) {
auto error = StrCat("Shape ", ShapeUtil::HumanString(shape),
" has a differently shaped representation on-device: ",
@@ -147,7 +176,8 @@ void TransferManager::TransferArrayFromDevice(
stream->parent()->platform(),
stream->parent()->device_ordinal());
shaped_buffer.set_buffer(source, /*index=*/{});
- return TransferLiteralFromDevice(stream, shaped_buffer, std::move(done));
+ return TransferLiteralFromDevice(stream, shaped_buffer, literal,
+ std::move(done));
}
/* static */ void TransferManager::RegisterTransferManager(
diff --git a/tensorflow/compiler/xla/service/transfer_manager.h b/tensorflow/compiler/xla/service/transfer_manager.h
index 82c599e482..475a2e5c14 100644
--- a/tensorflow/compiler/xla/service/transfer_manager.h
+++ b/tensorflow/compiler/xla/service/transfer_manager.h
@@ -59,6 +59,9 @@ class TransferManager {
// This function should be avoided in favor of the asynchronous version below.
virtual StatusOr<std::unique_ptr<Literal>> TransferLiteralFromDevice(
se::Stream* stream, const ShapedBuffer& device_buffer);
+ virtual Status TransferLiteralFromDevice(
+ se::Stream* stream, const ShapedBuffer& device_buffer,
+ const MutableBorrowingLiteral& literal);
// Begins transferring a literal containing the data held in the given
// ShapedBuffer using the provided executor.
@@ -69,9 +72,10 @@ class TransferManager {
//
// device_buffer is copied by reference and must live at least until done() is
// invoked.
- virtual void TransferLiteralFromDevice(
- se::Stream* stream, const ShapedBuffer& device_buffer,
- std::function<void(StatusOr<std::unique_ptr<Literal>>)> done) = 0;
+ virtual void TransferLiteralFromDevice(se::Stream* stream,
+ const ShapedBuffer& device_buffer,
+ MutableBorrowingLiteral literal,
+ std::function<void(Status)> done) = 0;
// Transfers the given literal into the previously allocated device memory
// represented by the given ShapedBuffer using the given executor. The shape
@@ -101,10 +105,10 @@ class TransferManager {
// transfer an array at a known address.
Status TransferArrayToDevice(se::Stream* stream, const LiteralSlice& literal,
const se::DeviceMemoryBase& dest);
- void TransferArrayFromDevice(
- se::Stream* stream, const Shape& shape,
- const se::DeviceMemoryBase& source,
- std::function<void(StatusOr<std::unique_ptr<Literal>>)> done);
+ void TransferArrayFromDevice(se::Stream* stream, const Shape& shape,
+ const se::DeviceMemoryBase& source,
+ const MutableBorrowingLiteral& literal,
+ std::function<void(Status)> done);
Status TransferArrayToDeviceAsync(se::Stream* stream,
const LiteralSlice& literal,
@@ -120,9 +124,9 @@ class TransferManager {
// Transfers the given literal from the Outfeed interface of the device,
// using the given executor.
- virtual Status TransferLiteralFromOutfeed(se::StreamExecutor* executor,
- const Shape& literal_shape,
- Literal* literal) = 0;
+ virtual Status TransferLiteralFromOutfeed(
+ se::StreamExecutor* executor, const Shape& literal_shape,
+ MutableBorrowingLiteral literal) = 0;
// Resets the devices associated with this transfer manager.
virtual Status ResetDevices(
diff --git a/tensorflow/compiler/xla/tests/tuple_test.cc b/tensorflow/compiler/xla/tests/tuple_test.cc
index 2fd70b72b5..d9c1dfa3f7 100644
--- a/tensorflow/compiler/xla/tests/tuple_test.cc
+++ b/tensorflow/compiler/xla/tests/tuple_test.cc
@@ -586,9 +586,9 @@ XLA_TEST_F(TupleHloTest,
}));
auto expected =
LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1<float>({2, 3}));
- auto literal = MakeUnique<Literal>();
+ auto literal = MakeUnique<Literal>(expected->shape());
TF_EXPECT_OK(backend().transfer_manager()->TransferLiteralFromOutfeed(
- backend().default_stream_executor(), expected->shape(), literal.get()));
+ backend().default_stream_executor(), expected->shape(), *literal));
EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *literal));
}