aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor/cuda/cuda_blas.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/stream_executor/cuda/cuda_blas.cc')
-rw-r--r--tensorflow/stream_executor/cuda/cuda_blas.cc51
1 files changed, 33 insertions, 18 deletions
diff --git a/tensorflow/stream_executor/cuda/cuda_blas.cc b/tensorflow/stream_executor/cuda/cuda_blas.cc
index 874bf0e8cb..67babd7f79 100644
--- a/tensorflow/stream_executor/cuda/cuda_blas.cc
+++ b/tensorflow/stream_executor/cuda/cuda_blas.cc
@@ -2223,26 +2223,41 @@ bool CUDABlas::DoBlasGemmWithAlgorithmImpl(
bool CUDABlas::GetBlasGemmAlgorithms(
std::vector<blas::AlgorithmType> *out_algorithms) {
-// cublasGemmAlgo_t (and the function that accepts this type, cublasGemmEx)
-// were first introduced in CUDA 8.
-// Note that when CUDA version and compute capability is not sufficient, we
-// still return the out_algorithms. Caller needs to make sure that in this case,
-// the returned vector is empty.
- for (cublasGemmAlgo_t algo : {
- CUBLAS_GEMM_DFALT, CUBLAS_GEMM_ALGO0, CUBLAS_GEMM_ALGO1,
- CUBLAS_GEMM_ALGO2, CUBLAS_GEMM_ALGO3, CUBLAS_GEMM_ALGO4,
- CUBLAS_GEMM_ALGO5, CUBLAS_GEMM_ALGO6, CUBLAS_GEMM_ALGO7,
+ // cublasGemmAlgo_t (and the function that accepts this type, cublasGemmEx)
+ // were first introduced in CUDA 8.
+ //
+ // Note that when CUDA version and compute capability is not sufficient, we
+ // still return the out_algorithms. Caller needs to make sure that in this
+ // case, the returned vector is empty.
+ *out_algorithms = {
+ CUBLAS_GEMM_DFALT,
+ CUBLAS_GEMM_ALGO0,
+ CUBLAS_GEMM_ALGO1,
+ CUBLAS_GEMM_ALGO2,
+ CUBLAS_GEMM_ALGO3,
+ CUBLAS_GEMM_ALGO4,
+ CUBLAS_GEMM_ALGO5,
+ CUBLAS_GEMM_ALGO6,
+ CUBLAS_GEMM_ALGO7,
#if CUDA_VERSION >= 9000
- CUBLAS_GEMM_ALGO8, CUBLAS_GEMM_ALGO9, CUBLAS_GEMM_ALGO10,
- CUBLAS_GEMM_ALGO11, CUBLAS_GEMM_ALGO12, CUBLAS_GEMM_ALGO13,
- CUBLAS_GEMM_ALGO14, CUBLAS_GEMM_ALGO15, CUBLAS_GEMM_ALGO16,
- CUBLAS_GEMM_ALGO17, CUBLAS_GEMM_DFALT_TENSOR_OP,
- CUBLAS_GEMM_ALGO0_TENSOR_OP, CUBLAS_GEMM_ALGO1_TENSOR_OP,
- CUBLAS_GEMM_ALGO2_TENSOR_OP
+ CUBLAS_GEMM_ALGO8,
+ CUBLAS_GEMM_ALGO9,
+ CUBLAS_GEMM_ALGO10,
+ CUBLAS_GEMM_ALGO11,
+ CUBLAS_GEMM_ALGO12,
+ CUBLAS_GEMM_ALGO13,
+ CUBLAS_GEMM_ALGO14,
+ CUBLAS_GEMM_ALGO15,
+ CUBLAS_GEMM_ALGO16,
+ CUBLAS_GEMM_ALGO17,
+ CUBLAS_GEMM_DFALT_TENSOR_OP,
+ CUBLAS_GEMM_ALGO0_TENSOR_OP,
+ CUBLAS_GEMM_ALGO1_TENSOR_OP,
+ CUBLAS_GEMM_ALGO2_TENSOR_OP,
+ CUBLAS_GEMM_ALGO3_TENSOR_OP,
+ CUBLAS_GEMM_ALGO4_TENSOR_OP,
#endif
- }) {
- out_algorithms->push_back(algo);
- }
+ };
return true;
}