aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor/cuda/cuda_blas.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/stream_executor/cuda/cuda_blas.h')
-rw-r--r--tensorflow/stream_executor/cuda/cuda_blas.h100
1 files changed, 100 insertions, 0 deletions
diff --git a/tensorflow/stream_executor/cuda/cuda_blas.h b/tensorflow/stream_executor/cuda/cuda_blas.h
new file mode 100644
index 0000000000..1dfec2ebc5
--- /dev/null
+++ b/tensorflow/stream_executor/cuda/cuda_blas.h
@@ -0,0 +1,100 @@
+// CUDA-specific support for BLAS functionality -- this wraps the cuBLAS library
+// capabilities, and is only included into CUDA implementation code -- it will
+// not introduce cuda headers into other code.
+
+#ifndef TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_BLAS_H_
+#define TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_BLAS_H_
+
+#include "tensorflow/stream_executor/blas.h"
+#include "tensorflow/stream_executor/lib/stringpiece.h"
+#include "tensorflow/stream_executor/platform/mutex.h"
+#include "tensorflow/stream_executor/platform/port.h"
+#include "tensorflow/stream_executor/platform/thread_annotations.h"
+#include "tensorflow/stream_executor/plugin_registry.h"
+
+typedef struct cublasContext *cublasHandle_t;
+
+namespace perftools {
+namespace gputools {
+
+class Stream;
+
+namespace cuda {
+
+// Opaque and unique identifier for the cuBLAS plugin.
+extern const PluginId kCuBlasPlugin;
+
+class CUDAExecutor;
+
+// BLAS plugin for CUDA platform via cuBLAS library.
+//
+// This satisfies the platform-agnostic BlasSupport interface.
+//
+// Note that the cuBLAS handle that this encapsulates is implicitly tied to the
+// context (and, as a result, the device) that the parent CUDAExecutor is tied
+// to. This simply happens as an artifact of creating the cuBLAS handle when a
+// CUDA context is active.
+//
+// Thread-safe post-initialization.
+class CUDABlas : public blas::BlasSupport {
+ public:
+ explicit CUDABlas(CUDAExecutor *parent);
+
+ // Allocates a cuBLAS handle.
+ bool Init();
+
+ // Releases the cuBLAS handle, if present.
+ ~CUDABlas() override;
+
+ TENSORFLOW_STREAM_EXECUTOR_GPU_BLAS_SUPPORT_OVERRIDES
+
+ private:
+ // Tells cuBLAS to enqueue the BLAS operation onto a particular Stream.
+ //
+ // cuBLAS is stateful, and only be associated with one stream (in order to
+ // enqueue dispatch) at a given time. As a result, this generally must be
+ // invoked before calling into cuBLAS.
+ bool SetStream(Stream *stream) EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ // A helper function that calls the real cuBLAS function together with error
+ // handling.
+ //
+ // cublas_func: cuBLAS function pointer.
+ // cublas_name: cuBLAS function name.
+ // stream: Stream to enqueue the BLAS operation onto.
+ // pointer_mode_host: Indicate if the pointer to a scalar value is from host
+ // (true) or device (false).
+ // args: Arguments of cuBLAS function.
+ template <typename FuncT, typename... Args>
+ bool DoBlasInternal(FuncT cublas_func, Stream *stream, bool pointer_mode_host,
+ Args... args);
+
+ // A helper function to implement DoBlasGemmBatched interfaces for generic
+ // types.
+ template <typename T, typename FuncT>
+ port::Status DoBlasGemmBatchedInternal(
+ FuncT cublas_func, Stream *stream, blas::Transpose transa,
+ blas::Transpose transb, uint64 m, uint64 n, uint64 k, T alpha,
+ const port::ArraySlice<DeviceMemory<T> *> &a_array, int lda,
+ const port::ArraySlice<DeviceMemory<T> *> &b_array, int ldb, T beta,
+ const port::ArraySlice<DeviceMemory<T> *> &c_array, int ldc,
+ int batch_count);
+
+ // mutex that guards the cuBLAS handle for this device.
+ mutex mu_;
+
+ // CUDAExecutor which instantiated this CUDABlas.
+ // Immutable post-initialization.
+ CUDAExecutor *parent_;
+
+ // cuBLAS library handle on the device.
+ cublasHandle_t blas_ GUARDED_BY(mu_);
+
+ SE_DISALLOW_COPY_AND_ASSIGN(CUDABlas);
+};
+
+} // namespace cuda
+} // namespace gputools
+} // namespace perftools
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_BLAS_H_