diff options
Diffstat (limited to 'tensorflow/stream_executor/stream.cc')
-rw-r--r-- | tensorflow/stream_executor/stream.cc | 180 |
1 files changed, 123 insertions, 57 deletions
diff --git a/tensorflow/stream_executor/stream.cc b/tensorflow/stream_executor/stream.cc index 6439e3992d..a42a469df5 100644 --- a/tensorflow/stream_executor/stream.cc +++ b/tensorflow/stream_executor/stream.cc @@ -115,7 +115,7 @@ string ToVlogString(const DeviceMemoryBase &memory) { } string ToVlogString(const DeviceMemoryBase *memory) { - return ToVlogString(*memory); + return memory == nullptr ? "null" : ToVlogString(*memory); } string ToVlogString(const Eigen::half &h) { @@ -211,13 +211,14 @@ string CallStr(const char *function_name, Stream *stream, // constructing all the strings in params is expensive. CHECK(VLOG_IS_ON(1)); - string str = port::StrCat("Called Stream::", function_name, "("); + string str = port::StrCat(stream->DebugStreamPointers(), + " Called Stream::", function_name, "("); const char *separator = ""; for (const auto ¶m : params) { port::StrAppend(&str, separator, param.first, "=", param.second); separator = ", "; } - port::StrAppend(&str, ") stream=", ToVlogString(stream)); + port::StrAppend(&str, ")"); if (VLOG_IS_ON(10)) { port::StrAppend(&str, " ", port::CurrentStackTrace(), "\n"); } @@ -1922,37 +1923,82 @@ Stream &Stream::ThenCopyDevice2HostBuffer( Stream *Stream::GetOrCreateSubStream() { mutex_lock lock(mu_); - for (auto &stream : sub_streams_) { - if (stream.second) { - stream.second = false; - return stream.first.get(); + + // Look for the first reusable sub_stream that is ok, dropping !ok sub_streams + // we encounter along the way. + for (int64 index = 0; index < sub_streams_.size();) { + std::pair<std::unique_ptr<Stream>, bool> &pair = sub_streams_[index]; + if (pair.second) { + // The sub_stream is reusable. + Stream *sub_stream = pair.first.get(); + if (sub_stream->ok()) { + VLOG(1) << DebugStreamPointers() << " reusing sub_stream " + << sub_stream->DebugStreamPointers(); + pair.second = false; + return sub_stream; + } + + // The stream is reusable and not ok. Streams have a monotonic state + // machine; the stream will remain in !ok forever. Swap it with the last + // stream and pop it off. + const int64 last = sub_streams_.size() - 1; + if (index != last) { + std::swap(pair, sub_streams_[last]); + } + sub_streams_.pop_back(); + VLOG(1) << DebugStreamPointers() << " dropped !ok sub_stream " + << sub_stream->DebugStreamPointers(); + } else { + // The sub_stream is not reusable, move on to the next one. + ++index; } } + + // No streams are reusable; create a new stream. sub_streams_.emplace_back(std::unique_ptr<Stream>{new Stream{parent_}}, false); Stream *sub_stream = sub_streams_.back().first.get(); sub_stream->Init(); CHECK(ok_) << "sub-stream failed to be initialized"; + VLOG(1) << DebugStreamPointers() << " created new sub_stream " + << sub_stream->DebugStreamPointers(); return sub_stream; } void Stream::ReturnSubStream(Stream *sub_stream) { mutex_lock lock(mu_); - for (auto &stream : sub_streams_) { - if (stream.first.get() == sub_stream) { - // Streams have a monotonic state machine; if a stream - // encounters an error, it will remain in an error state - // forever. Only allow re-use of ok streams. - // - // TODO(toddw): Improve this mechanism, if necessary, to drop - // failed streams completely. - const bool ready_to_reuse = sub_stream->ok(); - stream.second = ready_to_reuse; - return; + + // Look for the sub-stream. + for (int64 index = 0; index < sub_streams_.size(); ++index) { + std::pair<std::unique_ptr<Stream>, bool> &pair = sub_streams_[index]; + if (pair.first.get() != sub_stream) { + continue; } + + // Found the sub_stream. + if (sub_stream->ok()) { + VLOG(1) << DebugStreamPointers() << " returned ok sub_stream " + << sub_stream->DebugStreamPointers(); + pair.second = true; + } else { + // The returned stream is not ok. Streams have a monotonic state + // machine; the stream will remain in !ok forever. Swap it with the last + // stream and pop it off. + VLOG(1) << DebugStreamPointers() << " returned !ok sub_stream " + << sub_stream->DebugStreamPointers(); + const int64 last = sub_streams_.size() - 1; + if (index != last) { + std::swap(pair, sub_streams_[last]); + } + sub_streams_.pop_back(); + } + return; } - LOG(FATAL) << "the sub-stream to be returned is not created by this stream"; + + LOG(FATAL) << DebugStreamPointers() + << " did not create the returned sub-stream " + << sub_stream->DebugStreamPointers(); } Stream &Stream::ThenStartTimer(Timer *t) { @@ -1961,7 +2007,8 @@ Stream &Stream::ThenStartTimer(Timer *t) { if (ok()) { CheckError(parent_->StartTimer(this, t)); } else { - LOG(INFO) << "stream " << this << " did not enqueue 'start timer': " << t; + LOG(INFO) << DebugStreamPointers() + << " did not enqueue 'start timer': " << t; } return *this; } @@ -1972,7 +2019,8 @@ Stream &Stream::ThenStopTimer(Timer *t) { if (ok()) { CheckError(parent_->StopTimer(this, t)); } else { - LOG(INFO) << "stream " << this << " did not enqueue 'stop timer': " << t; + LOG(INFO) << DebugStreamPointers() + << " did not enqueue 'stop timer': " << t; } return *this; } @@ -1985,7 +2033,8 @@ Stream &Stream::ThenWaitFor(Stream *other) { CheckError(parent_->CreateStreamDependency(this, other)); } else { SetError(); - LOG(INFO) << "stream " << this << " did not wait for stream: " << other; + LOG(INFO) << DebugStreamPointers() << " did not wait for " + << other->DebugStreamPointers(); } return *this; } @@ -2002,7 +2051,7 @@ Stream &Stream::ThenWaitFor(Event *event) { << "at fault. Monitor for further errors."; } } else { - LOG(INFO) << "stream " << this << " did not wait for an event."; + LOG(INFO) << DebugStreamPointers() << " did not wait for an event."; } return *this; } @@ -4802,10 +4851,10 @@ Stream &Stream::ThenSetRngSeed(const uint8 *seed, uint64 seed_bytes) { CheckError(rng->SetSeed(this, seed, seed_bytes)); } else { SetError(); - LOG(INFO) << "stream " << this << " unable to initialize RNG"; + LOG(INFO) << DebugStreamPointers() << " unable to initialize RNG"; } } else { - LOG(INFO) << "stream " << this + LOG(INFO) << DebugStreamPointers() << " did not set RNG seed: " << static_cast<const void *>(seed) << "; bytes: " << seed_bytes; } @@ -4820,8 +4869,9 @@ Stream &Stream::ThenPopulateRandUniform(DeviceMemory<float> *values) { CheckError(rng->DoPopulateRandUniform(this, values)); } else { SetError(); - LOG(INFO) << "attempting to perform RNG operation using StreamExecutor " - "without RNG support."; + LOG(INFO) << DebugStreamPointers() + << " attempting to perform RNG operation using StreamExecutor" + " without RNG support."; } } return *this; @@ -4836,8 +4886,9 @@ Stream &Stream::ThenPopulateRandGaussian(float mean, float sd, CheckError(rng->DoPopulateRandGaussian(this, mean, sd, values)); } else { SetError(); - LOG(INFO) << "attempting to perform RNG operation using StreamExecutor " - "without RNG support."; + LOG(INFO) << DebugStreamPointers() + << " attempting to perform RNG operation using StreamExecutor" + " without RNG support."; } } return *this; @@ -4852,8 +4903,9 @@ Stream &Stream::ThenPopulateRandGaussian(double mean, double sd, CheckError(rng->DoPopulateRandGaussian(this, mean, sd, values)); } else { SetError(); - LOG(INFO) << "attempting to perform RNG operation using StreamExecutor " - "without RNG support."; + LOG(INFO) << DebugStreamPointers() + << " attempting to perform RNG operation using StreamExecutor" + " without RNG support."; } } return *this; @@ -4867,8 +4919,9 @@ Stream &Stream::ThenPopulateRandUniform(DeviceMemory<double> *values) { CheckError(rng->DoPopulateRandUniform(this, values)); } else { SetError(); - LOG(INFO) << "attempting to perform RNG operation using StreamExecutor " - "without RNG support."; + LOG(INFO) << DebugStreamPointers() + << " attempting to perform RNG operation using StreamExecutor" + " without RNG support."; } } return *this; @@ -4883,8 +4936,9 @@ Stream &Stream::ThenPopulateRandUniform( CheckError(rng->DoPopulateRandUniform(this, values)); } else { SetError(); - LOG(INFO) << "attempting to perform RNG operation using StreamExecutor " - "without RNG support."; + LOG(INFO) << DebugStreamPointers() + << " attempting to perform RNG operation using StreamExecutor" + " without RNG support."; } } return *this; @@ -4899,9 +4953,9 @@ Stream &Stream::ThenPopulateRandUniform( CheckError(rng->DoPopulateRandUniform(this, values)); } else { SetError(); - LOG(INFO) << "stream " << this - << " attempting to perform RNG operation using StreamExecutor " - "without RNG support."; + LOG(INFO) << DebugStreamPointers() + << " attempting to perform RNG operation using StreamExecutor" + " without RNG support."; } } return *this; @@ -4914,7 +4968,7 @@ Stream &Stream::ThenMemcpy(void *host_dst, const DeviceMemoryBase &gpu_src, if (ok()) { CheckError(parent_->Memcpy(this, host_dst, gpu_src, size)); } else { - LOG(INFO) << "stream " << this + LOG(INFO) << DebugStreamPointers() << " did not memcpy device-to-host; source: " << gpu_src.opaque(); } return *this; @@ -4927,7 +4981,7 @@ Stream &Stream::ThenMemcpy(DeviceMemoryBase *gpu_dst, const void *host_src, if (ok()) { CheckError(parent_->Memcpy(this, gpu_dst, host_src, size)); } else { - LOG(INFO) << "stream " << this + LOG(INFO) << DebugStreamPointers() << " did not memcpy host-to-device; source: " << host_src; } return *this; @@ -4940,7 +4994,7 @@ Stream &Stream::ThenMemcpy(DeviceMemoryBase *gpu_dst, if (ok()) { CheckError(parent_->MemcpyDeviceToDevice(this, gpu_dst, gpu_src, size)); } else { - LOG(INFO) << "stream " << this + LOG(INFO) << DebugStreamPointers() << " did not memcpy gpu-to-gpu; source: " << &gpu_src; } return *this; @@ -4952,7 +5006,7 @@ Stream &Stream::ThenMemZero(DeviceMemoryBase *location, uint64 size) { if (ok()) { CheckError(parent_->MemZero(this, location, size)); } else { - LOG(INFO) << "stream " << this + LOG(INFO) << DebugStreamPointers() << " did not memzero GPU location; source: " << location; } return *this; @@ -4965,7 +5019,7 @@ Stream &Stream::ThenMemset32(DeviceMemoryBase *location, uint32 pattern, if (ok()) { CheckError(parent_->Memset32(this, location, pattern, size)); } else { - LOG(INFO) << "stream " << this + LOG(INFO) << DebugStreamPointers() << " did not memset GPU location; source: " << location << "; size: " << size << "; pattern: " << std::hex << pattern; } @@ -5234,7 +5288,7 @@ Stream &Stream::ThenDoHostCallback(std::function<void()> callback) { if (ok()) { CheckError(parent_->HostCallback(this, callback)); } else { - LOG(INFO) << "stream " << this + LOG(INFO) << DebugStreamPointers() << " was in error state before adding host callback"; } return *this; @@ -5250,8 +5304,9 @@ Stream &Stream::ThenFft(fft::Plan *plan, CheckError(fft->DoFft(this, plan, input, output)); } else { SetError(); - LOG(INFO) << "attempting to perform FFT operation using StreamExecutor " - "without FFT support"; + LOG(INFO) << DebugStreamPointers() + << " attempting to perform FFT operation using StreamExecutor" + " without FFT support"; } } return *this; @@ -5267,8 +5322,9 @@ Stream &Stream::ThenFft(fft::Plan *plan, CheckError(fft->DoFft(this, plan, input, output)); } else { SetError(); - LOG(INFO) << "attempting to perform FFT operation using StreamExecutor " - "without FFT support"; + LOG(INFO) << DebugStreamPointers() + << " attempting to perform FFT operation using StreamExecutor" + " without FFT support"; } } return *this; @@ -5283,8 +5339,9 @@ Stream &Stream::ThenFft(fft::Plan *plan, const DeviceMemory<float> &input, CheckError(fft->DoFft(this, plan, input, output)); } else { SetError(); - LOG(INFO) << "attempting to perform FFT operation using StreamExecutor " - "without FFT support"; + LOG(INFO) << DebugStreamPointers() + << " attempting to perform FFT operation using StreamExecutor" + " without FFT support"; } } return *this; @@ -5299,8 +5356,9 @@ Stream &Stream::ThenFft(fft::Plan *plan, const DeviceMemory<double> &input, CheckError(fft->DoFft(this, plan, input, output)); } else { SetError(); - LOG(INFO) << "attempting to perform FFT operation using StreamExecutor " - "without FFT support"; + LOG(INFO) << DebugStreamPointers() + << " attempting to perform FFT operation using StreamExecutor" + " without FFT support"; } } return *this; @@ -5316,8 +5374,9 @@ Stream &Stream::ThenFft(fft::Plan *plan, CheckError(fft->DoFft(this, plan, input, output)); } else { SetError(); - LOG(INFO) << "attempting to perform FFT operation using StreamExecutor " - "without FFT support"; + LOG(INFO) << DebugStreamPointers() + << " attempting to perform FFT operation using StreamExecutor" + " without FFT support"; } } return *this; @@ -5333,8 +5392,9 @@ Stream &Stream::ThenFft(fft::Plan *plan, CheckError(fft->DoFft(this, plan, input, output)); } else { SetError(); - LOG(INFO) << "attempting to perform FFT operation using StreamExecutor " - "without FFT support"; + LOG(INFO) << DebugStreamPointers() + << " attempting to perform FFT operation using StreamExecutor" + " without FFT support"; } } return *this; @@ -5361,7 +5421,7 @@ port::Status Stream::BlockHostUntilDone() { port::Status status = port::Status( port::error::INTERNAL, "stream did not block host until done; was already in an error state"); - LOG(INFO) << status << " " << this; + LOG(INFO) << DebugStreamPointers() << " " << status; return status; } @@ -5372,4 +5432,10 @@ port::Status Stream::BlockHostUntilDone() { return error; } +string Stream::DebugStreamPointers() const { + // Relies on the ToVlogString(const void*) overload above. + return port::StrCat("[stream=", ToVlogString(this), + ",impl=", ToVlogString(implementation_.get()), "]"); +} + } // namespace stream_executor |