diff options
Diffstat (limited to 'tensorflow/stream_executor/cuda/cuda_blas.h')
-rw-r--r-- | tensorflow/stream_executor/cuda/cuda_blas.h | 100 |
1 files changed, 100 insertions, 0 deletions
diff --git a/tensorflow/stream_executor/cuda/cuda_blas.h b/tensorflow/stream_executor/cuda/cuda_blas.h new file mode 100644 index 0000000000..1dfec2ebc5 --- /dev/null +++ b/tensorflow/stream_executor/cuda/cuda_blas.h @@ -0,0 +1,100 @@ +// CUDA-specific support for BLAS functionality -- this wraps the cuBLAS library +// capabilities, and is only included into CUDA implementation code -- it will +// not introduce cuda headers into other code. + +#ifndef TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_BLAS_H_ +#define TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_BLAS_H_ + +#include "tensorflow/stream_executor/blas.h" +#include "tensorflow/stream_executor/lib/stringpiece.h" +#include "tensorflow/stream_executor/platform/mutex.h" +#include "tensorflow/stream_executor/platform/port.h" +#include "tensorflow/stream_executor/platform/thread_annotations.h" +#include "tensorflow/stream_executor/plugin_registry.h" + +typedef struct cublasContext *cublasHandle_t; + +namespace perftools { +namespace gputools { + +class Stream; + +namespace cuda { + +// Opaque and unique identifier for the cuBLAS plugin. +extern const PluginId kCuBlasPlugin; + +class CUDAExecutor; + +// BLAS plugin for CUDA platform via cuBLAS library. +// +// This satisfies the platform-agnostic BlasSupport interface. +// +// Note that the cuBLAS handle that this encapsulates is implicitly tied to the +// context (and, as a result, the device) that the parent CUDAExecutor is tied +// to. This simply happens as an artifact of creating the cuBLAS handle when a +// CUDA context is active. +// +// Thread-safe post-initialization. +class CUDABlas : public blas::BlasSupport { + public: + explicit CUDABlas(CUDAExecutor *parent); + + // Allocates a cuBLAS handle. + bool Init(); + + // Releases the cuBLAS handle, if present. + ~CUDABlas() override; + + TENSORFLOW_STREAM_EXECUTOR_GPU_BLAS_SUPPORT_OVERRIDES + + private: + // Tells cuBLAS to enqueue the BLAS operation onto a particular Stream. + // + // cuBLAS is stateful, and only be associated with one stream (in order to + // enqueue dispatch) at a given time. As a result, this generally must be + // invoked before calling into cuBLAS. + bool SetStream(Stream *stream) EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // A helper function that calls the real cuBLAS function together with error + // handling. + // + // cublas_func: cuBLAS function pointer. + // cublas_name: cuBLAS function name. + // stream: Stream to enqueue the BLAS operation onto. + // pointer_mode_host: Indicate if the pointer to a scalar value is from host + // (true) or device (false). + // args: Arguments of cuBLAS function. + template <typename FuncT, typename... Args> + bool DoBlasInternal(FuncT cublas_func, Stream *stream, bool pointer_mode_host, + Args... args); + + // A helper function to implement DoBlasGemmBatched interfaces for generic + // types. + template <typename T, typename FuncT> + port::Status DoBlasGemmBatchedInternal( + FuncT cublas_func, Stream *stream, blas::Transpose transa, + blas::Transpose transb, uint64 m, uint64 n, uint64 k, T alpha, + const port::ArraySlice<DeviceMemory<T> *> &a_array, int lda, + const port::ArraySlice<DeviceMemory<T> *> &b_array, int ldb, T beta, + const port::ArraySlice<DeviceMemory<T> *> &c_array, int ldc, + int batch_count); + + // mutex that guards the cuBLAS handle for this device. + mutex mu_; + + // CUDAExecutor which instantiated this CUDABlas. + // Immutable post-initialization. + CUDAExecutor *parent_; + + // cuBLAS library handle on the device. + cublasHandle_t blas_ GUARDED_BY(mu_); + + SE_DISALLOW_COPY_AND_ASSIGN(CUDABlas); +}; + +} // namespace cuda +} // namespace gputools +} // namespace perftools + +#endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_BLAS_H_ |