From 1e55f802fd4bc30a3e68726e2914dba0c08ffbd8 Mon Sep 17 00:00:00 2001 From: Benjamin Kramer Date: Fri, 3 Aug 2018 10:29:38 -0700 Subject: [XLA:GPU] Add a fast version of gemmStridedBatched for cuda 9.1 It's unfortunate that this was only added in 9.1, but I haven't found a good way of emulating the behavior on 9.0 without falling back to non-batched gemms. PiperOrigin-RevId: 207286575 --- tensorflow/stream_executor/cuda/cuda_blas.cc | 34 +++++++++++++++++++++------- 1 file changed, 26 insertions(+), 8 deletions(-) (limited to 'tensorflow/stream_executor') diff --git a/tensorflow/stream_executor/cuda/cuda_blas.cc b/tensorflow/stream_executor/cuda/cuda_blas.cc index e5b5f7fc94..ab7091b3f5 100644 --- a/tensorflow/stream_executor/cuda/cuda_blas.cc +++ b/tensorflow/stream_executor/cuda/cuda_blas.cc @@ -292,6 +292,7 @@ STREAM_EXECUTOR_CUBLAS_WRAP(cublasSetMathMode) #if CUDA_VERSION >= 9010 STREAM_EXECUTOR_CUBLAS_WRAP(cublasGemmBatchedEx) +STREAM_EXECUTOR_CUBLAS_WRAP(cublasGemmStridedBatchedEx) #endif } // namespace wrap @@ -2636,16 +2637,33 @@ bool CUDABlas::DoBlasGemmStridedBatched( bool use_tensor_ops = false; #if CUDA_VERSION >= 9000 int cc_major, cc_minor; - stream->parent()->GetDeviceDescription().cuda_compute_capability(&cc_major, - &cc_minor); - - // GPUs < sm_70 don't support tensor ops. - if (cc_major >= 7 && TensorOpMathEnabled()) { - use_tensor_ops = true; + if (stream->parent()->GetDeviceDescription().cuda_compute_capability( + &cc_major, &cc_minor)) { + // GPUs < sm_70 don't support tensor ops. + if (cc_major >= 7 && TensorOpMathEnabled()) { + use_tensor_ops = true; + } +#if CUDA_VERSION >= 9010 + if (cc_major >= 5) { + cublasGemmAlgo_t algo = + (use_tensor_ops ? CUBLAS_GEMM_DFALT_TENSOR_OP : CUBLAS_GEMM_DFALT); + bool ok = DoBlasInternalImpl( + wrap::cublasGemmStridedBatchedEx, stream, + true /* = pointer_mode_host */, true /* = err_on_failure */, + use_tensor_ops, CUDABlasTranspose(transa), CUDABlasTranspose(transb), + m, n, k, &alpha, CUDAMemory(a), CUDA_R_16F, lda, stride_a, + CUDAMemory(b), CUDA_R_16F, ldb, stride_b, &beta, CUDAMemoryMutable(c), + CUDA_R_16F, ldc, stride_c, batch_count, CUDA_R_32F, algo); + if (ok) { + return true; + } + LOG(ERROR) << "failed BLAS call, see log for details"; + return false; + } +#endif } #endif - // We'd need cublasgemmStridedBatchedEx for this, which isn't available before - // CUDA 9.1. Fall back to a loop. + // Either CUDA_VERSION < 9.1 or SM < 5.0. Fall back to a loop. for (int batch = 0; batch < batch_count; ++batch) { const auto *a_matrix = reinterpret_cast(CUDAMemory(a) + batch * stride_a); -- cgit v1.2.3