diff options
Diffstat (limited to 'tensorflow/compiler/jit/xla_device_context.cc')
-rw-r--r-- | tensorflow/compiler/jit/xla_device_context.cc | 26 |
1 files changed, 8 insertions, 18 deletions
diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc index 0a0c089241..ee07c5c964 100644 --- a/tensorflow/compiler/jit/xla_device_context.cc +++ b/tensorflow/compiler/jit/xla_device_context.cc @@ -91,7 +91,8 @@ Status XlaTransferManager::TransferLiteralToDevice( const xla::ShapedBuffer& shaped_buffer = xla_tensor->shaped_buffer(); VLOG(1) << "Transfer to device as literal: " << literal->ToString() << " " << shaped_buffer.ToString(); - if (UseMultipleStreams()) { + if (UseMultipleStreams() && !transfer_manager_->CanShapedBufferBeAccessedNow( + stream_->parent(), shaped_buffer)) { // Initially wait for the compute stream so that memory allocations are // synchronized. host_to_device_stream_->ThenWaitFor(stream_.get()); @@ -123,11 +124,11 @@ void XlaTransferManager::TransferLiteralFromDevice( TensorReference ref(device_tensor); transfer_manager_->TransferLiteralFromDevice( device_to_host_stream_.get(), shaped_buffer, literal, - [=, &shaped_buffer, &literal](xla::Status status) { + [=, &shaped_buffer](xla::Status status) { ref.Unref(); done([&]() -> Status { - VLOG(1) << "Transfer from device as literal: " << literal.ToString() - << " " << shaped_buffer.ToString(); + VLOG(1) << "Transfer from device as literal: " + << shaped_buffer.ToString(); return status; }()); }); @@ -183,18 +184,6 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor, return; } status = TransferLiteralToDevice(reshaped_cpu_tensor, device_tensor); - if (status.ok()) { - xla_tensor->set_host_tensor(*cpu_tensor); - host_to_device_stream_->ThenDoHostCallback([this, done]() { - // We must not call the done closure directly from DoHostCallback - // to avoid a deadlock. If done() is the callback that ends an - // Executor's run, the Executor may call XlaDevice::Sync() inside the - // callback. This deadlocks, because XlaDevice::Sync() waits for all - // stream activity to complete. - thread_pool_->Schedule([done]() { done(Status::OK()); }); - }); - return; - } } else { se::DeviceMemoryBase dev_dst_ptr = XlaTensor::DeviceMemoryFromTensor(*device_tensor); @@ -207,8 +196,9 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor, host_to_device_stream_.get(), block_status.error_message().c_str()); } } - xla_tensor->set_host_tensor(*cpu_tensor); - + if (status.ok()) { + xla_tensor->set_host_tensor(*cpu_tensor); + } done(status); } |