aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor/stream.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-07 07:50:21 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-07 07:54:17 -0700
commit335336aa2cdf853d380c3e22ab6694ff78cb487a (patch)
tree17a9686337bd3cd7444f66fb8e796cb01e04bfc6 /tensorflow/stream_executor/stream.cc
parenta1a370cb8b8ef43996a275b64ada81f9cb32e743 (diff)
Implement DoHostCallbackWithStatus to allow callbacks to return a status
PiperOrigin-RevId: 207714420
Diffstat (limited to 'tensorflow/stream_executor/stream.cc')
-rw-r--r--tensorflow/stream_executor/stream.cc13
1 files changed, 13 insertions, 0 deletions
diff --git a/tensorflow/stream_executor/stream.cc b/tensorflow/stream_executor/stream.cc
index a42a469df5..9efd34de24 100644
--- a/tensorflow/stream_executor/stream.cc
+++ b/tensorflow/stream_executor/stream.cc
@@ -5294,6 +5294,19 @@ Stream &Stream::ThenDoHostCallback(std::function<void()> callback) {
return *this;
}
+Stream &Stream::ThenDoHostCallbackWithStatus(
+ std::function<port::Status()> callback) {
+ VLOG_CALL(PARAM(callback));
+
+ if (ok()) {
+ CheckError(parent_->HostCallback(this, std::move(callback)));
+ } else {
+ LOG(WARNING) << "stream " << DebugStreamPointers()
+ << " was in error state before adding host callback";
+ }
+ return *this;
+}
+
Stream &Stream::ThenFft(fft::Plan *plan,
const DeviceMemory<std::complex<float>> &input,
DeviceMemory<std::complex<float>> *output) {