diff options
author | 2016-06-11 10:45:56 -0800 | |
---|---|---|
committer | 2016-06-11 11:49:48 -0700 | |
commit | 5a65d43a9e456aac08b32a1d38cbe123d73fcda8 (patch) | |
tree | 2f30b8d138b429b2ac3e436e11a0df4eb5af55ef /tensorflow/stream_executor/cuda/cuda_blas.cc | |
parent | f34e3976228e0c15d979786823c83a780eebb212 (diff) |
Merge changes from github.
Change: 124644444
Diffstat (limited to 'tensorflow/stream_executor/cuda/cuda_blas.cc')
-rw-r--r-- | tensorflow/stream_executor/cuda/cuda_blas.cc | 12 |
1 files changed, 9 insertions, 3 deletions
diff --git a/tensorflow/stream_executor/cuda/cuda_blas.cc b/tensorflow/stream_executor/cuda/cuda_blas.cc index 224803fc84..a9dd2953e5 100644 --- a/tensorflow/stream_executor/cuda/cuda_blas.cc +++ b/tensorflow/stream_executor/cuda/cuda_blas.cc @@ -25,6 +25,12 @@ limitations under the License. #define EIGEN_HAS_CUDA_FP16 #endif +#if CUDA_VERSION >= 8000 +#define SE_CUDA_DATA_HALF CUDA_R_16F +#else +#define SE_CUDA_DATA_HALF CUBLAS_DATA_HALF +#endif + #include "tensorflow/stream_executor/cuda/cuda_blas.h" #include <dlfcn.h> @@ -1680,10 +1686,10 @@ bool CUDABlas::DoBlasGemm( return DoBlasInternal( dynload::cublasSgemmEx, stream, true /* = pointer_mode_host */, CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, &alpha, - CUDAMemory(a), CUBLAS_DATA_HALF, lda, - CUDAMemory(b), CUBLAS_DATA_HALF, ldb, + CUDAMemory(a), SE_CUDA_DATA_HALF, lda, + CUDAMemory(b), SE_CUDA_DATA_HALF, ldb, &beta, - CUDAMemoryMutable(c), CUBLAS_DATA_HALF, ldc); + CUDAMemoryMutable(c), SE_CUDA_DATA_HALF, ldc); #else LOG(ERROR) << "fp16 sgemm is not implemented in this cuBLAS version " << "(need at least CUDA 7.5)"; |