aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor/host/host_stream.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/stream_executor/host/host_stream.cc')
-rw-r--r--tensorflow/stream_executor/host/host_stream.cc26
1 files changed, 18 insertions, 8 deletions
diff --git a/tensorflow/stream_executor/host/host_stream.cc b/tensorflow/stream_executor/host/host_stream.cc
index 5a7d3b3dd4..bfbfb56cd7 100644
--- a/tensorflow/stream_executor/host/host_stream.cc
+++ b/tensorflow/stream_executor/host/host_stream.cc
@@ -28,18 +28,28 @@ HostStream::HostStream()
HostStream::~HostStream() {}
bool HostStream::EnqueueTask(std::function<void()> task) {
+ struct NotifiedTask {
+ HostStream* stream;
+ std::function<void()> task;
+
+ void operator()() {
+ task();
+ // Destroy the task before unblocking its waiters, as BlockHostUntilDone()
+ // should guarantee that all tasks are destroyed.
+ task = std::function<void()>();
+ {
+ mutex_lock lock(stream->mu_);
+ --stream->pending_tasks_;
+ }
+ stream->completion_condition_.notify_all();
+ }
+ };
+
{
mutex_lock lock(mu_);
++pending_tasks_;
}
- host_executor_->Schedule([this, task]() {
- task();
- {
- mutex_lock lock(mu_);
- --pending_tasks_;
- }
- completion_condition_.notify_all();
- });
+ host_executor_->Schedule(NotifiedTask{this, std::move(task)});
return true;
}