diff options
Diffstat (limited to 'tensorflow/stream_executor/stream.h')
-rw-r--r-- | tensorflow/stream_executor/stream.h | 67 |
1 files changed, 59 insertions, 8 deletions
diff --git a/tensorflow/stream_executor/stream.h b/tensorflow/stream_executor/stream.h index e8885e1eb6..e1629b5b30 100644 --- a/tensorflow/stream_executor/stream.h +++ b/tensorflow/stream_executor/stream.h @@ -122,10 +122,14 @@ class Stream { // Get or create a sub-stream from this stream. If there is any sub-stream in // the pool that can be reused then just return this sub-stream. Otherwise // create a new sub-stream. + // + // TODO(b/112196569): The semantics of failed sub-streams is error-prone. Stream *GetOrCreateSubStream() LOCKS_EXCLUDED(mu_); // Return the sub-stream back to the host stream so that it can be reused - // later. + // later. Sub-streams that are !ok() will not be reused. + // + // TODO(b/112196569): The semantics of failed sub-streams is error-prone. void ReturnSubStream(Stream *sub_stream) LOCKS_EXCLUDED(mu_); // Allocate temporary memories. The stream will deallocate them when blocked @@ -629,19 +633,22 @@ class Stream { const dnn::BatchDescriptor &input_dimensions, const DeviceMemory<double> &input_data, const dnn::BatchDescriptor &output_dimensions, - DeviceMemory<double> *output_data); + DeviceMemory<double> *output_data, + ScratchAllocator *workspace_allocator = nullptr); Stream &ThenPoolForward(const dnn::PoolingDescriptor &pooling_dimensions, const dnn::BatchDescriptor &input_dimensions, const DeviceMemory<float> &input_data, const dnn::BatchDescriptor &output_dimensions, - DeviceMemory<float> *output_data); + DeviceMemory<float> *output_data, + ScratchAllocator *workspace_allocator = nullptr); Stream &ThenPoolForward(const dnn::PoolingDescriptor &pooling_dimensions, const dnn::BatchDescriptor &input_dimensions, const DeviceMemory<Eigen::half> &input_data, const dnn::BatchDescriptor &output_dimensions, - DeviceMemory<Eigen::half> *output_data); + DeviceMemory<Eigen::half> *output_data, + ScratchAllocator *workspace_allocator = nullptr); Stream &ThenPoolBackward(const dnn::PoolingDescriptor &pooling_dimensions, const dnn::BatchDescriptor &input_dimensions, @@ -649,7 +656,8 @@ class Stream { const dnn::BatchDescriptor &output_dimensions, const DeviceMemory<double> &output_data, const DeviceMemory<double> &input_diff_data, - DeviceMemory<double> *output_diff_data); + DeviceMemory<double> *output_diff_data, + ScratchAllocator *workspace_allocator = nullptr); Stream &ThenPoolBackward(const dnn::PoolingDescriptor &pooling_dimensions, const dnn::BatchDescriptor &input_dimensions, @@ -657,7 +665,8 @@ class Stream { const dnn::BatchDescriptor &output_dimensions, const DeviceMemory<float> &output_data, const DeviceMemory<float> &input_diff_data, - DeviceMemory<float> *output_diff_data); + DeviceMemory<float> *output_diff_data, + ScratchAllocator *workspace_allocator = nullptr); Stream &ThenPoolBackward(const dnn::PoolingDescriptor &pooling_dimensions, const dnn::BatchDescriptor &input_dimensions, @@ -665,7 +674,8 @@ class Stream { const dnn::BatchDescriptor &output_dimensions, const DeviceMemory<Eigen::half> &output_data, const DeviceMemory<Eigen::half> &input_diff_data, - DeviceMemory<Eigen::half> *output_diff_data); + DeviceMemory<Eigen::half> *output_diff_data, + ScratchAllocator *workspace_allocator = nullptr); Stream &ThenNormalize(const dnn::NormalizeDescriptor &normalize_descriptor, const DeviceMemory<float> &input_data, @@ -684,7 +694,8 @@ class Stream { const DeviceMemory<float> &raw_data, const DeviceMemory<float> &normalized_data, const DeviceMemory<float> &normalized_variable_gradient, - DeviceMemory<float> *raw_variable_gradient); + DeviceMemory<float> *raw_variable_gradient, + ScratchAllocator *workspace_allocator = nullptr); Stream &ThenActivate(dnn::ActivationMode activation_mode, const dnn::BatchDescriptor &dimensions, @@ -1550,6 +1561,38 @@ class Stream { std::complex<double> beta, const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, int ldc, int batch_count, ScratchAllocator *scratch_allocator); + Stream &ThenBlasGemmStridedBatched( + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, float alpha, const DeviceMemory<Eigen::half> &a, int lda, + int64 stride_a, const DeviceMemory<Eigen::half> &b, int ldb, + int64 stride_b, float beta, DeviceMemory<Eigen::half> *c, int ldc, + int64 stride_c, int batch_count); + Stream &ThenBlasGemmStridedBatched( + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, float alpha, const DeviceMemory<float> &a, int lda, + int64 stride_a, const DeviceMemory<float> &b, int ldb, int64 stride_b, + float beta, DeviceMemory<float> *c, int ldc, int64 stride_c, + int batch_count); + Stream &ThenBlasGemmStridedBatched( + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, double alpha, const DeviceMemory<double> &a, int lda, + int64 stride_a, const DeviceMemory<double> &b, int ldb, int64 stride_b, + double beta, DeviceMemory<double> *c, int ldc, int64 stride_c, + int batch_count); + Stream &ThenBlasGemmStridedBatched( + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, std::complex<float> alpha, + const DeviceMemory<std::complex<float>> &a, int lda, int64 stride_a, + const DeviceMemory<std::complex<float>> &b, int ldb, int64 stride_b, + std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc, + int64 stride_c, int batch_count); + Stream &ThenBlasGemmStridedBatched( + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, std::complex<double> alpha, + const DeviceMemory<std::complex<double>> &a, int lda, int64 stride_a, + const DeviceMemory<std::complex<double>> &b, int ldb, int64 stride_b, + std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc, + int64 stride_c, int batch_count); // See BlasSupport::DoBlasHemm. Stream &ThenBlasHemm(blas::Side side, blas::UpperLower uplo, uint64 m, @@ -2002,6 +2045,11 @@ class Stream { // negative effects on performance. Stream &ThenDoHostCallback(std::function<void()> callback); + // Entrains onto the stream a callback to the host (from the device). + // Behaves as ThenDoHostCallback above, but returns a Status instead of void. + // This overload should be preferred if the callback could fail. + Stream &ThenDoHostCallbackWithStatus(std::function<port::Status()> callback); + // Returns the StreamExecutor (parent object) associated with this stream. StreamExecutor *parent() const { CHECK(parent_ != nullptr); @@ -2012,6 +2060,9 @@ class Stream { // with this stream. internal::TemporaryMemoryManager *temporary_memory_manager(); + // Returns a debugging string "[stream=0x...,impl=0x...]". + string DebugStreamPointers() const; + private: friend class host::HostBlas; // for parent_. friend class host::HostFft; // for parent_. |