diff options
author | Benjamin Kramer <kramerb@google.com> | 2018-08-03 10:29:38 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-03 10:33:04 -0700 |
commit | 1e55f802fd4bc30a3e68726e2914dba0c08ffbd8 (patch) | |
tree | e7d7234c073301ad56cdaf098e93502fc0ab2e72 /tensorflow/stream_executor | |
parent | e36c16c6720dd64ae5d8a1f8555102a1323af9ae (diff) |
[XLA:GPU] Add a fast version of gemmStridedBatched for cuda 9.1
It's unfortunate that this was only added in 9.1, but I haven't found a good
way of emulating the behavior on 9.0 without falling back to non-batched gemms.
PiperOrigin-RevId: 207286575
Diffstat (limited to 'tensorflow/stream_executor')
-rw-r--r-- | tensorflow/stream_executor/cuda/cuda_blas.cc | 34 |
1 files changed, 26 insertions, 8 deletions
diff --git a/tensorflow/stream_executor/cuda/cuda_blas.cc b/tensorflow/stream_executor/cuda/cuda_blas.cc index e5b5f7fc94..ab7091b3f5 100644 --- a/tensorflow/stream_executor/cuda/cuda_blas.cc +++ b/tensorflow/stream_executor/cuda/cuda_blas.cc @@ -292,6 +292,7 @@ STREAM_EXECUTOR_CUBLAS_WRAP(cublasSetMathMode) #if CUDA_VERSION >= 9010 STREAM_EXECUTOR_CUBLAS_WRAP(cublasGemmBatchedEx) +STREAM_EXECUTOR_CUBLAS_WRAP(cublasGemmStridedBatchedEx) #endif } // namespace wrap @@ -2636,16 +2637,33 @@ bool CUDABlas::DoBlasGemmStridedBatched( bool use_tensor_ops = false; #if CUDA_VERSION >= 9000 int cc_major, cc_minor; - stream->parent()->GetDeviceDescription().cuda_compute_capability(&cc_major, - &cc_minor); - - // GPUs < sm_70 don't support tensor ops. - if (cc_major >= 7 && TensorOpMathEnabled()) { - use_tensor_ops = true; + if (stream->parent()->GetDeviceDescription().cuda_compute_capability( + &cc_major, &cc_minor)) { + // GPUs < sm_70 don't support tensor ops. + if (cc_major >= 7 && TensorOpMathEnabled()) { + use_tensor_ops = true; + } +#if CUDA_VERSION >= 9010 + if (cc_major >= 5) { + cublasGemmAlgo_t algo = + (use_tensor_ops ? CUBLAS_GEMM_DFALT_TENSOR_OP : CUBLAS_GEMM_DFALT); + bool ok = DoBlasInternalImpl( + wrap::cublasGemmStridedBatchedEx, stream, + true /* = pointer_mode_host */, true /* = err_on_failure */, + use_tensor_ops, CUDABlasTranspose(transa), CUDABlasTranspose(transb), + m, n, k, &alpha, CUDAMemory(a), CUDA_R_16F, lda, stride_a, + CUDAMemory(b), CUDA_R_16F, ldb, stride_b, &beta, CUDAMemoryMutable(c), + CUDA_R_16F, ldc, stride_c, batch_count, CUDA_R_32F, algo); + if (ok) { + return true; + } + LOG(ERROR) << "failed BLAS call, see log for details"; + return false; + } +#endif } #endif - // We'd need cublasgemmStridedBatchedEx for this, which isn't available before - // CUDA 9.1. Fall back to a loop. + // Either CUDA_VERSION < 9.1 or SM < 5.0. Fall back to a loop. for (int batch = 0; batch < batch_count; ++batch) { const auto *a_matrix = reinterpret_cast<const __half *>(CUDAMemory(a) + batch * stride_a); |