aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/jit/xla_device_context.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/jit/xla_device_context.cc')
-rw-r--r--tensorflow/compiler/jit/xla_device_context.cc89
1 files changed, 55 insertions, 34 deletions
diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc
index 0100bf51ed..0a0c089241 100644
--- a/tensorflow/compiler/jit/xla_device_context.cc
+++ b/tensorflow/compiler/jit/xla_device_context.cc
@@ -15,6 +15,9 @@ limitations under the License.
#include "tensorflow/compiler/jit/xla_device_context.h"
+#include <memory>
+
+#include "tensorflow/compiler/jit/xla_device.h"
#include "tensorflow/compiler/jit/xla_launch_util.h"
#include "tensorflow/compiler/tf2xla/literal_util.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
@@ -48,17 +51,20 @@ void XlaDeviceAllocator::DeallocateRaw(void* ptr) {
void XlaDeviceAllocator::GetStats(AllocatorStats* stats) { stats->Clear(); }
XlaTransferManager::XlaTransferManager(
- se::Stream* compute_stream, se::Stream* host_to_device_stream,
- se::Stream* device_to_host_stream, xla::LocalClient* client,
+ std::shared_ptr<se::Stream> compute_stream,
+ std::shared_ptr<se::Stream> host_to_device_stream,
+ std::shared_ptr<se::Stream> device_to_host_stream, xla::LocalClient* client,
bool transfer_as_literal,
- XlaCompiler::ShapeRepresentationFn shape_representation_fn)
- : stream_(compute_stream),
- host_to_device_stream_(host_to_device_stream),
- device_to_host_stream_(device_to_host_stream),
+ XlaCompiler::ShapeRepresentationFn shape_representation_fn,
+ thread::ThreadPool* thread_pool)
+ : stream_(std::move(compute_stream)),
+ host_to_device_stream_(std::move(host_to_device_stream)),
+ device_to_host_stream_(std::move(device_to_host_stream)),
client_(client),
transfer_manager_(client->backend().transfer_manager()),
transfer_as_literal_(transfer_as_literal),
- shape_representation_fn_(std::move(shape_representation_fn)) {
+ shape_representation_fn_(std::move(shape_representation_fn)),
+ thread_pool_(thread_pool) {
CHECK(host_to_device_stream_ != nullptr);
CHECK(device_to_host_stream_ != nullptr);
CHECK(stream_ != nullptr);
@@ -88,15 +94,15 @@ Status XlaTransferManager::TransferLiteralToDevice(
if (UseMultipleStreams()) {
// Initially wait for the compute stream so that memory allocations are
// synchronized.
- host_to_device_stream_->ThenWaitFor(stream_);
+ host_to_device_stream_->ThenWaitFor(stream_.get());
}
TF_RETURN_IF_ERROR(transfer_manager_->TransferLiteralToDeviceAsync(
- host_to_device_stream_, *literal, shaped_buffer));
+ host_to_device_stream_.get(), *literal, shaped_buffer));
if (UseMultipleStreams()) {
- se::Event event(stream_->parent());
- TF_RET_CHECK(event.Init()) << "Event failed to initialize!";
- host_to_device_stream_->ThenRecordEvent(&event);
- xla_tensor->SetDefinedOn(host_to_device_stream_, std::move(event));
+ auto event = std::make_shared<se::Event>(stream_->parent());
+ TF_RET_CHECK(event->Init()) << "Event failed to initialize!";
+ host_to_device_stream_->ThenRecordEvent(event.get());
+ xla_tensor->SetDefinedOn(host_to_device_stream_.get(), std::move(event));
}
// Unref the host tensor, and capture the literal shared_ptr too so it goes
// out of scope when the lambda completes.
@@ -116,7 +122,7 @@ void XlaTransferManager::TransferLiteralFromDevice(
TensorReference ref(device_tensor);
transfer_manager_->TransferLiteralFromDevice(
- device_to_host_stream_, shaped_buffer, literal,
+ device_to_host_stream_.get(), shaped_buffer, literal,
[=, &shaped_buffer, &literal](xla::Status status) {
ref.Unref();
done([&]() -> Status {
@@ -179,8 +185,14 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
status = TransferLiteralToDevice(reshaped_cpu_tensor, device_tensor);
if (status.ok()) {
xla_tensor->set_host_tensor(*cpu_tensor);
- host_to_device_stream_->ThenDoHostCallback(
- [done]() { done(Status::OK()); });
+ host_to_device_stream_->ThenDoHostCallback([this, done]() {
+ // We must not call the done closure directly from DoHostCallback
+ // to avoid a deadlock. If done() is the callback that ends an
+ // Executor's run, the Executor may call XlaDevice::Sync() inside the
+ // callback. This deadlocks, because XlaDevice::Sync() waits for all
+ // stream activity to complete.
+ thread_pool_->Schedule([done]() { done(Status::OK()); });
+ });
return;
}
} else {
@@ -192,7 +204,7 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
if (!block_status.ok()) {
status = xla::InternalError(
"Failed to complete data transfer on stream %p: %s",
- host_to_device_stream_, block_status.error_message().c_str());
+ host_to_device_stream_.get(), block_status.error_message().c_str());
}
}
xla_tensor->set_host_tensor(*cpu_tensor);
@@ -225,9 +237,9 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor,
XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor);
if (se::Event* event =
- xla_tensor->GetDefinitionEvent(device_to_host_stream_)) {
+ xla_tensor->GetDefinitionEvent(device_to_host_stream_.get())) {
device_to_host_stream_->ThenWaitFor(event);
- xla_tensor->SetDefinedOn(device_to_host_stream_);
+ xla_tensor->SetDefinedOn(device_to_host_stream_.get());
}
Status status;
@@ -240,7 +252,7 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor,
Status block_status = device_to_host_stream_->BlockHostUntilDone();
if (!block_status.ok()) {
status = xla::InternalError(
- "Failed to complete data transfer on stream %p: %s", stream_,
+ "Failed to complete data transfer on stream %p: %s", stream_.get(),
block_status.error_message().c_str());
}
}
@@ -278,14 +290,14 @@ void XlaTransferManager::CopyDeviceTensorToDevice(const Tensor& src_tensor,
if (stream_ != device_to_device_stream) {
// Initially wait for the compute stream so that memory allocations are
// synchronized.
- device_to_device_stream->ThenWaitFor(stream_);
+ device_to_device_stream->ThenWaitFor(stream_.get());
}
}
if (se::Event* event =
- xla_src->GetDefinitionEvent(device_to_device_stream)) {
+ xla_src->GetDefinitionEvent(device_to_device_stream.get())) {
device_to_device_stream->ThenWaitFor(event);
- xla_src->SetDefinedOn(device_to_device_stream);
+ xla_src->SetDefinedOn(device_to_device_stream.get());
}
auto from_iter = xla_src->shaped_buffer().buffers().begin();
@@ -297,28 +309,37 @@ void XlaTransferManager::CopyDeviceTensorToDevice(const Tensor& src_tensor,
}
if (UseMultipleStreams()) {
- se::Event event(stream_->parent());
- CHECK(event.Init());
- device_to_device_stream->ThenRecordEvent(&event);
- xla_dst->SetDefinedOn(device_to_device_stream, std::move(event));
+ auto event = std::make_shared<se::Event>(stream_->parent());
+ TF_RET_CHECK(event->Init()) << "Event failed to initialize";
+ device_to_device_stream->ThenRecordEvent(event.get());
+ xla_dst->SetDefinedOn(device_to_device_stream.get(), std::move(event));
}
return Status::OK();
}();
if (!status.ok()) {
return done(status);
} else {
- stream_->ThenDoHostCallback([=]() { done(Status::OK()); });
+ stream_->ThenDoHostCallback([this, done]() {
+ // We must not call the done closure directly from DoHostCallback to avoid
+ // a deadlock. If done() is the callback that ends an Executor's run, the
+ // Executor may call XlaDevice::Sync() inside the callback. This
+ // deadlocks, because XlaDevice::Sync() waits for all stream activity to
+ // complete.
+ thread_pool_->Schedule([done]() { done(Status::OK()); });
+ });
}
}
XlaDeviceContext::XlaDeviceContext(
- se::Stream* compute_stream, se::Stream* host_to_device_stream,
- se::Stream* device_to_host_stream, xla::LocalClient* client,
+ std::shared_ptr<se::Stream> compute_stream,
+ std::shared_ptr<se::Stream> host_to_device_stream,
+ std::shared_ptr<se::Stream> device_to_host_stream, xla::LocalClient* client,
bool transfer_as_literal,
- XlaCompiler::ShapeRepresentationFn shape_representation_fn)
- : manager_(compute_stream, host_to_device_stream, device_to_host_stream,
- client, transfer_as_literal,
- std::move(shape_representation_fn)) {}
+ XlaCompiler::ShapeRepresentationFn shape_representation_fn,
+ thread::ThreadPool* thread_pool)
+ : manager_(std::move(compute_stream), std::move(host_to_device_stream),
+ std::move(device_to_host_stream), client, transfer_as_literal,
+ std::move(shape_representation_fn), thread_pool) {}
void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
Device* device,