aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor
diff options
context:
space:
mode:
authorGravatar Justin Lebar <jlebar@google.com>2018-07-31 15:57:25 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-31 16:02:07 -0700
commit43a465f93a5f21d5015d3d39e028dc5269585c4a (patch)
tree68cb1219e39d44b6bea7ef9df5969a1cec60e13f /tensorflow/stream_executor
parent15167d56fd7c9f7870ceebd360f34839d04827c5 (diff)
[SE] Add missing cublas algorithms for cuda 9.0, CUBLAS_GEMM_ALGO{3,4}_TENSOR_OP.
These appear to have been omitted by mistake. PiperOrigin-RevId: 206843312
Diffstat (limited to 'tensorflow/stream_executor')
-rw-r--r--tensorflow/stream_executor/cuda/cuda_blas.cc51
1 files changed, 33 insertions, 18 deletions
diff --git a/tensorflow/stream_executor/cuda/cuda_blas.cc b/tensorflow/stream_executor/cuda/cuda_blas.cc
index 874bf0e8cb..67babd7f79 100644
--- a/tensorflow/stream_executor/cuda/cuda_blas.cc
+++ b/tensorflow/stream_executor/cuda/cuda_blas.cc
@@ -2223,26 +2223,41 @@ bool CUDABlas::DoBlasGemmWithAlgorithmImpl(
bool CUDABlas::GetBlasGemmAlgorithms(
std::vector<blas::AlgorithmType> *out_algorithms) {
-// cublasGemmAlgo_t (and the function that accepts this type, cublasGemmEx)
-// were first introduced in CUDA 8.
-// Note that when CUDA version and compute capability is not sufficient, we
-// still return the out_algorithms. Caller needs to make sure that in this case,
-// the returned vector is empty.
- for (cublasGemmAlgo_t algo : {
- CUBLAS_GEMM_DFALT, CUBLAS_GEMM_ALGO0, CUBLAS_GEMM_ALGO1,
- CUBLAS_GEMM_ALGO2, CUBLAS_GEMM_ALGO3, CUBLAS_GEMM_ALGO4,
- CUBLAS_GEMM_ALGO5, CUBLAS_GEMM_ALGO6, CUBLAS_GEMM_ALGO7,
+ // cublasGemmAlgo_t (and the function that accepts this type, cublasGemmEx)
+ // were first introduced in CUDA 8.
+ //
+ // Note that when CUDA version and compute capability is not sufficient, we
+ // still return the out_algorithms. Caller needs to make sure that in this
+ // case, the returned vector is empty.
+ *out_algorithms = {
+ CUBLAS_GEMM_DFALT,
+ CUBLAS_GEMM_ALGO0,
+ CUBLAS_GEMM_ALGO1,
+ CUBLAS_GEMM_ALGO2,
+ CUBLAS_GEMM_ALGO3,
+ CUBLAS_GEMM_ALGO4,
+ CUBLAS_GEMM_ALGO5,
+ CUBLAS_GEMM_ALGO6,
+ CUBLAS_GEMM_ALGO7,
#if CUDA_VERSION >= 9000
- CUBLAS_GEMM_ALGO8, CUBLAS_GEMM_ALGO9, CUBLAS_GEMM_ALGO10,
- CUBLAS_GEMM_ALGO11, CUBLAS_GEMM_ALGO12, CUBLAS_GEMM_ALGO13,
- CUBLAS_GEMM_ALGO14, CUBLAS_GEMM_ALGO15, CUBLAS_GEMM_ALGO16,
- CUBLAS_GEMM_ALGO17, CUBLAS_GEMM_DFALT_TENSOR_OP,
- CUBLAS_GEMM_ALGO0_TENSOR_OP, CUBLAS_GEMM_ALGO1_TENSOR_OP,
- CUBLAS_GEMM_ALGO2_TENSOR_OP
+ CUBLAS_GEMM_ALGO8,
+ CUBLAS_GEMM_ALGO9,
+ CUBLAS_GEMM_ALGO10,
+ CUBLAS_GEMM_ALGO11,
+ CUBLAS_GEMM_ALGO12,
+ CUBLAS_GEMM_ALGO13,
+ CUBLAS_GEMM_ALGO14,
+ CUBLAS_GEMM_ALGO15,
+ CUBLAS_GEMM_ALGO16,
+ CUBLAS_GEMM_ALGO17,
+ CUBLAS_GEMM_DFALT_TENSOR_OP,
+ CUBLAS_GEMM_ALGO0_TENSOR_OP,
+ CUBLAS_GEMM_ALGO1_TENSOR_OP,
+ CUBLAS_GEMM_ALGO2_TENSOR_OP,
+ CUBLAS_GEMM_ALGO3_TENSOR_OP,
+ CUBLAS_GEMM_ALGO4_TENSOR_OP,
#endif
- }) {
- out_algorithms->push_back(algo);
- }
+ };
return true;
}