aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/jit/xla_device_context.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/jit/xla_device_context.cc')
-rw-r--r--tensorflow/compiler/jit/xla_device_context.cc26
1 files changed, 8 insertions, 18 deletions
diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc
index 0a0c089241..ee07c5c964 100644
--- a/tensorflow/compiler/jit/xla_device_context.cc
+++ b/tensorflow/compiler/jit/xla_device_context.cc
@@ -91,7 +91,8 @@ Status XlaTransferManager::TransferLiteralToDevice(
const xla::ShapedBuffer& shaped_buffer = xla_tensor->shaped_buffer();
VLOG(1) << "Transfer to device as literal: " << literal->ToString() << " "
<< shaped_buffer.ToString();
- if (UseMultipleStreams()) {
+ if (UseMultipleStreams() && !transfer_manager_->CanShapedBufferBeAccessedNow(
+ stream_->parent(), shaped_buffer)) {
// Initially wait for the compute stream so that memory allocations are
// synchronized.
host_to_device_stream_->ThenWaitFor(stream_.get());
@@ -123,11 +124,11 @@ void XlaTransferManager::TransferLiteralFromDevice(
TensorReference ref(device_tensor);
transfer_manager_->TransferLiteralFromDevice(
device_to_host_stream_.get(), shaped_buffer, literal,
- [=, &shaped_buffer, &literal](xla::Status status) {
+ [=, &shaped_buffer](xla::Status status) {
ref.Unref();
done([&]() -> Status {
- VLOG(1) << "Transfer from device as literal: " << literal.ToString()
- << " " << shaped_buffer.ToString();
+ VLOG(1) << "Transfer from device as literal: "
+ << shaped_buffer.ToString();
return status;
}());
});
@@ -183,18 +184,6 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
return;
}
status = TransferLiteralToDevice(reshaped_cpu_tensor, device_tensor);
- if (status.ok()) {
- xla_tensor->set_host_tensor(*cpu_tensor);
- host_to_device_stream_->ThenDoHostCallback([this, done]() {
- // We must not call the done closure directly from DoHostCallback
- // to avoid a deadlock. If done() is the callback that ends an
- // Executor's run, the Executor may call XlaDevice::Sync() inside the
- // callback. This deadlocks, because XlaDevice::Sync() waits for all
- // stream activity to complete.
- thread_pool_->Schedule([done]() { done(Status::OK()); });
- });
- return;
- }
} else {
se::DeviceMemoryBase dev_dst_ptr =
XlaTensor::DeviceMemoryFromTensor(*device_tensor);
@@ -207,8 +196,9 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
host_to_device_stream_.get(), block_status.error_message().c_str());
}
}
- xla_tensor->set_host_tensor(*cpu_tensor);
-
+ if (status.ok()) {
+ xla_tensor->set_host_tensor(*cpu_tensor);
+ }
done(status);
}