diff options
Diffstat (limited to 'tensorflow/stream_executor/stream.cc')
-rw-r--r-- | tensorflow/stream_executor/stream.cc | 297 |
1 files changed, 246 insertions, 51 deletions
diff --git a/tensorflow/stream_executor/stream.cc b/tensorflow/stream_executor/stream.cc index 2c495c99e1..9efd34de24 100644 --- a/tensorflow/stream_executor/stream.cc +++ b/tensorflow/stream_executor/stream.cc @@ -115,7 +115,7 @@ string ToVlogString(const DeviceMemoryBase &memory) { } string ToVlogString(const DeviceMemoryBase *memory) { - return ToVlogString(*memory); + return memory == nullptr ? "null" : ToVlogString(*memory); } string ToVlogString(const Eigen::half &h) { @@ -211,13 +211,14 @@ string CallStr(const char *function_name, Stream *stream, // constructing all the strings in params is expensive. CHECK(VLOG_IS_ON(1)); - string str = port::StrCat("Called Stream::", function_name, "("); + string str = port::StrCat(stream->DebugStreamPointers(), + " Called Stream::", function_name, "("); const char *separator = ""; for (const auto ¶m : params) { port::StrAppend(&str, separator, param.first, "=", param.second); separator = ", "; } - port::StrAppend(&str, ") stream=", ToVlogString(stream)); + port::StrAppend(&str, ")"); if (VLOG_IS_ON(10)) { port::StrAppend(&str, " ", port::CurrentStackTrace(), "\n"); } @@ -267,13 +268,13 @@ Stream::Stream(StreamExecutor *parent, Stream::~Stream() { VLOG_CALL(); - temporary_memory_manager_.ForceDeallocateAll(); // Ensure the stream is completed. auto status = BlockHostUntilDone(); if (!status.ok()) { LOG(WARNING) << "Error blocking host until done in stream destructor: " << status; } + temporary_memory_manager_.ForceDeallocateAll(); if (allocated_) { parent_->DeallocateStream(this); @@ -1922,30 +1923,82 @@ Stream &Stream::ThenCopyDevice2HostBuffer( Stream *Stream::GetOrCreateSubStream() { mutex_lock lock(mu_); - for (auto &stream : sub_streams_) { - if (stream.second) { - stream.second = false; - return stream.first.get(); + + // Look for the first reusable sub_stream that is ok, dropping !ok sub_streams + // we encounter along the way. + for (int64 index = 0; index < sub_streams_.size();) { + std::pair<std::unique_ptr<Stream>, bool> &pair = sub_streams_[index]; + if (pair.second) { + // The sub_stream is reusable. + Stream *sub_stream = pair.first.get(); + if (sub_stream->ok()) { + VLOG(1) << DebugStreamPointers() << " reusing sub_stream " + << sub_stream->DebugStreamPointers(); + pair.second = false; + return sub_stream; + } + + // The stream is reusable and not ok. Streams have a monotonic state + // machine; the stream will remain in !ok forever. Swap it with the last + // stream and pop it off. + const int64 last = sub_streams_.size() - 1; + if (index != last) { + std::swap(pair, sub_streams_[last]); + } + sub_streams_.pop_back(); + VLOG(1) << DebugStreamPointers() << " dropped !ok sub_stream " + << sub_stream->DebugStreamPointers(); + } else { + // The sub_stream is not reusable, move on to the next one. + ++index; } } + + // No streams are reusable; create a new stream. sub_streams_.emplace_back(std::unique_ptr<Stream>{new Stream{parent_}}, false); Stream *sub_stream = sub_streams_.back().first.get(); sub_stream->Init(); CHECK(ok_) << "sub-stream failed to be initialized"; + VLOG(1) << DebugStreamPointers() << " created new sub_stream " + << sub_stream->DebugStreamPointers(); return sub_stream; } void Stream::ReturnSubStream(Stream *sub_stream) { mutex_lock lock(mu_); - for (auto &stream : sub_streams_) { - if (stream.first.get() == sub_stream) { - stream.second = true; - return; + + // Look for the sub-stream. + for (int64 index = 0; index < sub_streams_.size(); ++index) { + std::pair<std::unique_ptr<Stream>, bool> &pair = sub_streams_[index]; + if (pair.first.get() != sub_stream) { + continue; } + + // Found the sub_stream. + if (sub_stream->ok()) { + VLOG(1) << DebugStreamPointers() << " returned ok sub_stream " + << sub_stream->DebugStreamPointers(); + pair.second = true; + } else { + // The returned stream is not ok. Streams have a monotonic state + // machine; the stream will remain in !ok forever. Swap it with the last + // stream and pop it off. + VLOG(1) << DebugStreamPointers() << " returned !ok sub_stream " + << sub_stream->DebugStreamPointers(); + const int64 last = sub_streams_.size() - 1; + if (index != last) { + std::swap(pair, sub_streams_[last]); + } + sub_streams_.pop_back(); + } + return; } - LOG(FATAL) << "the sub-stream to be returned is not created by this stream"; + + LOG(FATAL) << DebugStreamPointers() + << " did not create the returned sub-stream " + << sub_stream->DebugStreamPointers(); } Stream &Stream::ThenStartTimer(Timer *t) { @@ -1954,7 +2007,8 @@ Stream &Stream::ThenStartTimer(Timer *t) { if (ok()) { CheckError(parent_->StartTimer(this, t)); } else { - LOG(INFO) << "stream " << this << " did not enqueue 'start timer': " << t; + LOG(INFO) << DebugStreamPointers() + << " did not enqueue 'start timer': " << t; } return *this; } @@ -1965,7 +2019,8 @@ Stream &Stream::ThenStopTimer(Timer *t) { if (ok()) { CheckError(parent_->StopTimer(this, t)); } else { - LOG(INFO) << "stream " << this << " did not enqueue 'stop timer': " << t; + LOG(INFO) << DebugStreamPointers() + << " did not enqueue 'stop timer': " << t; } return *this; } @@ -1978,7 +2033,8 @@ Stream &Stream::ThenWaitFor(Stream *other) { CheckError(parent_->CreateStreamDependency(this, other)); } else { SetError(); - LOG(INFO) << "stream " << this << " did not wait for stream: " << other; + LOG(INFO) << DebugStreamPointers() << " did not wait for " + << other->DebugStreamPointers(); } return *this; } @@ -1995,7 +2051,7 @@ Stream &Stream::ThenWaitFor(Event *event) { << "at fault. Monitor for further errors."; } } else { - LOG(INFO) << "stream " << this << " did not wait for an event."; + LOG(INFO) << DebugStreamPointers() << " did not wait for an event."; } return *this; } @@ -4678,6 +4734,115 @@ Stream &Stream::ThenBlasGemmBatchedWithScratch( scratch_allocator); } +Stream &Stream::ThenBlasGemmStridedBatched( + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, float alpha, const DeviceMemory<Eigen::half> &a, int lda, + int64 stride_a, const DeviceMemory<Eigen::half> &b, int ldb, int64 stride_b, + float beta, DeviceMemory<Eigen::half> *c, int ldc, int64 stride_c, + int batch_count) { + VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), + PARAM(alpha), PARAM(a), PARAM(lda), PARAM(stride_a), PARAM(b), + PARAM(ldb), PARAM(stride_b), PARAM(beta), PARAM(c), PARAM(ldc), + PARAM(stride_c), PARAM(batch_count)); + + ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, float, + const DeviceMemory<Eigen::half> &, int, int64, + const DeviceMemory<Eigen::half> &, int, int64, float, + DeviceMemory<Eigen::half> *, int, int64, int> + impl; + return impl(this, &blas::BlasSupport::DoBlasGemmStridedBatched, transa, + transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, + c, ldc, stride_c, batch_count); +} + +Stream &Stream::ThenBlasGemmStridedBatched( + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, float alpha, const DeviceMemory<float> &a, int lda, + int64 stride_a, const DeviceMemory<float> &b, int ldb, int64 stride_b, + float beta, DeviceMemory<float> *c, int ldc, int64 stride_c, + int batch_count) { + VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), + PARAM(alpha), PARAM(a), PARAM(lda), PARAM(stride_a), PARAM(b), + PARAM(ldb), PARAM(stride_b), PARAM(beta), PARAM(c), PARAM(ldc), + PARAM(stride_c), PARAM(batch_count)); + + ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, float, + const DeviceMemory<float> &, int, int64, + const DeviceMemory<float> &, int, int64, float, + DeviceMemory<float> *, int, int64, int> + impl; + return impl(this, &blas::BlasSupport::DoBlasGemmStridedBatched, transa, + transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, + c, ldc, stride_c, batch_count); +} + +Stream &Stream::ThenBlasGemmStridedBatched( + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, double alpha, const DeviceMemory<double> &a, int lda, + int64 stride_a, const DeviceMemory<double> &b, int ldb, int64 stride_b, + double beta, DeviceMemory<double> *c, int ldc, int64 stride_c, + int batch_count) { + VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), + PARAM(alpha), PARAM(a), PARAM(lda), PARAM(stride_a), PARAM(b), + PARAM(ldb), PARAM(stride_b), PARAM(beta), PARAM(c), PARAM(ldc), + PARAM(stride_c), PARAM(batch_count)); + + ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, double, + const DeviceMemory<double> &, int, int64, + const DeviceMemory<double> &, int, int64, double, + DeviceMemory<double> *, int, int64, int> + impl; + return impl(this, &blas::BlasSupport::DoBlasGemmStridedBatched, transa, + transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, + c, ldc, stride_c, batch_count); +} + +Stream &Stream::ThenBlasGemmStridedBatched( + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, std::complex<float> alpha, + const DeviceMemory<std::complex<float>> &a, int lda, int64 stride_a, + const DeviceMemory<std::complex<float>> &b, int ldb, int64 stride_b, + std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc, + int64 stride_c, int batch_count) { + VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), + PARAM(alpha), PARAM(a), PARAM(lda), PARAM(stride_a), PARAM(b), + PARAM(ldb), PARAM(stride_b), PARAM(beta), PARAM(c), PARAM(ldc), + PARAM(stride_c), PARAM(batch_count)); + + ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, + std::complex<float>, const DeviceMemory<std::complex<float>> &, + int, int64, const DeviceMemory<std::complex<float>> &, int, + int64, std::complex<float>, DeviceMemory<std::complex<float>> *, + int, int64, int> + impl; + return impl(this, &blas::BlasSupport::DoBlasGemmStridedBatched, transa, + transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, + c, ldc, stride_c, batch_count); +} + +Stream &Stream::ThenBlasGemmStridedBatched( + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, std::complex<double> alpha, + const DeviceMemory<std::complex<double>> &a, int lda, int64 stride_a, + const DeviceMemory<std::complex<double>> &b, int ldb, int64 stride_b, + std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc, + int64 stride_c, int batch_count) { + VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), + PARAM(alpha), PARAM(a), PARAM(lda), PARAM(stride_a), PARAM(b), + PARAM(ldb), PARAM(stride_b), PARAM(beta), PARAM(c), PARAM(ldc), + PARAM(stride_c), PARAM(batch_count)); + + ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, + std::complex<double>, const DeviceMemory<std::complex<double>> &, + int, int64, const DeviceMemory<std::complex<double>> &, int, + int64, std::complex<double>, + DeviceMemory<std::complex<double>> *, int, int64, int> + impl; + return impl(this, &blas::BlasSupport::DoBlasGemmStridedBatched, transa, + transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, + c, ldc, stride_c, batch_count); +} + Stream &Stream::ThenSetRngSeed(const uint8 *seed, uint64 seed_bytes) { VLOG_CALL(PARAM(seed), PARAM(seed_bytes)); @@ -4686,10 +4851,10 @@ Stream &Stream::ThenSetRngSeed(const uint8 *seed, uint64 seed_bytes) { CheckError(rng->SetSeed(this, seed, seed_bytes)); } else { SetError(); - LOG(INFO) << "stream " << this << " unable to initialize RNG"; + LOG(INFO) << DebugStreamPointers() << " unable to initialize RNG"; } } else { - LOG(INFO) << "stream " << this + LOG(INFO) << DebugStreamPointers() << " did not set RNG seed: " << static_cast<const void *>(seed) << "; bytes: " << seed_bytes; } @@ -4704,8 +4869,9 @@ Stream &Stream::ThenPopulateRandUniform(DeviceMemory<float> *values) { CheckError(rng->DoPopulateRandUniform(this, values)); } else { SetError(); - LOG(INFO) << "attempting to perform RNG operation using StreamExecutor " - "without RNG support."; + LOG(INFO) << DebugStreamPointers() + << " attempting to perform RNG operation using StreamExecutor" + " without RNG support."; } } return *this; @@ -4720,8 +4886,9 @@ Stream &Stream::ThenPopulateRandGaussian(float mean, float sd, CheckError(rng->DoPopulateRandGaussian(this, mean, sd, values)); } else { SetError(); - LOG(INFO) << "attempting to perform RNG operation using StreamExecutor " - "without RNG support."; + LOG(INFO) << DebugStreamPointers() + << " attempting to perform RNG operation using StreamExecutor" + " without RNG support."; } } return *this; @@ -4736,8 +4903,9 @@ Stream &Stream::ThenPopulateRandGaussian(double mean, double sd, CheckError(rng->DoPopulateRandGaussian(this, mean, sd, values)); } else { SetError(); - LOG(INFO) << "attempting to perform RNG operation using StreamExecutor " - "without RNG support."; + LOG(INFO) << DebugStreamPointers() + << " attempting to perform RNG operation using StreamExecutor" + " without RNG support."; } } return *this; @@ -4751,8 +4919,9 @@ Stream &Stream::ThenPopulateRandUniform(DeviceMemory<double> *values) { CheckError(rng->DoPopulateRandUniform(this, values)); } else { SetError(); - LOG(INFO) << "attempting to perform RNG operation using StreamExecutor " - "without RNG support."; + LOG(INFO) << DebugStreamPointers() + << " attempting to perform RNG operation using StreamExecutor" + " without RNG support."; } } return *this; @@ -4767,8 +4936,9 @@ Stream &Stream::ThenPopulateRandUniform( CheckError(rng->DoPopulateRandUniform(this, values)); } else { SetError(); - LOG(INFO) << "attempting to perform RNG operation using StreamExecutor " - "without RNG support."; + LOG(INFO) << DebugStreamPointers() + << " attempting to perform RNG operation using StreamExecutor" + " without RNG support."; } } return *this; @@ -4783,9 +4953,9 @@ Stream &Stream::ThenPopulateRandUniform( CheckError(rng->DoPopulateRandUniform(this, values)); } else { SetError(); - LOG(INFO) << "stream " << this - << " attempting to perform RNG operation using StreamExecutor " - "without RNG support."; + LOG(INFO) << DebugStreamPointers() + << " attempting to perform RNG operation using StreamExecutor" + " without RNG support."; } } return *this; @@ -4798,7 +4968,7 @@ Stream &Stream::ThenMemcpy(void *host_dst, const DeviceMemoryBase &gpu_src, if (ok()) { CheckError(parent_->Memcpy(this, host_dst, gpu_src, size)); } else { - LOG(INFO) << "stream " << this + LOG(INFO) << DebugStreamPointers() << " did not memcpy device-to-host; source: " << gpu_src.opaque(); } return *this; @@ -4811,7 +4981,7 @@ Stream &Stream::ThenMemcpy(DeviceMemoryBase *gpu_dst, const void *host_src, if (ok()) { CheckError(parent_->Memcpy(this, gpu_dst, host_src, size)); } else { - LOG(INFO) << "stream " << this + LOG(INFO) << DebugStreamPointers() << " did not memcpy host-to-device; source: " << host_src; } return *this; @@ -4824,7 +4994,7 @@ Stream &Stream::ThenMemcpy(DeviceMemoryBase *gpu_dst, if (ok()) { CheckError(parent_->MemcpyDeviceToDevice(this, gpu_dst, gpu_src, size)); } else { - LOG(INFO) << "stream " << this + LOG(INFO) << DebugStreamPointers() << " did not memcpy gpu-to-gpu; source: " << &gpu_src; } return *this; @@ -4836,7 +5006,7 @@ Stream &Stream::ThenMemZero(DeviceMemoryBase *location, uint64 size) { if (ok()) { CheckError(parent_->MemZero(this, location, size)); } else { - LOG(INFO) << "stream " << this + LOG(INFO) << DebugStreamPointers() << " did not memzero GPU location; source: " << location; } return *this; @@ -4849,7 +5019,7 @@ Stream &Stream::ThenMemset32(DeviceMemoryBase *location, uint32 pattern, if (ok()) { CheckError(parent_->Memset32(this, location, pattern, size)); } else { - LOG(INFO) << "stream " << this + LOG(INFO) << DebugStreamPointers() << " did not memset GPU location; source: " << location << "; size: " << size << "; pattern: " << std::hex << pattern; } @@ -5118,12 +5288,25 @@ Stream &Stream::ThenDoHostCallback(std::function<void()> callback) { if (ok()) { CheckError(parent_->HostCallback(this, callback)); } else { - LOG(INFO) << "stream " << this + LOG(INFO) << DebugStreamPointers() << " was in error state before adding host callback"; } return *this; } +Stream &Stream::ThenDoHostCallbackWithStatus( + std::function<port::Status()> callback) { + VLOG_CALL(PARAM(callback)); + + if (ok()) { + CheckError(parent_->HostCallback(this, std::move(callback))); + } else { + LOG(WARNING) << "stream " << DebugStreamPointers() + << " was in error state before adding host callback"; + } + return *this; +} + Stream &Stream::ThenFft(fft::Plan *plan, const DeviceMemory<std::complex<float>> &input, DeviceMemory<std::complex<float>> *output) { @@ -5134,8 +5317,9 @@ Stream &Stream::ThenFft(fft::Plan *plan, CheckError(fft->DoFft(this, plan, input, output)); } else { SetError(); - LOG(INFO) << "attempting to perform FFT operation using StreamExecutor " - "without FFT support"; + LOG(INFO) << DebugStreamPointers() + << " attempting to perform FFT operation using StreamExecutor" + " without FFT support"; } } return *this; @@ -5151,8 +5335,9 @@ Stream &Stream::ThenFft(fft::Plan *plan, CheckError(fft->DoFft(this, plan, input, output)); } else { SetError(); - LOG(INFO) << "attempting to perform FFT operation using StreamExecutor " - "without FFT support"; + LOG(INFO) << DebugStreamPointers() + << " attempting to perform FFT operation using StreamExecutor" + " without FFT support"; } } return *this; @@ -5167,8 +5352,9 @@ Stream &Stream::ThenFft(fft::Plan *plan, const DeviceMemory<float> &input, CheckError(fft->DoFft(this, plan, input, output)); } else { SetError(); - LOG(INFO) << "attempting to perform FFT operation using StreamExecutor " - "without FFT support"; + LOG(INFO) << DebugStreamPointers() + << " attempting to perform FFT operation using StreamExecutor" + " without FFT support"; } } return *this; @@ -5183,8 +5369,9 @@ Stream &Stream::ThenFft(fft::Plan *plan, const DeviceMemory<double> &input, CheckError(fft->DoFft(this, plan, input, output)); } else { SetError(); - LOG(INFO) << "attempting to perform FFT operation using StreamExecutor " - "without FFT support"; + LOG(INFO) << DebugStreamPointers() + << " attempting to perform FFT operation using StreamExecutor" + " without FFT support"; } } return *this; @@ -5200,8 +5387,9 @@ Stream &Stream::ThenFft(fft::Plan *plan, CheckError(fft->DoFft(this, plan, input, output)); } else { SetError(); - LOG(INFO) << "attempting to perform FFT operation using StreamExecutor " - "without FFT support"; + LOG(INFO) << DebugStreamPointers() + << " attempting to perform FFT operation using StreamExecutor" + " without FFT support"; } } return *this; @@ -5217,8 +5405,9 @@ Stream &Stream::ThenFft(fft::Plan *plan, CheckError(fft->DoFft(this, plan, input, output)); } else { SetError(); - LOG(INFO) << "attempting to perform FFT operation using StreamExecutor " - "without FFT support"; + LOG(INFO) << DebugStreamPointers() + << " attempting to perform FFT operation using StreamExecutor" + " without FFT support"; } } return *this; @@ -5245,7 +5434,7 @@ port::Status Stream::BlockHostUntilDone() { port::Status status = port::Status( port::error::INTERNAL, "stream did not block host until done; was already in an error state"); - LOG(INFO) << status << " " << this; + LOG(INFO) << DebugStreamPointers() << " " << status; return status; } @@ -5256,4 +5445,10 @@ port::Status Stream::BlockHostUntilDone() { return error; } +string Stream::DebugStreamPointers() const { + // Relies on the ToVlogString(const void*) overload above. + return port::StrCat("[stream=", ToVlogString(this), + ",impl=", ToVlogString(implementation_.get()), "]"); +} + } // namespace stream_executor |