From 335336aa2cdf853d380c3e22ab6694ff78cb487a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 7 Aug 2018 07:50:21 -0700 Subject: Implement DoHostCallbackWithStatus to allow callbacks to return a status PiperOrigin-RevId: 207714420 --- tensorflow/stream_executor/stream.cc | 13 +++++++++++++ tensorflow/stream_executor/stream.h | 5 +++++ tensorflow/stream_executor/stream_executor_internal.cc | 12 ++++++++++++ tensorflow/stream_executor/stream_executor_internal.h | 2 ++ tensorflow/stream_executor/stream_executor_pimpl.cc | 5 +++++ tensorflow/stream_executor/stream_executor_pimpl.h | 5 +++++ 6 files changed, 42 insertions(+) (limited to 'tensorflow/stream_executor') diff --git a/tensorflow/stream_executor/stream.cc b/tensorflow/stream_executor/stream.cc index a42a469df5..9efd34de24 100644 --- a/tensorflow/stream_executor/stream.cc +++ b/tensorflow/stream_executor/stream.cc @@ -5294,6 +5294,19 @@ Stream &Stream::ThenDoHostCallback(std::function callback) { return *this; } +Stream &Stream::ThenDoHostCallbackWithStatus( + std::function callback) { + VLOG_CALL(PARAM(callback)); + + if (ok()) { + CheckError(parent_->HostCallback(this, std::move(callback))); + } else { + LOG(WARNING) << "stream " << DebugStreamPointers() + << " was in error state before adding host callback"; + } + return *this; +} + Stream &Stream::ThenFft(fft::Plan *plan, const DeviceMemory> &input, DeviceMemory> *output) { diff --git a/tensorflow/stream_executor/stream.h b/tensorflow/stream_executor/stream.h index 4d41409fef..e1629b5b30 100644 --- a/tensorflow/stream_executor/stream.h +++ b/tensorflow/stream_executor/stream.h @@ -2045,6 +2045,11 @@ class Stream { // negative effects on performance. Stream &ThenDoHostCallback(std::function callback); + // Entrains onto the stream a callback to the host (from the device). + // Behaves as ThenDoHostCallback above, but returns a Status instead of void. + // This overload should be preferred if the callback could fail. + Stream &ThenDoHostCallbackWithStatus(std::function callback); + // Returns the StreamExecutor (parent object) associated with this stream. StreamExecutor *parent() const { CHECK(parent_ != nullptr); diff --git a/tensorflow/stream_executor/stream_executor_internal.cc b/tensorflow/stream_executor/stream_executor_internal.cc index 8297228e6f..7df6a361c6 100644 --- a/tensorflow/stream_executor/stream_executor_internal.cc +++ b/tensorflow/stream_executor/stream_executor_internal.cc @@ -36,5 +36,17 @@ StreamExecutorFactory* MakeOpenCLExecutorImplementation() { StreamExecutorFactory MakeHostExecutorImplementation; +// TODO(b/112125301): Consolodate this down to one implementation of +// HostCallback, taking a callback that returns a Status. +bool StreamExecutorInterface::HostCallback( + Stream* stream, std::function callback) { + return HostCallback(stream, [callback]() { + port::Status s = callback(); + if (!s.ok()) { + LOG(WARNING) << "HostCallback failed: " << s; + } + }); +} + } // namespace internal } // namespace stream_executor diff --git a/tensorflow/stream_executor/stream_executor_internal.h b/tensorflow/stream_executor/stream_executor_internal.h index f34b1fc083..92e5376835 100644 --- a/tensorflow/stream_executor/stream_executor_internal.h +++ b/tensorflow/stream_executor/stream_executor_internal.h @@ -239,6 +239,8 @@ class StreamExecutorInterface { const DeviceMemoryBase &host_src, uint64 size) = 0; virtual bool HostCallback(Stream *stream, std::function callback) = 0; + virtual bool HostCallback(Stream *stream, + std::function callback); virtual port::Status AllocateEvent(Event *event) = 0; virtual port::Status DeallocateEvent(Event *event) = 0; virtual port::Status RecordEvent(Stream *stream, Event *event) = 0; diff --git a/tensorflow/stream_executor/stream_executor_pimpl.cc b/tensorflow/stream_executor/stream_executor_pimpl.cc index 2e0137a485..9515d8e62a 100644 --- a/tensorflow/stream_executor/stream_executor_pimpl.cc +++ b/tensorflow/stream_executor/stream_executor_pimpl.cc @@ -699,6 +699,11 @@ bool StreamExecutor::HostCallback(Stream *stream, return implementation_->HostCallback(stream, std::move(callback)); } +bool StreamExecutor::HostCallback(Stream *stream, + std::function callback) { + return implementation_->HostCallback(stream, std::move(callback)); +} + port::Status StreamExecutor::AllocateEvent(Event *event) { return implementation_->AllocateEvent(event); } diff --git a/tensorflow/stream_executor/stream_executor_pimpl.h b/tensorflow/stream_executor/stream_executor_pimpl.h index 47b3a2b030..437f298616 100644 --- a/tensorflow/stream_executor/stream_executor_pimpl.h +++ b/tensorflow/stream_executor/stream_executor_pimpl.h @@ -549,6 +549,11 @@ class StreamExecutor { // See Stream::ThenDoHostCallback for full details. bool HostCallback(Stream *stream, std::function callback); + // Entrains on a stream a user-specified function to be run on the host. + // See Stream::ThenDoHostCallback for full details. + // This is the preferred form for a callback that may return an error. + bool HostCallback(Stream *stream, std::function callback); + // Performs platform-specific allocation and initialization of an event. port::Status AllocateEvent(Event *event); -- cgit v1.2.3