From 43a465f93a5f21d5015d3d39e028dc5269585c4a Mon Sep 17 00:00:00 2001 From: Justin Lebar Date: Tue, 31 Jul 2018 15:57:25 -0700 Subject: [SE] Add missing cublas algorithms for cuda 9.0, CUBLAS_GEMM_ALGO{3,4}_TENSOR_OP. These appear to have been omitted by mistake. PiperOrigin-RevId: 206843312 --- tensorflow/stream_executor/cuda/cuda_blas.cc | 51 ++++++++++++++++++---------- 1 file changed, 33 insertions(+), 18 deletions(-) (limited to 'tensorflow/stream_executor') diff --git a/tensorflow/stream_executor/cuda/cuda_blas.cc b/tensorflow/stream_executor/cuda/cuda_blas.cc index 874bf0e8cb..67babd7f79 100644 --- a/tensorflow/stream_executor/cuda/cuda_blas.cc +++ b/tensorflow/stream_executor/cuda/cuda_blas.cc @@ -2223,26 +2223,41 @@ bool CUDABlas::DoBlasGemmWithAlgorithmImpl( bool CUDABlas::GetBlasGemmAlgorithms( std::vector *out_algorithms) { -// cublasGemmAlgo_t (and the function that accepts this type, cublasGemmEx) -// were first introduced in CUDA 8. -// Note that when CUDA version and compute capability is not sufficient, we -// still return the out_algorithms. Caller needs to make sure that in this case, -// the returned vector is empty. - for (cublasGemmAlgo_t algo : { - CUBLAS_GEMM_DFALT, CUBLAS_GEMM_ALGO0, CUBLAS_GEMM_ALGO1, - CUBLAS_GEMM_ALGO2, CUBLAS_GEMM_ALGO3, CUBLAS_GEMM_ALGO4, - CUBLAS_GEMM_ALGO5, CUBLAS_GEMM_ALGO6, CUBLAS_GEMM_ALGO7, + // cublasGemmAlgo_t (and the function that accepts this type, cublasGemmEx) + // were first introduced in CUDA 8. + // + // Note that when CUDA version and compute capability is not sufficient, we + // still return the out_algorithms. Caller needs to make sure that in this + // case, the returned vector is empty. + *out_algorithms = { + CUBLAS_GEMM_DFALT, + CUBLAS_GEMM_ALGO0, + CUBLAS_GEMM_ALGO1, + CUBLAS_GEMM_ALGO2, + CUBLAS_GEMM_ALGO3, + CUBLAS_GEMM_ALGO4, + CUBLAS_GEMM_ALGO5, + CUBLAS_GEMM_ALGO6, + CUBLAS_GEMM_ALGO7, #if CUDA_VERSION >= 9000 - CUBLAS_GEMM_ALGO8, CUBLAS_GEMM_ALGO9, CUBLAS_GEMM_ALGO10, - CUBLAS_GEMM_ALGO11, CUBLAS_GEMM_ALGO12, CUBLAS_GEMM_ALGO13, - CUBLAS_GEMM_ALGO14, CUBLAS_GEMM_ALGO15, CUBLAS_GEMM_ALGO16, - CUBLAS_GEMM_ALGO17, CUBLAS_GEMM_DFALT_TENSOR_OP, - CUBLAS_GEMM_ALGO0_TENSOR_OP, CUBLAS_GEMM_ALGO1_TENSOR_OP, - CUBLAS_GEMM_ALGO2_TENSOR_OP + CUBLAS_GEMM_ALGO8, + CUBLAS_GEMM_ALGO9, + CUBLAS_GEMM_ALGO10, + CUBLAS_GEMM_ALGO11, + CUBLAS_GEMM_ALGO12, + CUBLAS_GEMM_ALGO13, + CUBLAS_GEMM_ALGO14, + CUBLAS_GEMM_ALGO15, + CUBLAS_GEMM_ALGO16, + CUBLAS_GEMM_ALGO17, + CUBLAS_GEMM_DFALT_TENSOR_OP, + CUBLAS_GEMM_ALGO0_TENSOR_OP, + CUBLAS_GEMM_ALGO1_TENSOR_OP, + CUBLAS_GEMM_ALGO2_TENSOR_OP, + CUBLAS_GEMM_ALGO3_TENSOR_OP, + CUBLAS_GEMM_ALGO4_TENSOR_OP, #endif - }) { - out_algorithms->push_back(algo); - } + }; return true; } -- cgit v1.2.3