// CUDA-specific support for BLAS functionality -- this wraps the cuBLAS library // capabilities, and is only included into CUDA implementation code -- it will // not introduce cuda headers into other code. #ifndef TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_BLAS_H_ #define TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_BLAS_H_ #include "tensorflow/stream_executor/blas.h" #include "tensorflow/stream_executor/lib/stringpiece.h" #include "tensorflow/stream_executor/platform/mutex.h" #include "tensorflow/stream_executor/platform/port.h" #include "tensorflow/stream_executor/platform/thread_annotations.h" #include "tensorflow/stream_executor/plugin_registry.h" typedef struct cublasContext *cublasHandle_t; namespace perftools { namespace gputools { class Stream; namespace cuda { // Opaque and unique identifier for the cuBLAS plugin. extern const PluginId kCuBlasPlugin; class CUDAExecutor; // BLAS plugin for CUDA platform via cuBLAS library. // // This satisfies the platform-agnostic BlasSupport interface. // // Note that the cuBLAS handle that this encapsulates is implicitly tied to the // context (and, as a result, the device) that the parent CUDAExecutor is tied // to. This simply happens as an artifact of creating the cuBLAS handle when a // CUDA context is active. // // Thread-safe post-initialization. class CUDABlas : public blas::BlasSupport { public: explicit CUDABlas(CUDAExecutor *parent); // Allocates a cuBLAS handle. bool Init(); // Releases the cuBLAS handle, if present. ~CUDABlas() override; TENSORFLOW_STREAM_EXECUTOR_GPU_BLAS_SUPPORT_OVERRIDES private: // Tells cuBLAS to enqueue the BLAS operation onto a particular Stream. // // cuBLAS is stateful, and only be associated with one stream (in order to // enqueue dispatch) at a given time. As a result, this generally must be // invoked before calling into cuBLAS. bool SetStream(Stream *stream) EXCLUSIVE_LOCKS_REQUIRED(mu_); // A helper function that calls the real cuBLAS function together with error // handling. // // cublas_func: cuBLAS function pointer. // cublas_name: cuBLAS function name. // stream: Stream to enqueue the BLAS operation onto. // pointer_mode_host: Indicate if the pointer to a scalar value is from host // (true) or device (false). // args: Arguments of cuBLAS function. template bool DoBlasInternal(FuncT cublas_func, Stream *stream, bool pointer_mode_host, Args... args); // A helper function to implement DoBlasGemmBatched interfaces for generic // types. template port::Status DoBlasGemmBatchedInternal( FuncT cublas_func, Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, uint64 k, T alpha, const port::ArraySlice *> &a_array, int lda, const port::ArraySlice *> &b_array, int ldb, T beta, const port::ArraySlice *> &c_array, int ldc, int batch_count); // mutex that guards the cuBLAS handle for this device. mutex mu_; // CUDAExecutor which instantiated this CUDABlas. // Immutable post-initialization. CUDAExecutor *parent_; // cuBLAS library handle on the device. cublasHandle_t blas_ GUARDED_BY(mu_); SE_DISALLOW_COPY_AND_ASSIGN(CUDABlas); }; } // namespace cuda } // namespace gputools } // namespace perftools #endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_BLAS_H_