aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/transfer_manager.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/transfer_manager.h')
-rw-r--r--tensorflow/compiler/xla/service/transfer_manager.h24
1 files changed, 14 insertions, 10 deletions
diff --git a/tensorflow/compiler/xla/service/transfer_manager.h b/tensorflow/compiler/xla/service/transfer_manager.h
index 82c599e482..475a2e5c14 100644
--- a/tensorflow/compiler/xla/service/transfer_manager.h
+++ b/tensorflow/compiler/xla/service/transfer_manager.h
@@ -59,6 +59,9 @@ class TransferManager {
// This function should be avoided in favor of the asynchronous version below.
virtual StatusOr<std::unique_ptr<Literal>> TransferLiteralFromDevice(
se::Stream* stream, const ShapedBuffer& device_buffer);
+ virtual Status TransferLiteralFromDevice(
+ se::Stream* stream, const ShapedBuffer& device_buffer,
+ const MutableBorrowingLiteral& literal);
// Begins transferring a literal containing the data held in the given
// ShapedBuffer using the provided executor.
@@ -69,9 +72,10 @@ class TransferManager {
//
// device_buffer is copied by reference and must live at least until done() is
// invoked.
- virtual void TransferLiteralFromDevice(
- se::Stream* stream, const ShapedBuffer& device_buffer,
- std::function<void(StatusOr<std::unique_ptr<Literal>>)> done) = 0;
+ virtual void TransferLiteralFromDevice(se::Stream* stream,
+ const ShapedBuffer& device_buffer,
+ MutableBorrowingLiteral literal,
+ std::function<void(Status)> done) = 0;
// Transfers the given literal into the previously allocated device memory
// represented by the given ShapedBuffer using the given executor. The shape
@@ -101,10 +105,10 @@ class TransferManager {
// transfer an array at a known address.
Status TransferArrayToDevice(se::Stream* stream, const LiteralSlice& literal,
const se::DeviceMemoryBase& dest);
- void TransferArrayFromDevice(
- se::Stream* stream, const Shape& shape,
- const se::DeviceMemoryBase& source,
- std::function<void(StatusOr<std::unique_ptr<Literal>>)> done);
+ void TransferArrayFromDevice(se::Stream* stream, const Shape& shape,
+ const se::DeviceMemoryBase& source,
+ const MutableBorrowingLiteral& literal,
+ std::function<void(Status)> done);
Status TransferArrayToDeviceAsync(se::Stream* stream,
const LiteralSlice& literal,
@@ -120,9 +124,9 @@ class TransferManager {
// Transfers the given literal from the Outfeed interface of the device,
// using the given executor.
- virtual Status TransferLiteralFromOutfeed(se::StreamExecutor* executor,
- const Shape& literal_shape,
- Literal* literal) = 0;
+ virtual Status TransferLiteralFromOutfeed(
+ se::StreamExecutor* executor, const Shape& literal_shape,
+ MutableBorrowingLiteral literal) = 0;
// Resets the devices associated with this transfer manager.
virtual Status ResetDevices(