diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/transfer_manager.h')
-rw-r--r-- | tensorflow/compiler/xla/service/transfer_manager.h | 24 |
1 files changed, 14 insertions, 10 deletions
diff --git a/tensorflow/compiler/xla/service/transfer_manager.h b/tensorflow/compiler/xla/service/transfer_manager.h index 82c599e482..475a2e5c14 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.h +++ b/tensorflow/compiler/xla/service/transfer_manager.h @@ -59,6 +59,9 @@ class TransferManager { // This function should be avoided in favor of the asynchronous version below. virtual StatusOr<std::unique_ptr<Literal>> TransferLiteralFromDevice( se::Stream* stream, const ShapedBuffer& device_buffer); + virtual Status TransferLiteralFromDevice( + se::Stream* stream, const ShapedBuffer& device_buffer, + const MutableBorrowingLiteral& literal); // Begins transferring a literal containing the data held in the given // ShapedBuffer using the provided executor. @@ -69,9 +72,10 @@ class TransferManager { // // device_buffer is copied by reference and must live at least until done() is // invoked. - virtual void TransferLiteralFromDevice( - se::Stream* stream, const ShapedBuffer& device_buffer, - std::function<void(StatusOr<std::unique_ptr<Literal>>)> done) = 0; + virtual void TransferLiteralFromDevice(se::Stream* stream, + const ShapedBuffer& device_buffer, + MutableBorrowingLiteral literal, + std::function<void(Status)> done) = 0; // Transfers the given literal into the previously allocated device memory // represented by the given ShapedBuffer using the given executor. The shape @@ -101,10 +105,10 @@ class TransferManager { // transfer an array at a known address. Status TransferArrayToDevice(se::Stream* stream, const LiteralSlice& literal, const se::DeviceMemoryBase& dest); - void TransferArrayFromDevice( - se::Stream* stream, const Shape& shape, - const se::DeviceMemoryBase& source, - std::function<void(StatusOr<std::unique_ptr<Literal>>)> done); + void TransferArrayFromDevice(se::Stream* stream, const Shape& shape, + const se::DeviceMemoryBase& source, + const MutableBorrowingLiteral& literal, + std::function<void(Status)> done); Status TransferArrayToDeviceAsync(se::Stream* stream, const LiteralSlice& literal, @@ -120,9 +124,9 @@ class TransferManager { // Transfers the given literal from the Outfeed interface of the device, // using the given executor. - virtual Status TransferLiteralFromOutfeed(se::StreamExecutor* executor, - const Shape& literal_shape, - Literal* literal) = 0; + virtual Status TransferLiteralFromOutfeed( + se::StreamExecutor* executor, const Shape& literal_shape, + MutableBorrowingLiteral literal) = 0; // Resets the devices associated with this transfer manager. virtual Status ResetDevices( |