diff options
Diffstat (limited to 'tensorflow/stream_executor/cuda/cuda_blas.cc')
-rw-r--r-- | tensorflow/stream_executor/cuda/cuda_blas.cc | 102 |
1 files changed, 101 insertions, 1 deletions
diff --git a/tensorflow/stream_executor/cuda/cuda_blas.cc b/tensorflow/stream_executor/cuda/cuda_blas.cc index 6988389f29..e5b5f7fc94 100644 --- a/tensorflow/stream_executor/cuda/cuda_blas.cc +++ b/tensorflow/stream_executor/cuda/cuda_blas.cc @@ -279,6 +279,10 @@ STREAM_EXECUTOR_CUBLAS_WRAP(cublasSgemmEx) #if CUDA_VERSION >= 8000 STREAM_EXECUTOR_CUBLAS_WRAP(cublasGemmEx) +STREAM_EXECUTOR_CUBLAS_WRAP(cublasSgemmStridedBatched) +STREAM_EXECUTOR_CUBLAS_WRAP(cublasDgemmStridedBatched) +STREAM_EXECUTOR_CUBLAS_WRAP(cublasCgemmStridedBatched) +STREAM_EXECUTOR_CUBLAS_WRAP(cublasZgemmStridedBatched) #endif #if CUDA_VERSION >= 9000 @@ -1865,7 +1869,7 @@ bool CUDABlas::DoBlasGemm( stream->parent()->GetDeviceDescription().cuda_compute_capability(&cc_major, &cc_minor); - // GPUs < sm_70 don't support Volta hardware. + // GPUs < sm_70 don't support tensor ops. if (cc_major >= 7 && TensorOpMathEnabled()) { use_tensor_ops = true; } @@ -2623,6 +2627,102 @@ bool CUDABlas::DoBlasGemmBatched( return status.ok(); } +bool CUDABlas::DoBlasGemmStridedBatched( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, float alpha, const DeviceMemory<Eigen::half> &a, + int lda, int64 stride_a, const DeviceMemory<Eigen::half> &b, int ldb, + int64 stride_b, float beta, DeviceMemory<Eigen::half> *c, int ldc, + int64 stride_c, int batch_count) { + bool use_tensor_ops = false; +#if CUDA_VERSION >= 9000 + int cc_major, cc_minor; + stream->parent()->GetDeviceDescription().cuda_compute_capability(&cc_major, + &cc_minor); + + // GPUs < sm_70 don't support tensor ops. + if (cc_major >= 7 && TensorOpMathEnabled()) { + use_tensor_ops = true; + } +#endif + // We'd need cublasgemmStridedBatchedEx for this, which isn't available before + // CUDA 9.1. Fall back to a loop. + for (int batch = 0; batch < batch_count; ++batch) { + const auto *a_matrix = + reinterpret_cast<const __half *>(CUDAMemory(a) + batch * stride_a); + const auto *b_matrix = + reinterpret_cast<const __half *>(CUDAMemory(b) + batch * stride_b); + auto *c_matrix = + reinterpret_cast<__half *>(CUDAMemoryMutable(c) + batch * stride_c); + bool ok = DoBlasInternalImpl( + wrap::cublasSgemmEx, stream, true /* = pointer_mode_host */, + true /* = err_on_failure= */, use_tensor_ops, CUDABlasTranspose(transa), + CUDABlasTranspose(transb), m, n, k, &alpha, a_matrix, SE_CUDA_DATA_HALF, + lda, b_matrix, SE_CUDA_DATA_HALF, ldb, &beta, c_matrix, + SE_CUDA_DATA_HALF, ldc); + if (!ok) { + LOG(ERROR) << "failed BLAS call, see log for details"; + return false; + } + } + return true; +} + +bool CUDABlas::DoBlasGemmStridedBatched( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, int lda, + int64 stride_a, const DeviceMemory<float> &b, int ldb, int64 stride_b, + float beta, DeviceMemory<float> *c, int ldc, int64 stride_c, + int batch_count) { + return DoBlasInternal( + wrap::cublasSgemmStridedBatched, stream, true /* = pointer_mode_host */, + CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, &alpha, + CUDAMemory(a), lda, stride_a, CUDAMemory(b), ldb, stride_b, &beta, + CUDAMemoryMutable(c), ldc, stride_c, batch_count); +} + +bool CUDABlas::DoBlasGemmStridedBatched( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, double alpha, const DeviceMemory<double> &a, int lda, + int64 stride_a, const DeviceMemory<double> &b, int ldb, int64 stride_b, + double beta, DeviceMemory<double> *c, int ldc, int64 stride_c, + int batch_count) { + return DoBlasInternal( + wrap::cublasDgemmStridedBatched, stream, true /* = pointer_mode_host */, + CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, &alpha, + CUDAMemory(a), lda, stride_a, CUDAMemory(b), ldb, stride_b, &beta, + CUDAMemoryMutable(c), ldc, stride_c, batch_count); +} + +bool CUDABlas::DoBlasGemmStridedBatched( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, std::complex<float> alpha, + const DeviceMemory<std::complex<float>> &a, int lda, int64 stride_a, + const DeviceMemory<std::complex<float>> &b, int ldb, int64 stride_b, + std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc, + int64 stride_c, int batch_count) { + return DoBlasInternal( + wrap::cublasCgemmStridedBatched, stream, true /* = pointer_mode_host */, + CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, + CUDAComplex(&alpha), CUDAComplex(CUDAMemory(a)), lda, stride_a, + CUDAComplex(CUDAMemory(b)), ldb, stride_b, CUDAComplex(&beta), + CUDAComplex(CUDAMemoryMutable(c)), ldc, stride_c, batch_count); +} + +bool CUDABlas::DoBlasGemmStridedBatched( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, std::complex<double> alpha, + const DeviceMemory<std::complex<double>> &a, int lda, int64 stride_a, + const DeviceMemory<std::complex<double>> &b, int ldb, int64 stride_b, + std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc, + int64 stride_c, int batch_count) { + return DoBlasInternal( + wrap::cublasZgemmStridedBatched, stream, true /* = pointer_mode_host */, + CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, + CUDAComplex(&alpha), CUDAComplex(CUDAMemory(a)), lda, stride_a, + CUDAComplex(CUDAMemory(b)), ldb, stride_b, CUDAComplex(&beta), + CUDAComplex(CUDAMemoryMutable(c)), ldc, stride_c, batch_count); +} + bool CUDABlas::DoBlasHemm(Stream *stream, blas::Side side, blas::UpperLower uplo, uint64 m, uint64 n, std::complex<float> alpha, |