aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/generic_transfer_manager.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/generic_transfer_manager.cc')
-rw-r--r--tensorflow/compiler/xla/service/generic_transfer_manager.cc22
1 files changed, 10 insertions, 12 deletions
diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.cc b/tensorflow/compiler/xla/service/generic_transfer_manager.cc
index e314a469f0..0ce2db907b 100644
--- a/tensorflow/compiler/xla/service/generic_transfer_manager.cc
+++ b/tensorflow/compiler/xla/service/generic_transfer_manager.cc
@@ -24,7 +24,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/interpreter/platform_id.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
-#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
@@ -60,17 +59,19 @@ Status GenericTransferManager::WriteSingleTupleIndexTable(
void GenericTransferManager::TransferLiteralFromDevice(
se::Stream* stream, const ShapedBuffer& device_buffer,
- std::function<void(StatusOr<std::unique_ptr<Literal>>)> done) {
+ MutableBorrowingLiteral literal, std::function<void(Status)> done) {
Status status = stream->BlockHostUntilDone();
if (!status.ok()) {
return done(status);
}
- done(TransferLiteralFromDeviceInternal(stream->parent(), device_buffer));
+
+ done(TransferLiteralFromDeviceInternal(stream->parent(), device_buffer,
+ literal));
}
-StatusOr<std::unique_ptr<Literal>>
-GenericTransferManager::TransferLiteralFromDeviceInternal(
- se::StreamExecutor* executor, const ShapedBuffer& device_buffer) {
+Status GenericTransferManager::TransferLiteralFromDeviceInternal(
+ se::StreamExecutor* executor, const ShapedBuffer& device_buffer,
+ MutableBorrowingLiteral literal) {
VLOG(2) << "transferring literal from device ordinal "
<< executor->device_ordinal() << "; device buffer: " << device_buffer;
TF_RET_CHECK(executor->device_ordinal() == device_buffer.device_ordinal());
@@ -80,9 +81,6 @@ GenericTransferManager::TransferLiteralFromDeviceInternal(
TF_RET_CHECK(ShapeUtil::Equal(device_buffer.on_device_shape(),
device_buffer.on_host_shape()));
- std::unique_ptr<Literal> literal =
- Literal::CreateFromShape(device_buffer.on_host_shape());
-
TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
device_buffer.on_host_shape(),
[&](const Shape& subshape, const ShapeIndex& index) -> Status {
@@ -91,12 +89,12 @@ GenericTransferManager::TransferLiteralFromDeviceInternal(
/*source=*/device_buffer.buffer(index),
/*size=*/GetByteSizeRequirement(subshape),
/*destination=*/
- literal->untyped_data(index)));
+ literal.untyped_data(index)));
}
return Status::OK();
}));
- return std::move(literal);
+ return Status::OK();
}
Status GenericTransferManager::TransferLiteralToDeviceAsync(
@@ -160,7 +158,7 @@ Status GenericTransferManager::TransferLiteralToInfeed(
Status GenericTransferManager::TransferLiteralFromOutfeed(
se::StreamExecutor* executor, const Shape& literal_shape,
- Literal* literal) {
+ MutableBorrowingLiteral literal) {
return Unimplemented("Generic transfer from Outfeed");
}