aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor/cuda/cuda_blas.cc
diff options
context:
space:
mode:
authorGravatar Vijay Vasudevan <vrv@google.com>2016-06-11 10:45:56 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-06-11 11:49:48 -0700
commit5a65d43a9e456aac08b32a1d38cbe123d73fcda8 (patch)
tree2f30b8d138b429b2ac3e436e11a0df4eb5af55ef /tensorflow/stream_executor/cuda/cuda_blas.cc
parentf34e3976228e0c15d979786823c83a780eebb212 (diff)
Merge changes from github.
Change: 124644444
Diffstat (limited to 'tensorflow/stream_executor/cuda/cuda_blas.cc')
-rw-r--r--tensorflow/stream_executor/cuda/cuda_blas.cc12
1 files changed, 9 insertions, 3 deletions
diff --git a/tensorflow/stream_executor/cuda/cuda_blas.cc b/tensorflow/stream_executor/cuda/cuda_blas.cc
index 224803fc84..a9dd2953e5 100644
--- a/tensorflow/stream_executor/cuda/cuda_blas.cc
+++ b/tensorflow/stream_executor/cuda/cuda_blas.cc
@@ -25,6 +25,12 @@ limitations under the License.
#define EIGEN_HAS_CUDA_FP16
#endif
+#if CUDA_VERSION >= 8000
+#define SE_CUDA_DATA_HALF CUDA_R_16F
+#else
+#define SE_CUDA_DATA_HALF CUBLAS_DATA_HALF
+#endif
+
#include "tensorflow/stream_executor/cuda/cuda_blas.h"
#include <dlfcn.h>
@@ -1680,10 +1686,10 @@ bool CUDABlas::DoBlasGemm(
return DoBlasInternal(
dynload::cublasSgemmEx, stream, true /* = pointer_mode_host */,
CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, &alpha,
- CUDAMemory(a), CUBLAS_DATA_HALF, lda,
- CUDAMemory(b), CUBLAS_DATA_HALF, ldb,
+ CUDAMemory(a), SE_CUDA_DATA_HALF, lda,
+ CUDAMemory(b), SE_CUDA_DATA_HALF, ldb,
&beta,
- CUDAMemoryMutable(c), CUBLAS_DATA_HALF, ldc);
+ CUDAMemoryMutable(c), SE_CUDA_DATA_HALF, ldc);
#else
LOG(ERROR) << "fp16 sgemm is not implemented in this cuBLAS version "
<< "(need at least CUDA 7.5)";