aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-07 07:50:21 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-07 07:54:17 -0700
commit335336aa2cdf853d380c3e22ab6694ff78cb487a (patch)
tree17a9686337bd3cd7444f66fb8e796cb01e04bfc6 /tensorflow/stream_executor
parenta1a370cb8b8ef43996a275b64ada81f9cb32e743 (diff)
Implement DoHostCallbackWithStatus to allow callbacks to return a status
PiperOrigin-RevId: 207714420
Diffstat (limited to 'tensorflow/stream_executor')
-rw-r--r--tensorflow/stream_executor/stream.cc13
-rw-r--r--tensorflow/stream_executor/stream.h5
-rw-r--r--tensorflow/stream_executor/stream_executor_internal.cc12
-rw-r--r--tensorflow/stream_executor/stream_executor_internal.h2
-rw-r--r--tensorflow/stream_executor/stream_executor_pimpl.cc5
-rw-r--r--tensorflow/stream_executor/stream_executor_pimpl.h5
6 files changed, 42 insertions, 0 deletions
diff --git a/tensorflow/stream_executor/stream.cc b/tensorflow/stream_executor/stream.cc
index a42a469df5..9efd34de24 100644
--- a/tensorflow/stream_executor/stream.cc
+++ b/tensorflow/stream_executor/stream.cc
@@ -5294,6 +5294,19 @@ Stream &Stream::ThenDoHostCallback(std::function<void()> callback) {
return *this;
}
+Stream &Stream::ThenDoHostCallbackWithStatus(
+ std::function<port::Status()> callback) {
+ VLOG_CALL(PARAM(callback));
+
+ if (ok()) {
+ CheckError(parent_->HostCallback(this, std::move(callback)));
+ } else {
+ LOG(WARNING) << "stream " << DebugStreamPointers()
+ << " was in error state before adding host callback";
+ }
+ return *this;
+}
+
Stream &Stream::ThenFft(fft::Plan *plan,
const DeviceMemory<std::complex<float>> &input,
DeviceMemory<std::complex<float>> *output) {
diff --git a/tensorflow/stream_executor/stream.h b/tensorflow/stream_executor/stream.h
index 4d41409fef..e1629b5b30 100644
--- a/tensorflow/stream_executor/stream.h
+++ b/tensorflow/stream_executor/stream.h
@@ -2045,6 +2045,11 @@ class Stream {
// negative effects on performance.
Stream &ThenDoHostCallback(std::function<void()> callback);
+ // Entrains onto the stream a callback to the host (from the device).
+ // Behaves as ThenDoHostCallback above, but returns a Status instead of void.
+ // This overload should be preferred if the callback could fail.
+ Stream &ThenDoHostCallbackWithStatus(std::function<port::Status()> callback);
+
// Returns the StreamExecutor (parent object) associated with this stream.
StreamExecutor *parent() const {
CHECK(parent_ != nullptr);
diff --git a/tensorflow/stream_executor/stream_executor_internal.cc b/tensorflow/stream_executor/stream_executor_internal.cc
index 8297228e6f..7df6a361c6 100644
--- a/tensorflow/stream_executor/stream_executor_internal.cc
+++ b/tensorflow/stream_executor/stream_executor_internal.cc
@@ -36,5 +36,17 @@ StreamExecutorFactory* MakeOpenCLExecutorImplementation() {
StreamExecutorFactory MakeHostExecutorImplementation;
+// TODO(b/112125301): Consolodate this down to one implementation of
+// HostCallback, taking a callback that returns a Status.
+bool StreamExecutorInterface::HostCallback(
+ Stream* stream, std::function<port::Status()> callback) {
+ return HostCallback(stream, [callback]() {
+ port::Status s = callback();
+ if (!s.ok()) {
+ LOG(WARNING) << "HostCallback failed: " << s;
+ }
+ });
+}
+
} // namespace internal
} // namespace stream_executor
diff --git a/tensorflow/stream_executor/stream_executor_internal.h b/tensorflow/stream_executor/stream_executor_internal.h
index f34b1fc083..92e5376835 100644
--- a/tensorflow/stream_executor/stream_executor_internal.h
+++ b/tensorflow/stream_executor/stream_executor_internal.h
@@ -239,6 +239,8 @@ class StreamExecutorInterface {
const DeviceMemoryBase &host_src,
uint64 size) = 0;
virtual bool HostCallback(Stream *stream, std::function<void()> callback) = 0;
+ virtual bool HostCallback(Stream *stream,
+ std::function<port::Status()> callback);
virtual port::Status AllocateEvent(Event *event) = 0;
virtual port::Status DeallocateEvent(Event *event) = 0;
virtual port::Status RecordEvent(Stream *stream, Event *event) = 0;
diff --git a/tensorflow/stream_executor/stream_executor_pimpl.cc b/tensorflow/stream_executor/stream_executor_pimpl.cc
index 2e0137a485..9515d8e62a 100644
--- a/tensorflow/stream_executor/stream_executor_pimpl.cc
+++ b/tensorflow/stream_executor/stream_executor_pimpl.cc
@@ -699,6 +699,11 @@ bool StreamExecutor::HostCallback(Stream *stream,
return implementation_->HostCallback(stream, std::move(callback));
}
+bool StreamExecutor::HostCallback(Stream *stream,
+ std::function<port::Status()> callback) {
+ return implementation_->HostCallback(stream, std::move(callback));
+}
+
port::Status StreamExecutor::AllocateEvent(Event *event) {
return implementation_->AllocateEvent(event);
}
diff --git a/tensorflow/stream_executor/stream_executor_pimpl.h b/tensorflow/stream_executor/stream_executor_pimpl.h
index 47b3a2b030..437f298616 100644
--- a/tensorflow/stream_executor/stream_executor_pimpl.h
+++ b/tensorflow/stream_executor/stream_executor_pimpl.h
@@ -549,6 +549,11 @@ class StreamExecutor {
// See Stream::ThenDoHostCallback for full details.
bool HostCallback(Stream *stream, std::function<void()> callback);
+ // Entrains on a stream a user-specified function to be run on the host.
+ // See Stream::ThenDoHostCallback for full details.
+ // This is the preferred form for a callback that may return an error.
+ bool HostCallback(Stream *stream, std::function<port::Status()> callback);
+
// Performs platform-specific allocation and initialization of an event.
port::Status AllocateEvent(Event *event);