From 9cdcb0397ceed7a43c8e85865a479440b2350102 Mon Sep 17 00:00:00 2001 From: Todd Wang Date: Fri, 3 Aug 2018 15:21:58 -0700 Subject: Drop failed sub-streams during both Get and Return. The old code ensured that failed sub-streams would not be re-used, but had two flaws: 1) It only checked for failed sub-streams during Return. 2) It didn't actually remove the failed sub-streams from our state. The new code fixes these two flaws, and adds an extra test that explains why (1) is insufficient. PiperOrigin-RevId: 207333296 --- tensorflow/stream_executor/stream.cc | 180 ++++++++++++++++++++---------- tensorflow/stream_executor/stream.h | 7 ++ tensorflow/stream_executor/stream_test.cc | 90 ++++++++++++--- 3 files changed, 207 insertions(+), 70 deletions(-) (limited to 'tensorflow/stream_executor') diff --git a/tensorflow/stream_executor/stream.cc b/tensorflow/stream_executor/stream.cc index 6439e3992d..a42a469df5 100644 --- a/tensorflow/stream_executor/stream.cc +++ b/tensorflow/stream_executor/stream.cc @@ -115,7 +115,7 @@ string ToVlogString(const DeviceMemoryBase &memory) { } string ToVlogString(const DeviceMemoryBase *memory) { - return ToVlogString(*memory); + return memory == nullptr ? "null" : ToVlogString(*memory); } string ToVlogString(const Eigen::half &h) { @@ -211,13 +211,14 @@ string CallStr(const char *function_name, Stream *stream, // constructing all the strings in params is expensive. CHECK(VLOG_IS_ON(1)); - string str = port::StrCat("Called Stream::", function_name, "("); + string str = port::StrCat(stream->DebugStreamPointers(), + " Called Stream::", function_name, "("); const char *separator = ""; for (const auto ¶m : params) { port::StrAppend(&str, separator, param.first, "=", param.second); separator = ", "; } - port::StrAppend(&str, ") stream=", ToVlogString(stream)); + port::StrAppend(&str, ")"); if (VLOG_IS_ON(10)) { port::StrAppend(&str, " ", port::CurrentStackTrace(), "\n"); } @@ -1922,37 +1923,82 @@ Stream &Stream::ThenCopyDevice2HostBuffer( Stream *Stream::GetOrCreateSubStream() { mutex_lock lock(mu_); - for (auto &stream : sub_streams_) { - if (stream.second) { - stream.second = false; - return stream.first.get(); + + // Look for the first reusable sub_stream that is ok, dropping !ok sub_streams + // we encounter along the way. + for (int64 index = 0; index < sub_streams_.size();) { + std::pair, bool> &pair = sub_streams_[index]; + if (pair.second) { + // The sub_stream is reusable. + Stream *sub_stream = pair.first.get(); + if (sub_stream->ok()) { + VLOG(1) << DebugStreamPointers() << " reusing sub_stream " + << sub_stream->DebugStreamPointers(); + pair.second = false; + return sub_stream; + } + + // The stream is reusable and not ok. Streams have a monotonic state + // machine; the stream will remain in !ok forever. Swap it with the last + // stream and pop it off. + const int64 last = sub_streams_.size() - 1; + if (index != last) { + std::swap(pair, sub_streams_[last]); + } + sub_streams_.pop_back(); + VLOG(1) << DebugStreamPointers() << " dropped !ok sub_stream " + << sub_stream->DebugStreamPointers(); + } else { + // The sub_stream is not reusable, move on to the next one. + ++index; } } + + // No streams are reusable; create a new stream. sub_streams_.emplace_back(std::unique_ptr{new Stream{parent_}}, false); Stream *sub_stream = sub_streams_.back().first.get(); sub_stream->Init(); CHECK(ok_) << "sub-stream failed to be initialized"; + VLOG(1) << DebugStreamPointers() << " created new sub_stream " + << sub_stream->DebugStreamPointers(); return sub_stream; } void Stream::ReturnSubStream(Stream *sub_stream) { mutex_lock lock(mu_); - for (auto &stream : sub_streams_) { - if (stream.first.get() == sub_stream) { - // Streams have a monotonic state machine; if a stream - // encounters an error, it will remain in an error state - // forever. Only allow re-use of ok streams. - // - // TODO(toddw): Improve this mechanism, if necessary, to drop - // failed streams completely. - const bool ready_to_reuse = sub_stream->ok(); - stream.second = ready_to_reuse; - return; + + // Look for the sub-stream. + for (int64 index = 0; index < sub_streams_.size(); ++index) { + std::pair, bool> &pair = sub_streams_[index]; + if (pair.first.get() != sub_stream) { + continue; } + + // Found the sub_stream. + if (sub_stream->ok()) { + VLOG(1) << DebugStreamPointers() << " returned ok sub_stream " + << sub_stream->DebugStreamPointers(); + pair.second = true; + } else { + // The returned stream is not ok. Streams have a monotonic state + // machine; the stream will remain in !ok forever. Swap it with the last + // stream and pop it off. + VLOG(1) << DebugStreamPointers() << " returned !ok sub_stream " + << sub_stream->DebugStreamPointers(); + const int64 last = sub_streams_.size() - 1; + if (index != last) { + std::swap(pair, sub_streams_[last]); + } + sub_streams_.pop_back(); + } + return; } - LOG(FATAL) << "the sub-stream to be returned is not created by this stream"; + + LOG(FATAL) << DebugStreamPointers() + << " did not create the returned sub-stream " + << sub_stream->DebugStreamPointers(); } Stream &Stream::ThenStartTimer(Timer *t) { @@ -1961,7 +2007,8 @@ Stream &Stream::ThenStartTimer(Timer *t) { if (ok()) { CheckError(parent_->StartTimer(this, t)); } else { - LOG(INFO) << "stream " << this << " did not enqueue 'start timer': " << t; + LOG(INFO) << DebugStreamPointers() + << " did not enqueue 'start timer': " << t; } return *this; } @@ -1972,7 +2019,8 @@ Stream &Stream::ThenStopTimer(Timer *t) { if (ok()) { CheckError(parent_->StopTimer(this, t)); } else { - LOG(INFO) << "stream " << this << " did not enqueue 'stop timer': " << t; + LOG(INFO) << DebugStreamPointers() + << " did not enqueue 'stop timer': " << t; } return *this; } @@ -1985,7 +2033,8 @@ Stream &Stream::ThenWaitFor(Stream *other) { CheckError(parent_->CreateStreamDependency(this, other)); } else { SetError(); - LOG(INFO) << "stream " << this << " did not wait for stream: " << other; + LOG(INFO) << DebugStreamPointers() << " did not wait for " + << other->DebugStreamPointers(); } return *this; } @@ -2002,7 +2051,7 @@ Stream &Stream::ThenWaitFor(Event *event) { << "at fault. Monitor for further errors."; } } else { - LOG(INFO) << "stream " << this << " did not wait for an event."; + LOG(INFO) << DebugStreamPointers() << " did not wait for an event."; } return *this; } @@ -4802,10 +4851,10 @@ Stream &Stream::ThenSetRngSeed(const uint8 *seed, uint64 seed_bytes) { CheckError(rng->SetSeed(this, seed, seed_bytes)); } else { SetError(); - LOG(INFO) << "stream " << this << " unable to initialize RNG"; + LOG(INFO) << DebugStreamPointers() << " unable to initialize RNG"; } } else { - LOG(INFO) << "stream " << this + LOG(INFO) << DebugStreamPointers() << " did not set RNG seed: " << static_cast(seed) << "; bytes: " << seed_bytes; } @@ -4820,8 +4869,9 @@ Stream &Stream::ThenPopulateRandUniform(DeviceMemory *values) { CheckError(rng->DoPopulateRandUniform(this, values)); } else { SetError(); - LOG(INFO) << "attempting to perform RNG operation using StreamExecutor " - "without RNG support."; + LOG(INFO) << DebugStreamPointers() + << " attempting to perform RNG operation using StreamExecutor" + " without RNG support."; } } return *this; @@ -4836,8 +4886,9 @@ Stream &Stream::ThenPopulateRandGaussian(float mean, float sd, CheckError(rng->DoPopulateRandGaussian(this, mean, sd, values)); } else { SetError(); - LOG(INFO) << "attempting to perform RNG operation using StreamExecutor " - "without RNG support."; + LOG(INFO) << DebugStreamPointers() + << " attempting to perform RNG operation using StreamExecutor" + " without RNG support."; } } return *this; @@ -4852,8 +4903,9 @@ Stream &Stream::ThenPopulateRandGaussian(double mean, double sd, CheckError(rng->DoPopulateRandGaussian(this, mean, sd, values)); } else { SetError(); - LOG(INFO) << "attempting to perform RNG operation using StreamExecutor " - "without RNG support."; + LOG(INFO) << DebugStreamPointers() + << " attempting to perform RNG operation using StreamExecutor" + " without RNG support."; } } return *this; @@ -4867,8 +4919,9 @@ Stream &Stream::ThenPopulateRandUniform(DeviceMemory *values) { CheckError(rng->DoPopulateRandUniform(this, values)); } else { SetError(); - LOG(INFO) << "attempting to perform RNG operation using StreamExecutor " - "without RNG support."; + LOG(INFO) << DebugStreamPointers() + << " attempting to perform RNG operation using StreamExecutor" + " without RNG support."; } } return *this; @@ -4883,8 +4936,9 @@ Stream &Stream::ThenPopulateRandUniform( CheckError(rng->DoPopulateRandUniform(this, values)); } else { SetError(); - LOG(INFO) << "attempting to perform RNG operation using StreamExecutor " - "without RNG support."; + LOG(INFO) << DebugStreamPointers() + << " attempting to perform RNG operation using StreamExecutor" + " without RNG support."; } } return *this; @@ -4899,9 +4953,9 @@ Stream &Stream::ThenPopulateRandUniform( CheckError(rng->DoPopulateRandUniform(this, values)); } else { SetError(); - LOG(INFO) << "stream " << this - << " attempting to perform RNG operation using StreamExecutor " - "without RNG support."; + LOG(INFO) << DebugStreamPointers() + << " attempting to perform RNG operation using StreamExecutor" + " without RNG support."; } } return *this; @@ -4914,7 +4968,7 @@ Stream &Stream::ThenMemcpy(void *host_dst, const DeviceMemoryBase &gpu_src, if (ok()) { CheckError(parent_->Memcpy(this, host_dst, gpu_src, size)); } else { - LOG(INFO) << "stream " << this + LOG(INFO) << DebugStreamPointers() << " did not memcpy device-to-host; source: " << gpu_src.opaque(); } return *this; @@ -4927,7 +4981,7 @@ Stream &Stream::ThenMemcpy(DeviceMemoryBase *gpu_dst, const void *host_src, if (ok()) { CheckError(parent_->Memcpy(this, gpu_dst, host_src, size)); } else { - LOG(INFO) << "stream " << this + LOG(INFO) << DebugStreamPointers() << " did not memcpy host-to-device; source: " << host_src; } return *this; @@ -4940,7 +4994,7 @@ Stream &Stream::ThenMemcpy(DeviceMemoryBase *gpu_dst, if (ok()) { CheckError(parent_->MemcpyDeviceToDevice(this, gpu_dst, gpu_src, size)); } else { - LOG(INFO) << "stream " << this + LOG(INFO) << DebugStreamPointers() << " did not memcpy gpu-to-gpu; source: " << &gpu_src; } return *this; @@ -4952,7 +5006,7 @@ Stream &Stream::ThenMemZero(DeviceMemoryBase *location, uint64 size) { if (ok()) { CheckError(parent_->MemZero(this, location, size)); } else { - LOG(INFO) << "stream " << this + LOG(INFO) << DebugStreamPointers() << " did not memzero GPU location; source: " << location; } return *this; @@ -4965,7 +5019,7 @@ Stream &Stream::ThenMemset32(DeviceMemoryBase *location, uint32 pattern, if (ok()) { CheckError(parent_->Memset32(this, location, pattern, size)); } else { - LOG(INFO) << "stream " << this + LOG(INFO) << DebugStreamPointers() << " did not memset GPU location; source: " << location << "; size: " << size << "; pattern: " << std::hex << pattern; } @@ -5234,7 +5288,7 @@ Stream &Stream::ThenDoHostCallback(std::function callback) { if (ok()) { CheckError(parent_->HostCallback(this, callback)); } else { - LOG(INFO) << "stream " << this + LOG(INFO) << DebugStreamPointers() << " was in error state before adding host callback"; } return *this; @@ -5250,8 +5304,9 @@ Stream &Stream::ThenFft(fft::Plan *plan, CheckError(fft->DoFft(this, plan, input, output)); } else { SetError(); - LOG(INFO) << "attempting to perform FFT operation using StreamExecutor " - "without FFT support"; + LOG(INFO) << DebugStreamPointers() + << " attempting to perform FFT operation using StreamExecutor" + " without FFT support"; } } return *this; @@ -5267,8 +5322,9 @@ Stream &Stream::ThenFft(fft::Plan *plan, CheckError(fft->DoFft(this, plan, input, output)); } else { SetError(); - LOG(INFO) << "attempting to perform FFT operation using StreamExecutor " - "without FFT support"; + LOG(INFO) << DebugStreamPointers() + << " attempting to perform FFT operation using StreamExecutor" + " without FFT support"; } } return *this; @@ -5283,8 +5339,9 @@ Stream &Stream::ThenFft(fft::Plan *plan, const DeviceMemory &input, CheckError(fft->DoFft(this, plan, input, output)); } else { SetError(); - LOG(INFO) << "attempting to perform FFT operation using StreamExecutor " - "without FFT support"; + LOG(INFO) << DebugStreamPointers() + << " attempting to perform FFT operation using StreamExecutor" + " without FFT support"; } } return *this; @@ -5299,8 +5356,9 @@ Stream &Stream::ThenFft(fft::Plan *plan, const DeviceMemory &input, CheckError(fft->DoFft(this, plan, input, output)); } else { SetError(); - LOG(INFO) << "attempting to perform FFT operation using StreamExecutor " - "without FFT support"; + LOG(INFO) << DebugStreamPointers() + << " attempting to perform FFT operation using StreamExecutor" + " without FFT support"; } } return *this; @@ -5316,8 +5374,9 @@ Stream &Stream::ThenFft(fft::Plan *plan, CheckError(fft->DoFft(this, plan, input, output)); } else { SetError(); - LOG(INFO) << "attempting to perform FFT operation using StreamExecutor " - "without FFT support"; + LOG(INFO) << DebugStreamPointers() + << " attempting to perform FFT operation using StreamExecutor" + " without FFT support"; } } return *this; @@ -5333,8 +5392,9 @@ Stream &Stream::ThenFft(fft::Plan *plan, CheckError(fft->DoFft(this, plan, input, output)); } else { SetError(); - LOG(INFO) << "attempting to perform FFT operation using StreamExecutor " - "without FFT support"; + LOG(INFO) << DebugStreamPointers() + << " attempting to perform FFT operation using StreamExecutor" + " without FFT support"; } } return *this; @@ -5361,7 +5421,7 @@ port::Status Stream::BlockHostUntilDone() { port::Status status = port::Status( port::error::INTERNAL, "stream did not block host until done; was already in an error state"); - LOG(INFO) << status << " " << this; + LOG(INFO) << DebugStreamPointers() << " " << status; return status; } @@ -5372,4 +5432,10 @@ port::Status Stream::BlockHostUntilDone() { return error; } +string Stream::DebugStreamPointers() const { + // Relies on the ToVlogString(const void*) overload above. + return port::StrCat("[stream=", ToVlogString(this), + ",impl=", ToVlogString(implementation_.get()), "]"); +} + } // namespace stream_executor diff --git a/tensorflow/stream_executor/stream.h b/tensorflow/stream_executor/stream.h index 62d0a2062d..4d41409fef 100644 --- a/tensorflow/stream_executor/stream.h +++ b/tensorflow/stream_executor/stream.h @@ -122,10 +122,14 @@ class Stream { // Get or create a sub-stream from this stream. If there is any sub-stream in // the pool that can be reused then just return this sub-stream. Otherwise // create a new sub-stream. + // + // TODO(b/112196569): The semantics of failed sub-streams is error-prone. Stream *GetOrCreateSubStream() LOCKS_EXCLUDED(mu_); // Return the sub-stream back to the host stream so that it can be reused // later. Sub-streams that are !ok() will not be reused. + // + // TODO(b/112196569): The semantics of failed sub-streams is error-prone. void ReturnSubStream(Stream *sub_stream) LOCKS_EXCLUDED(mu_); // Allocate temporary memories. The stream will deallocate them when blocked @@ -2051,6 +2055,9 @@ class Stream { // with this stream. internal::TemporaryMemoryManager *temporary_memory_manager(); + // Returns a debugging string "[stream=0x...,impl=0x...]". + string DebugStreamPointers() const; + private: friend class host::HostBlas; // for parent_. friend class host::HostFft; // for parent_. diff --git a/tensorflow/stream_executor/stream_test.cc b/tensorflow/stream_executor/stream_test.cc index 47dd675834..cfc051fd09 100644 --- a/tensorflow/stream_executor/stream_test.cc +++ b/tensorflow/stream_executor/stream_test.cc @@ -95,18 +95,18 @@ TEST_F(StreamTest, TwoSubStreams) { EXPECT_NE(sub_stream3, sub_stream4); } -TEST_F(StreamTest, FailedSubStreamNotReused) { +TEST_F(StreamTest, FailedSubStreamBeforeReturnNotReused) { std::unique_ptr executor = NewStreamExecutor(); Stream stream(executor.get()); stream.Init(); EXPECT_TRUE(stream.ok()); - // Get a sub-stream. + // Get sub_stream1. Stream* sub_stream1 = stream.GetOrCreateSubStream(); EXPECT_TRUE(sub_stream1->ok()); - // Force an error on the stream; here we call a method that requires - // DNN support, which we know the Host platform doesn't support. + // Force an error on sub_stream1; here we call a method that requires DNN + // support, which we know the Host platform doesn't support. sub_stream1->ThenDepthConcatenate({}, {}, nullptr); EXPECT_FALSE(sub_stream1->ok()); @@ -115,20 +115,84 @@ TEST_F(StreamTest, FailedSubStreamNotReused) { Stream* sub_stream2 = stream.GetOrCreateSubStream(); EXPECT_TRUE(sub_stream2->ok()); - // The underlying streams should be different. They would have been - // the same, but since we forced an error on sub_stream1, it will - // not be re-used. Sadly we can't just check: + // The underlying sub_streams should be different. They would have been the + // same, but since we forced an error on sub_stream1, it will not be + // re-used. Sadly we can't just check: // EXPECT_NE(sub_stream1, sub_stream2); // - // The above should hold logically, but it may fail if the new - // stream instance allocated for sub_stream2 happens to reside in - // the same memory address as sub_stream1. + // The above should hold logically, but it may fail if the new Stream instance + // allocated for sub_stream2 happens to reside in the same memory address as + // sub_stream1. // // The check that sub_stream2->ok() serves as a good-enough check. - // Return sub_stream2 and get sub_stream3. The previous error on - // sub_stream1 has no effect on these streams, and they are the - // same. + // Return sub_stream2 and get sub_stream3. The previous error on sub_stream1 + // has no effect on these streams, and they are the same. + stream.ReturnSubStream(sub_stream2); + Stream* sub_stream3 = stream.GetOrCreateSubStream(); + EXPECT_TRUE(sub_stream3->ok()); + EXPECT_EQ(sub_stream2, sub_stream3); +} + +TEST_F(StreamTest, FailedSubStreamAfterReturnNotReused) { + std::unique_ptr executor = NewStreamExecutor(); + Stream stream(executor.get()); + stream.Init(); + EXPECT_TRUE(stream.ok()); + + // Get and return sub_stream1. + Stream* sub_stream1 = stream.GetOrCreateSubStream(); + EXPECT_TRUE(sub_stream1->ok()); + stream.ReturnSubStream(sub_stream1); + + // Force an error on sub_stream1; here we call a method that requires DNN + // support, which we know the Host platform doesn't support. + // + // It is a bit weird to use sub_stream1 after it has already been returned. By + // doing this, we're simulating an asynchronous error that occurs during + // execution of the sub_stream, that occurs after the sub_stream is returned. + // + // E.g. the following is a common pattern of usage, where the execution of the + // operations enqueued onto the sub streams may occur after the streams have + // already been returned. + // + // void EnqueueOnSubStreams(Stream* stream) { + // Stream* sub_stream1 = stream.GetOrCreateSubStream(); + // Stream* sub_stream2 = stream.GetOrCreateSubStream(); + // // ... enqueue some operations on the sub streams ... + // stream.ThenWaitFor(sub_stream1).ThenWaitFor(sub_stream2); + // stream.ReturnSubStream(sub_stream1); + // stream.ReturnSubStream(sub_stream2); + // } + // + // Stream* main_stream = ...; + // EnqueueOnSubStreams(main_stream); + // main_stream.BlockHostUntilDone(); + // + // TODO(b/112196569): The semantics of failed sub-streams is error-prone; + // GetOrCreateSubStream can still return a sub-stream that has not encountered + // an error yet, but will encounter one in the future, based on previously + // enqueued operations. + sub_stream1->ThenDepthConcatenate({}, {}, nullptr); + EXPECT_FALSE(sub_stream1->ok()); + + // Get and return sub_stream2. + Stream* sub_stream2 = stream.GetOrCreateSubStream(); + EXPECT_TRUE(sub_stream2->ok()); + + // The underlying streams should be different. They would have been the same, + // but since we forced an error on sub_stream1, it will not be re-used. Sadly + // we can't just check: + // EXPECT_NE(sub_stream1, sub_stream2); + // + // The above should hold logically, but it may fail if the new stream instance + // allocated for sub_stream2 happens to reside in the same memory address as + // sub_stream1. + // + // The check that sub_stream2->ok() serves as a good-enough check. + + // Return sub_stream2 and get sub_stream3. The previous error on sub_stream1 + // has no effect on these streams, and they are the same. stream.ReturnSubStream(sub_stream2); Stream* sub_stream3 = stream.GetOrCreateSubStream(); EXPECT_TRUE(sub_stream3->ok()); -- cgit v1.2.3