aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-12-19 17:54:19 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-20 10:59:57 -0800
commit1988732f81bc5f61cd97c20952d5359fc0bf627f (patch)
tree159a855c36535cc72cd32992bec622f808b533e4 /tensorflow/compiler/xla/service/gpu/gpu_executable.cc
parentd064a47543f51ff5a62927a76bb0fb0862d05558 (diff)
[XLA:GPU] Make the use of scratch allocator in convolution_thunk safe.
Add member function Thunk::ShouldFutureScheduledThunksDependOn for convolution_thunk to tell thunk executor that all future scheduled thunks should wait for convolution_thunk. This can ensure that the use of scratch allocator in convolution_thunk is safe. PiperOrigin-RevId: 179628764
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/gpu_executable.cc')
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_executable.cc13
1 files changed, 12 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
index 366d87e9c3..df6c6668c3 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
@@ -155,6 +155,7 @@ Status GpuExecutable::ExecuteThunks(
run_options->BorrowStream(main_stream->parent()->device_ordinal()));
}
+ std::map<int32, const Thunk*> last_blocking_thunk_for_stream;
std::map<const Thunk*, std::unique_ptr<se::Event>> thunk_to_finish_event;
for (Thunk* thunk : thunk_schedule_->TotalOrder()) {
TF_RETURN_IF_ERROR(thunk->Initialize(*this));
@@ -167,10 +168,17 @@ Status GpuExecutable::ExecuteThunks(
stream->ThenWaitFor(FindOrDie(thunk_to_finish_event, dependency).get());
}
+ if (last_blocking_thunk_for_stream.count(stream_no)) {
+ stream->ThenWaitFor(FindOrDie(thunk_to_finish_event,
+ last_blocking_thunk_for_stream[stream_no])
+ .get());
+ }
+
// If this thunk requests it, wait for all currently-executing thunks to
// finish. This is useful e.g. if the thunk is about to perform autotuning.
if (thunk->ShouldHaltAllActivityBeforeRunning(stream)) {
TF_RETURN_IF_ERROR(main_stream->BlockHostUntilDone());
+ last_blocking_thunk_for_stream.clear();
}
profiler.StartOperation();
@@ -178,11 +186,14 @@ Status GpuExecutable::ExecuteThunks(
<< thunk->hlo_instruction()->ToString() << " on stream "
<< stream_no;
TF_RETURN_IF_ERROR(thunk->ExecuteOnStream(buffer_allocations, stream));
- if (thunk_schedule_->Depended(thunk)) {
+ if (thunk_schedule_->Depended(thunk) || thunk->ShouldBlockFutureThunks()) {
auto finish_event = MakeUnique<se::Event>(main_stream->parent());
finish_event->Init();
stream->ThenRecordEvent(finish_event.get());
thunk_to_finish_event[thunk] = std::move(finish_event);
+ if (thunk->ShouldBlockFutureThunks()) {
+ last_blocking_thunk_for_stream[stream_no] = thunk;
+ }
}
profiler.FinishOperation(thunk->hlo_instruction());
}