diff options
author | 2017-12-19 17:54:19 -0800 | |
---|---|---|
committer | 2017-12-20 10:59:57 -0800 | |
commit | 1988732f81bc5f61cd97c20952d5359fc0bf627f (patch) | |
tree | 159a855c36535cc72cd32992bec622f808b533e4 /tensorflow/compiler/xla/service/gpu/gpu_executable.cc | |
parent | d064a47543f51ff5a62927a76bb0fb0862d05558 (diff) |
[XLA:GPU] Make the use of scratch allocator in convolution_thunk safe.
Add member function Thunk::ShouldFutureScheduledThunksDependOn for
convolution_thunk to tell thunk executor that all future scheduled thunks
should wait for convolution_thunk. This can ensure that the use of scratch
allocator in convolution_thunk is safe.
PiperOrigin-RevId: 179628764
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/gpu_executable.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/gpu/gpu_executable.cc | 13 |
1 files changed, 12 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc index 366d87e9c3..df6c6668c3 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc @@ -155,6 +155,7 @@ Status GpuExecutable::ExecuteThunks( run_options->BorrowStream(main_stream->parent()->device_ordinal())); } + std::map<int32, const Thunk*> last_blocking_thunk_for_stream; std::map<const Thunk*, std::unique_ptr<se::Event>> thunk_to_finish_event; for (Thunk* thunk : thunk_schedule_->TotalOrder()) { TF_RETURN_IF_ERROR(thunk->Initialize(*this)); @@ -167,10 +168,17 @@ Status GpuExecutable::ExecuteThunks( stream->ThenWaitFor(FindOrDie(thunk_to_finish_event, dependency).get()); } + if (last_blocking_thunk_for_stream.count(stream_no)) { + stream->ThenWaitFor(FindOrDie(thunk_to_finish_event, + last_blocking_thunk_for_stream[stream_no]) + .get()); + } + // If this thunk requests it, wait for all currently-executing thunks to // finish. This is useful e.g. if the thunk is about to perform autotuning. if (thunk->ShouldHaltAllActivityBeforeRunning(stream)) { TF_RETURN_IF_ERROR(main_stream->BlockHostUntilDone()); + last_blocking_thunk_for_stream.clear(); } profiler.StartOperation(); @@ -178,11 +186,14 @@ Status GpuExecutable::ExecuteThunks( << thunk->hlo_instruction()->ToString() << " on stream " << stream_no; TF_RETURN_IF_ERROR(thunk->ExecuteOnStream(buffer_allocations, stream)); - if (thunk_schedule_->Depended(thunk)) { + if (thunk_schedule_->Depended(thunk) || thunk->ShouldBlockFutureThunks()) { auto finish_event = MakeUnique<se::Event>(main_stream->parent()); finish_event->Init(); stream->ThenRecordEvent(finish_event.get()); thunk_to_finish_event[thunk] = std::move(finish_event); + if (thunk->ShouldBlockFutureThunks()) { + last_blocking_thunk_for_stream[stream_no] = thunk; + } } profiler.FinishOperation(thunk->hlo_instruction()); } |