diff options
Diffstat (limited to 'tensorflow/stream_executor')
-rw-r--r-- | tensorflow/stream_executor/cuda/cuda_blas.cc | 34 |
1 files changed, 26 insertions, 8 deletions
diff --git a/tensorflow/stream_executor/cuda/cuda_blas.cc b/tensorflow/stream_executor/cuda/cuda_blas.cc index e5b5f7fc94..ab7091b3f5 100644 --- a/tensorflow/stream_executor/cuda/cuda_blas.cc +++ b/tensorflow/stream_executor/cuda/cuda_blas.cc @@ -292,6 +292,7 @@ STREAM_EXECUTOR_CUBLAS_WRAP(cublasSetMathMode) #if CUDA_VERSION >= 9010 STREAM_EXECUTOR_CUBLAS_WRAP(cublasGemmBatchedEx) +STREAM_EXECUTOR_CUBLAS_WRAP(cublasGemmStridedBatchedEx) #endif } // namespace wrap @@ -2636,16 +2637,33 @@ bool CUDABlas::DoBlasGemmStridedBatched( bool use_tensor_ops = false; #if CUDA_VERSION >= 9000 int cc_major, cc_minor; - stream->parent()->GetDeviceDescription().cuda_compute_capability(&cc_major, - &cc_minor); - - // GPUs < sm_70 don't support tensor ops. - if (cc_major >= 7 && TensorOpMathEnabled()) { - use_tensor_ops = true; + if (stream->parent()->GetDeviceDescription().cuda_compute_capability( + &cc_major, &cc_minor)) { + // GPUs < sm_70 don't support tensor ops. + if (cc_major >= 7 && TensorOpMathEnabled()) { + use_tensor_ops = true; + } +#if CUDA_VERSION >= 9010 + if (cc_major >= 5) { + cublasGemmAlgo_t algo = + (use_tensor_ops ? CUBLAS_GEMM_DFALT_TENSOR_OP : CUBLAS_GEMM_DFALT); + bool ok = DoBlasInternalImpl( + wrap::cublasGemmStridedBatchedEx, stream, + true /* = pointer_mode_host */, true /* = err_on_failure */, + use_tensor_ops, CUDABlasTranspose(transa), CUDABlasTranspose(transb), + m, n, k, &alpha, CUDAMemory(a), CUDA_R_16F, lda, stride_a, + CUDAMemory(b), CUDA_R_16F, ldb, stride_b, &beta, CUDAMemoryMutable(c), + CUDA_R_16F, ldc, stride_c, batch_count, CUDA_R_32F, algo); + if (ok) { + return true; + } + LOG(ERROR) << "failed BLAS call, see log for details"; + return false; + } +#endif } #endif - // We'd need cublasgemmStridedBatchedEx for this, which isn't available before - // CUDA 9.1. Fall back to a loop. + // Either CUDA_VERSION < 9.1 or SM < 5.0. Fall back to a loop. for (int batch = 0; batch < batch_count; ++batch) { const auto *a_matrix = reinterpret_cast<const __half *>(CUDAMemory(a) + batch * stride_a); |