aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-27 05:19:14 -0700
committerGravatar Gunhan Gulsoy <gunan@google.com>2018-06-28 21:37:43 -0700
commitc2d369373d7e0cbdb01be9f556a5a36ff3ce6cf6 (patch)
treecbf006b41429f239d72c42e1638cdc728a86426a
parent8f1061f846c76c882b75de42d9bda395822cf666 (diff)
[XLA] Support asynchronous execution through XLA
PiperOrigin-RevId: 202292422
-rw-r--r--tensorflow/compiler/jit/xla_compilation_cache.cc18
-rw-r--r--tensorflow/compiler/jit/xla_device_context.cc103
-rw-r--r--tensorflow/compiler/jit/xla_device_context.h5
-rw-r--r--tensorflow/compiler/xla/service/executable.cc13
-rw-r--r--tensorflow/compiler/xla/service/hlo_runner.cc9
-rw-r--r--tensorflow/compiler/xla/tests/local_client_execute_test.cc4
-rw-r--r--tensorflow/compiler/xla/tests/local_client_test_base.cc14
-rw-r--r--tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc1
-rw-r--r--tensorflow/stream_executor/host/host_gpu_executor.cc2
9 files changed, 117 insertions, 52 deletions
diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc
index 7ed609c437..54a41a4daa 100644
--- a/tensorflow/compiler/jit/xla_compilation_cache.cc
+++ b/tensorflow/compiler/jit/xla_compilation_cache.cc
@@ -40,7 +40,23 @@ namespace tensorflow {
XlaCompilationCache::XlaCompilationCache(xla::LocalClient* client,
DeviceType device_type)
: client_(client), device_type_(std::move(device_type)) {}
-XlaCompilationCache::~XlaCompilationCache() = default;
+XlaCompilationCache::~XlaCompilationCache() {
+ // Ensure any use of our programs have completed by waiting for all stream
+ // executors to complete.
+ for (auto* executor : client_->backend().stream_executors()) {
+ bool ok = executor->SynchronizeAllActivity();
+ if (!ok) {
+ LOG(ERROR) << "Error synchronizing activity while waiting for all "
+ "programs to complete";
+ }
+ }
+ // TODO(b/110813685): Think about the program ownership model. Programs are
+ // currently owned by the compilation cache which means we must wait for
+ // program completion in the destructor. There are multiple compilation caches
+ // around, which complicates things a little. Perhaps having programs be
+ // shared_ptrs (an invasive change) would make the model easier to reason
+ // about?
+}
string XlaCompilationCache::DebugString() {
return "XLA JIT compilation cache";
diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc
index 37005479dc..e20f5aa837 100644
--- a/tensorflow/compiler/jit/xla_device_context.cc
+++ b/tensorflow/compiler/jit/xla_device_context.cc
@@ -67,36 +67,53 @@ Status XlaTransferManager::TransferLiteralToDevice(
xla::Shape xla_shape;
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(host_tensor.dtype(),
host_tensor.shape(), &xla_shape));
- xla::BorrowingLiteral literal(
+ // Create a reference to hold onto host_tensor until after the literal has
+ // been transferred. Also make sure the literal exists until the function
+ // asynchronously completes, as it will be wrapped in an xla::LiteralSlice.
+ TensorReference ref(host_tensor);
+ auto literal = std::make_shared<xla::BorrowingLiteral>(
static_cast<const char*>(DMAHelper::base(&host_tensor)), xla_shape);
const xla::ShapedBuffer& shaped_buffer =
XlaTensor::FromTensor(device_tensor)->shaped_buffer();
- VLOG(1) << "Transfer to device as literal: " << literal.ToString() << " "
+ VLOG(1) << "Transfer to device as literal: " << literal->ToString() << " "
<< shaped_buffer.ToString();
- return transfer_manager_->TransferLiteralToDevice(stream_, literal,
- shaped_buffer);
+ TF_RETURN_IF_ERROR(transfer_manager_->TransferLiteralToDeviceAsync(
+ stream_, *literal, shaped_buffer));
+ // Unref the host tensor, and capture the literal shared_ptr too so it goes
+ // out of scope when the lambda completes.
+ stream_->ThenDoHostCallback([ref, literal]() { ref.Unref(); });
+ return Status::OK();
}
-Status XlaTransferManager::TransferLiteralFromDevice(
- Tensor* host_tensor, const Tensor& device_tensor) const {
+void XlaTransferManager::TransferLiteralFromDevice(
+ Tensor* host_tensor, const Tensor& device_tensor,
+ const StatusCallback& done) const {
const xla::ShapedBuffer& shaped_buffer =
XlaTensor::FromTensor(&device_tensor)->shaped_buffer();
- TF_ASSIGN_OR_RETURN(
- std::unique_ptr<xla::Literal> literal,
- transfer_manager_->TransferLiteralFromDevice(stream_, shaped_buffer));
- VLOG(1) << "Transfer from device as literal: " << literal->ToString() << " "
- << shaped_buffer.ToString();
- Tensor tensor;
- TF_RETURN_IF_ERROR(
- LiteralToHostTensor(*literal, host_tensor->dtype(), &tensor));
- // Reshape the tensor back to its declared shape.
- if (!host_tensor->CopyFrom(tensor, device_tensor.shape())) {
- return errors::Internal(
- "Tensor::CopyFrom failed when copying from XLA device to CPU");
- }
- return Status::OK();
+ TensorReference ref(device_tensor);
+ transfer_manager_->TransferLiteralFromDevice(
+ stream_, shaped_buffer,
+ [=, &shaped_buffer](
+ xla::StatusOr<std::unique_ptr<xla::Literal> > literal_or) {
+ ref.Unref();
+ done([&]() -> Status {
+ TF_ASSIGN_OR_RETURN(auto literal, std::move(literal_or));
+ VLOG(1) << "Transfer from device as literal: " << literal->ToString()
+ << " " << shaped_buffer.ToString();
+ Tensor tensor;
+ TF_RETURN_IF_ERROR(
+ LiteralToHostTensor(*literal, host_tensor->dtype(), &tensor));
+ // Reshape the tensor back to its declared shape.
+ Status status;
+ if (!host_tensor->CopyFrom(tensor, device_tensor.shape())) {
+ status = errors::Internal(
+ "Tensor::CopyFrom failed when copying from XLA device to CPU");
+ }
+ return status;
+ }());
+ });
}
void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
@@ -121,17 +138,16 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
TensorShape shape = shape_representation_fn_(device_tensor->shape(),
device_tensor->dtype());
+ Status status;
if (!xla_tensor->has_shaped_buffer()) {
- Status s = xla_tensor->AllocateShapedBuffer(
+ status = xla_tensor->AllocateShapedBuffer(
device_tensor->dtype(), shape, client_,
stream_->parent()->device_ordinal());
- if (!s.ok()) {
- done(s);
- return;
+ if (!status.ok()) {
+ return done(status);
}
}
- Status status;
if (transfer_as_literal_) {
Tensor reshaped_cpu_tensor;
if (!reshaped_cpu_tensor.CopyFrom(*cpu_tensor, shape)) {
@@ -184,7 +200,8 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor,
Status status;
if (transfer_as_literal_) {
- status = TransferLiteralFromDevice(cpu_tensor, *device_tensor);
+ TransferLiteralFromDevice(cpu_tensor, *device_tensor, done);
+ return;
} else {
stream_->ThenMemcpy(dst_ptr, dev_src_ptr, total_bytes);
// TODO(hpucha): Make this asynchronous.
@@ -194,9 +211,8 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor,
"Failed to complete data transfer on stream %p: %s", stream_,
block_status.error_message().c_str());
}
+ done(status);
}
-
- done(status);
return;
}
@@ -207,8 +223,8 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor,
void XlaTransferManager::CopyDeviceTensorToDevice(const Tensor& src_tensor,
Tensor* dst_tensor,
const StatusCallback& done) {
- // TODO(phawkins): replace this code with an asynchronous implementation.
- auto body = [&]() {
+ // Perform memory allocation now, and enqueue the device-to-device transfer.
+ Status status = [&]() -> Status {
if (src_tensor.NumElements() == 0) {
return Status::OK();
}
@@ -223,21 +239,20 @@ void XlaTransferManager::CopyDeviceTensorToDevice(const Tensor& src_tensor,
xla_dst->AllocateShapedBuffer(src_tensor.dtype(), shape, client_,
stream_->parent()->device_ordinal()));
}
- TF_RETURN_IF_ERROR(
- xla_dst->shaped_buffer().buffers().ForEachMutableElementWithStatus(
- [&](const xla::ShapeIndex& index, se::DeviceMemoryBase* buffer) {
- const se::DeviceMemoryBase& from_buffer =
- xla_src->shaped_buffer().buffers().element(index);
- CHECK_EQ(buffer->size(), from_buffer.size());
- if (!stream_->parent()->SynchronousMemcpy(buffer, from_buffer,
- buffer->size())) {
- return errors::Internal("Device to device memcpy failed");
- }
- return Status::OK();
- }));
+ auto from_iter = xla_src->shaped_buffer().buffers().begin();
+ auto to_iter = xla_dst->shaped_buffer().buffers().begin();
+ for (auto end_iter = xla_src->shaped_buffer().buffers().end();
+ from_iter != end_iter; ++from_iter, ++to_iter) {
+ stream_->ThenMemcpyD2D(&to_iter->second, from_iter->second,
+ to_iter->second.size());
+ }
return Status::OK();
- };
- done(body());
+ }();
+ if (!status.ok()) {
+ return done(status);
+ } else {
+ stream_->ThenDoHostCallback([=]() { done(Status::OK()); });
+ }
}
XlaDeviceContext::XlaDeviceContext(
diff --git a/tensorflow/compiler/jit/xla_device_context.h b/tensorflow/compiler/jit/xla_device_context.h
index ee346e5653..c5c81d65fe 100644
--- a/tensorflow/compiler/jit/xla_device_context.h
+++ b/tensorflow/compiler/jit/xla_device_context.h
@@ -64,8 +64,9 @@ class XlaTransferManager {
private:
Status TransferLiteralToDevice(const Tensor& host_tensor,
Tensor* device_tensor) const;
- Status TransferLiteralFromDevice(Tensor* host_tensor,
- const Tensor& device_tensor) const;
+ void TransferLiteralFromDevice(Tensor* host_tensor,
+ const Tensor& device_tensor,
+ const StatusCallback& done) const;
// Stream obtained from a Device, used to transfer tensors between
// CPU and device.
diff --git a/tensorflow/compiler/xla/service/executable.cc b/tensorflow/compiler/xla/service/executable.cc
index 7cf2746947..fd75847d0c 100644
--- a/tensorflow/compiler/xla/service/executable.cc
+++ b/tensorflow/compiler/xla/service/executable.cc
@@ -82,7 +82,18 @@ StatusOr<ScopedShapedBuffer> Executable::ExecuteOnStreamWrapper(
StatusOr<ScopedShapedBuffer> return_value =
ExecuteOnStream(run_options, arguments, profile_ptr.get());
- TF_RETURN_IF_ERROR(return_value.status());
+ if (!return_value.status().ok()) {
+ if (profile != nullptr) {
+ // Ensure the ThenStartTimer call has completed before we destroy timer.
+ // We already have a failure status to return, so just log this if it
+ // fails.
+ Status status = stream->BlockHostUntilDone();
+ if (!status.ok()) {
+ LOG(ERROR) << "Failed to BlockHostUntilDone: " << status;
+ }
+ }
+ return return_value.status();
+ }
if (profile != nullptr) {
VLOG(1) << "enqueueing 'stop timer' and blocking host until done...";
diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc
index 4f0569f405..b2725e2918 100644
--- a/tensorflow/compiler/xla/service/hlo_runner.cc
+++ b/tensorflow/compiler/xla/service/hlo_runner.cc
@@ -180,8 +180,12 @@ StatusOr<ScopedShapedBuffer> HloRunner::ExecuteWithDeviceBuffers(
TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
CreateExecutable(std::move(module), run_hlo_passes));
- return executable->ExecuteOnStreamWrapper(&service_run_options,
- /*profile=*/profile, arguments);
+ TF_ASSIGN_OR_RETURN(
+ ScopedShapedBuffer retval,
+ executable->ExecuteOnStreamWrapper(&service_run_options,
+ /*profile=*/profile, arguments));
+ TF_RETURN_IF_ERROR(stream.BlockHostUntilDone());
+ return std::move(retval);
}
StatusOr<ScopedShapedBuffer> HloRunner::ExecuteWithDeviceBuffers(
@@ -309,6 +313,7 @@ StatusOr<std::vector<std::unique_ptr<Literal>>> HloRunner::ExecuteReplicated(
std::vector<std::unique_ptr<Literal>> exec_results;
for (int64 i = 0; i < options.num_replicas; ++i) {
+ TF_RETURN_IF_ERROR(streams[i]->BlockHostUntilDone());
TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> literal,
backend().transfer_manager()->TransferLiteralFromDevice(
streams[i].get(), results[i]));
diff --git a/tensorflow/compiler/xla/tests/local_client_execute_test.cc b/tensorflow/compiler/xla/tests/local_client_execute_test.cc
index 77f9c33ee1..8a903f1e6d 100644
--- a/tensorflow/compiler/xla/tests/local_client_execute_test.cc
+++ b/tensorflow/compiler/xla/tests/local_client_execute_test.cc
@@ -772,6 +772,10 @@ XLA_TEST_F(LocalClientExecuteTest, CompileExecutable) {
ScopedShapedBuffer result =
executable->Run({&x_array}, DefaultExecutableRunOptions())
.ConsumeValueOrDie();
+ ASSERT_IS_OK(local_client_->mutable_backend()
+ ->BorrowStream(0)
+ .ValueOrDie()
+ ->BlockHostUntilDone());
LiteralTestUtil::ExpectR1Near<float>(
{2.0f, 4.0f, 6.0f}, *ShapedBufferToLiteral(result), error_spec_);
diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.cc b/tensorflow/compiler/xla/tests/local_client_test_base.cc
index 88797a7d0a..c31ba0e713 100644
--- a/tensorflow/compiler/xla/tests/local_client_test_base.cc
+++ b/tensorflow/compiler/xla/tests/local_client_test_base.cc
@@ -189,7 +189,19 @@ StatusOr<ScopedShapedBuffer> LocalClientTestBase::ExecuteLocally(
TF_ASSIGN_OR_RETURN(
std::unique_ptr<LocalExecutable> executable,
local_client_->Compile(computation, argument_layouts, build_options));
- return executable->Run(arguments, run_options);
+ TF_ASSIGN_OR_RETURN(auto ret, executable->Run(arguments, run_options));
+
+ auto device_ordinal =
+ build_options.device_ordinal() == -1 ? 0 : build_options.device_ordinal();
+ auto* stream = run_options.stream();
+ if (!stream) {
+ stream = local_client_->mutable_backend()
+ ->BorrowStream(device_ordinal)
+ .ValueOrDie()
+ .get();
+ }
+ TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
+ return std::move(ret);
}
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc
index b081850eb5..e7074915ee 100644
--- a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc
+++ b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc
@@ -168,6 +168,7 @@ void ExecuteAndFetchProfile(string* profile_output, LocalClient* client,
auto execution_result,
executable->ExecuteOnStream(&run_options, {&lhs_arg, &rhs_arg},
&hlo_execution_profile));
+ TF_ASSERT_OK(stream_ptr->BlockHostUntilDone());
(void)execution_result;
*profile_output =
diff --git a/tensorflow/stream_executor/host/host_gpu_executor.cc b/tensorflow/stream_executor/host/host_gpu_executor.cc
index 2c4819651a..c8a6297330 100644
--- a/tensorflow/stream_executor/host/host_gpu_executor.cc
+++ b/tensorflow/stream_executor/host/host_gpu_executor.cc
@@ -95,7 +95,7 @@ bool HostExecutor::MemcpyDeviceToDevice(Stream *stream,
// the nature of the HostExecutor) memcpy on the stream (HostStream)
// associated with the HostExecutor.
AsHostStream(stream)->EnqueueTask(
- [src_mem, dst_mem, size]() { memcpy(src_mem, dst_mem, size); });
+ [src_mem, dst_mem, size]() { memcpy(dst_mem, src_mem, size); });
return true;
}