diff options
author | 2018-06-27 05:19:14 -0700 | |
---|---|---|
committer | 2018-06-28 21:37:43 -0700 | |
commit | c2d369373d7e0cbdb01be9f556a5a36ff3ce6cf6 (patch) | |
tree | cbf006b41429f239d72c42e1638cdc728a86426a | |
parent | 8f1061f846c76c882b75de42d9bda395822cf666 (diff) |
[XLA] Support asynchronous execution through XLA
PiperOrigin-RevId: 202292422
-rw-r--r-- | tensorflow/compiler/jit/xla_compilation_cache.cc | 18 | ||||
-rw-r--r-- | tensorflow/compiler/jit/xla_device_context.cc | 103 | ||||
-rw-r--r-- | tensorflow/compiler/jit/xla_device_context.h | 5 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/executable.cc | 13 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_runner.cc | 9 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tests/local_client_execute_test.cc | 4 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tests/local_client_test_base.cc | 14 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc | 1 | ||||
-rw-r--r-- | tensorflow/stream_executor/host/host_gpu_executor.cc | 2 |
9 files changed, 117 insertions, 52 deletions
diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc index 7ed609c437..54a41a4daa 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache.cc @@ -40,7 +40,23 @@ namespace tensorflow { XlaCompilationCache::XlaCompilationCache(xla::LocalClient* client, DeviceType device_type) : client_(client), device_type_(std::move(device_type)) {} -XlaCompilationCache::~XlaCompilationCache() = default; +XlaCompilationCache::~XlaCompilationCache() { + // Ensure any use of our programs have completed by waiting for all stream + // executors to complete. + for (auto* executor : client_->backend().stream_executors()) { + bool ok = executor->SynchronizeAllActivity(); + if (!ok) { + LOG(ERROR) << "Error synchronizing activity while waiting for all " + "programs to complete"; + } + } + // TODO(b/110813685): Think about the program ownership model. Programs are + // currently owned by the compilation cache which means we must wait for + // program completion in the destructor. There are multiple compilation caches + // around, which complicates things a little. Perhaps having programs be + // shared_ptrs (an invasive change) would make the model easier to reason + // about? +} string XlaCompilationCache::DebugString() { return "XLA JIT compilation cache"; diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc index 37005479dc..e20f5aa837 100644 --- a/tensorflow/compiler/jit/xla_device_context.cc +++ b/tensorflow/compiler/jit/xla_device_context.cc @@ -67,36 +67,53 @@ Status XlaTransferManager::TransferLiteralToDevice( xla::Shape xla_shape; TF_RETURN_IF_ERROR(TensorShapeToXLAShape(host_tensor.dtype(), host_tensor.shape(), &xla_shape)); - xla::BorrowingLiteral literal( + // Create a reference to hold onto host_tensor until after the literal has + // been transferred. Also make sure the literal exists until the function + // asynchronously completes, as it will be wrapped in an xla::LiteralSlice. + TensorReference ref(host_tensor); + auto literal = std::make_shared<xla::BorrowingLiteral>( static_cast<const char*>(DMAHelper::base(&host_tensor)), xla_shape); const xla::ShapedBuffer& shaped_buffer = XlaTensor::FromTensor(device_tensor)->shaped_buffer(); - VLOG(1) << "Transfer to device as literal: " << literal.ToString() << " " + VLOG(1) << "Transfer to device as literal: " << literal->ToString() << " " << shaped_buffer.ToString(); - return transfer_manager_->TransferLiteralToDevice(stream_, literal, - shaped_buffer); + TF_RETURN_IF_ERROR(transfer_manager_->TransferLiteralToDeviceAsync( + stream_, *literal, shaped_buffer)); + // Unref the host tensor, and capture the literal shared_ptr too so it goes + // out of scope when the lambda completes. + stream_->ThenDoHostCallback([ref, literal]() { ref.Unref(); }); + return Status::OK(); } -Status XlaTransferManager::TransferLiteralFromDevice( - Tensor* host_tensor, const Tensor& device_tensor) const { +void XlaTransferManager::TransferLiteralFromDevice( + Tensor* host_tensor, const Tensor& device_tensor, + const StatusCallback& done) const { const xla::ShapedBuffer& shaped_buffer = XlaTensor::FromTensor(&device_tensor)->shaped_buffer(); - TF_ASSIGN_OR_RETURN( - std::unique_ptr<xla::Literal> literal, - transfer_manager_->TransferLiteralFromDevice(stream_, shaped_buffer)); - VLOG(1) << "Transfer from device as literal: " << literal->ToString() << " " - << shaped_buffer.ToString(); - Tensor tensor; - TF_RETURN_IF_ERROR( - LiteralToHostTensor(*literal, host_tensor->dtype(), &tensor)); - // Reshape the tensor back to its declared shape. - if (!host_tensor->CopyFrom(tensor, device_tensor.shape())) { - return errors::Internal( - "Tensor::CopyFrom failed when copying from XLA device to CPU"); - } - return Status::OK(); + TensorReference ref(device_tensor); + transfer_manager_->TransferLiteralFromDevice( + stream_, shaped_buffer, + [=, &shaped_buffer]( + xla::StatusOr<std::unique_ptr<xla::Literal> > literal_or) { + ref.Unref(); + done([&]() -> Status { + TF_ASSIGN_OR_RETURN(auto literal, std::move(literal_or)); + VLOG(1) << "Transfer from device as literal: " << literal->ToString() + << " " << shaped_buffer.ToString(); + Tensor tensor; + TF_RETURN_IF_ERROR( + LiteralToHostTensor(*literal, host_tensor->dtype(), &tensor)); + // Reshape the tensor back to its declared shape. + Status status; + if (!host_tensor->CopyFrom(tensor, device_tensor.shape())) { + status = errors::Internal( + "Tensor::CopyFrom failed when copying from XLA device to CPU"); + } + return status; + }()); + }); } void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor, @@ -121,17 +138,16 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor, TensorShape shape = shape_representation_fn_(device_tensor->shape(), device_tensor->dtype()); + Status status; if (!xla_tensor->has_shaped_buffer()) { - Status s = xla_tensor->AllocateShapedBuffer( + status = xla_tensor->AllocateShapedBuffer( device_tensor->dtype(), shape, client_, stream_->parent()->device_ordinal()); - if (!s.ok()) { - done(s); - return; + if (!status.ok()) { + return done(status); } } - Status status; if (transfer_as_literal_) { Tensor reshaped_cpu_tensor; if (!reshaped_cpu_tensor.CopyFrom(*cpu_tensor, shape)) { @@ -184,7 +200,8 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor, Status status; if (transfer_as_literal_) { - status = TransferLiteralFromDevice(cpu_tensor, *device_tensor); + TransferLiteralFromDevice(cpu_tensor, *device_tensor, done); + return; } else { stream_->ThenMemcpy(dst_ptr, dev_src_ptr, total_bytes); // TODO(hpucha): Make this asynchronous. @@ -194,9 +211,8 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor, "Failed to complete data transfer on stream %p: %s", stream_, block_status.error_message().c_str()); } + done(status); } - - done(status); return; } @@ -207,8 +223,8 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor, void XlaTransferManager::CopyDeviceTensorToDevice(const Tensor& src_tensor, Tensor* dst_tensor, const StatusCallback& done) { - // TODO(phawkins): replace this code with an asynchronous implementation. - auto body = [&]() { + // Perform memory allocation now, and enqueue the device-to-device transfer. + Status status = [&]() -> Status { if (src_tensor.NumElements() == 0) { return Status::OK(); } @@ -223,21 +239,20 @@ void XlaTransferManager::CopyDeviceTensorToDevice(const Tensor& src_tensor, xla_dst->AllocateShapedBuffer(src_tensor.dtype(), shape, client_, stream_->parent()->device_ordinal())); } - TF_RETURN_IF_ERROR( - xla_dst->shaped_buffer().buffers().ForEachMutableElementWithStatus( - [&](const xla::ShapeIndex& index, se::DeviceMemoryBase* buffer) { - const se::DeviceMemoryBase& from_buffer = - xla_src->shaped_buffer().buffers().element(index); - CHECK_EQ(buffer->size(), from_buffer.size()); - if (!stream_->parent()->SynchronousMemcpy(buffer, from_buffer, - buffer->size())) { - return errors::Internal("Device to device memcpy failed"); - } - return Status::OK(); - })); + auto from_iter = xla_src->shaped_buffer().buffers().begin(); + auto to_iter = xla_dst->shaped_buffer().buffers().begin(); + for (auto end_iter = xla_src->shaped_buffer().buffers().end(); + from_iter != end_iter; ++from_iter, ++to_iter) { + stream_->ThenMemcpyD2D(&to_iter->second, from_iter->second, + to_iter->second.size()); + } return Status::OK(); - }; - done(body()); + }(); + if (!status.ok()) { + return done(status); + } else { + stream_->ThenDoHostCallback([=]() { done(Status::OK()); }); + } } XlaDeviceContext::XlaDeviceContext( diff --git a/tensorflow/compiler/jit/xla_device_context.h b/tensorflow/compiler/jit/xla_device_context.h index ee346e5653..c5c81d65fe 100644 --- a/tensorflow/compiler/jit/xla_device_context.h +++ b/tensorflow/compiler/jit/xla_device_context.h @@ -64,8 +64,9 @@ class XlaTransferManager { private: Status TransferLiteralToDevice(const Tensor& host_tensor, Tensor* device_tensor) const; - Status TransferLiteralFromDevice(Tensor* host_tensor, - const Tensor& device_tensor) const; + void TransferLiteralFromDevice(Tensor* host_tensor, + const Tensor& device_tensor, + const StatusCallback& done) const; // Stream obtained from a Device, used to transfer tensors between // CPU and device. diff --git a/tensorflow/compiler/xla/service/executable.cc b/tensorflow/compiler/xla/service/executable.cc index 7cf2746947..fd75847d0c 100644 --- a/tensorflow/compiler/xla/service/executable.cc +++ b/tensorflow/compiler/xla/service/executable.cc @@ -82,7 +82,18 @@ StatusOr<ScopedShapedBuffer> Executable::ExecuteOnStreamWrapper( StatusOr<ScopedShapedBuffer> return_value = ExecuteOnStream(run_options, arguments, profile_ptr.get()); - TF_RETURN_IF_ERROR(return_value.status()); + if (!return_value.status().ok()) { + if (profile != nullptr) { + // Ensure the ThenStartTimer call has completed before we destroy timer. + // We already have a failure status to return, so just log this if it + // fails. + Status status = stream->BlockHostUntilDone(); + if (!status.ok()) { + LOG(ERROR) << "Failed to BlockHostUntilDone: " << status; + } + } + return return_value.status(); + } if (profile != nullptr) { VLOG(1) << "enqueueing 'stop timer' and blocking host until done..."; diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc index 4f0569f405..b2725e2918 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.cc +++ b/tensorflow/compiler/xla/service/hlo_runner.cc @@ -180,8 +180,12 @@ StatusOr<ScopedShapedBuffer> HloRunner::ExecuteWithDeviceBuffers( TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable, CreateExecutable(std::move(module), run_hlo_passes)); - return executable->ExecuteOnStreamWrapper(&service_run_options, - /*profile=*/profile, arguments); + TF_ASSIGN_OR_RETURN( + ScopedShapedBuffer retval, + executable->ExecuteOnStreamWrapper(&service_run_options, + /*profile=*/profile, arguments)); + TF_RETURN_IF_ERROR(stream.BlockHostUntilDone()); + return std::move(retval); } StatusOr<ScopedShapedBuffer> HloRunner::ExecuteWithDeviceBuffers( @@ -309,6 +313,7 @@ StatusOr<std::vector<std::unique_ptr<Literal>>> HloRunner::ExecuteReplicated( std::vector<std::unique_ptr<Literal>> exec_results; for (int64 i = 0; i < options.num_replicas; ++i) { + TF_RETURN_IF_ERROR(streams[i]->BlockHostUntilDone()); TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> literal, backend().transfer_manager()->TransferLiteralFromDevice( streams[i].get(), results[i])); diff --git a/tensorflow/compiler/xla/tests/local_client_execute_test.cc b/tensorflow/compiler/xla/tests/local_client_execute_test.cc index 77f9c33ee1..8a903f1e6d 100644 --- a/tensorflow/compiler/xla/tests/local_client_execute_test.cc +++ b/tensorflow/compiler/xla/tests/local_client_execute_test.cc @@ -772,6 +772,10 @@ XLA_TEST_F(LocalClientExecuteTest, CompileExecutable) { ScopedShapedBuffer result = executable->Run({&x_array}, DefaultExecutableRunOptions()) .ConsumeValueOrDie(); + ASSERT_IS_OK(local_client_->mutable_backend() + ->BorrowStream(0) + .ValueOrDie() + ->BlockHostUntilDone()); LiteralTestUtil::ExpectR1Near<float>( {2.0f, 4.0f, 6.0f}, *ShapedBufferToLiteral(result), error_spec_); diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.cc b/tensorflow/compiler/xla/tests/local_client_test_base.cc index 88797a7d0a..c31ba0e713 100644 --- a/tensorflow/compiler/xla/tests/local_client_test_base.cc +++ b/tensorflow/compiler/xla/tests/local_client_test_base.cc @@ -189,7 +189,19 @@ StatusOr<ScopedShapedBuffer> LocalClientTestBase::ExecuteLocally( TF_ASSIGN_OR_RETURN( std::unique_ptr<LocalExecutable> executable, local_client_->Compile(computation, argument_layouts, build_options)); - return executable->Run(arguments, run_options); + TF_ASSIGN_OR_RETURN(auto ret, executable->Run(arguments, run_options)); + + auto device_ordinal = + build_options.device_ordinal() == -1 ? 0 : build_options.device_ordinal(); + auto* stream = run_options.stream(); + if (!stream) { + stream = local_client_->mutable_backend() + ->BorrowStream(device_ordinal) + .ValueOrDie() + .get(); + } + TF_RETURN_IF_ERROR(stream->BlockHostUntilDone()); + return std::move(ret); } } // namespace xla diff --git a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc index b081850eb5..e7074915ee 100644 --- a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc +++ b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc @@ -168,6 +168,7 @@ void ExecuteAndFetchProfile(string* profile_output, LocalClient* client, auto execution_result, executable->ExecuteOnStream(&run_options, {&lhs_arg, &rhs_arg}, &hlo_execution_profile)); + TF_ASSERT_OK(stream_ptr->BlockHostUntilDone()); (void)execution_result; *profile_output = diff --git a/tensorflow/stream_executor/host/host_gpu_executor.cc b/tensorflow/stream_executor/host/host_gpu_executor.cc index 2c4819651a..c8a6297330 100644 --- a/tensorflow/stream_executor/host/host_gpu_executor.cc +++ b/tensorflow/stream_executor/host/host_gpu_executor.cc @@ -95,7 +95,7 @@ bool HostExecutor::MemcpyDeviceToDevice(Stream *stream, // the nature of the HostExecutor) memcpy on the stream (HostStream) // associated with the HostExecutor. AsHostStream(stream)->EnqueueTask( - [src_mem, dst_mem, size]() { memcpy(src_mem, dst_mem, size); }); + [src_mem, dst_mem, size]() { memcpy(dst_mem, src_mem, size); }); return true; } |