aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/stream_executor')
-rw-r--r--tensorflow/stream_executor/cuda/cuda_blas.cc34
1 files changed, 26 insertions, 8 deletions
diff --git a/tensorflow/stream_executor/cuda/cuda_blas.cc b/tensorflow/stream_executor/cuda/cuda_blas.cc
index e5b5f7fc94..ab7091b3f5 100644
--- a/tensorflow/stream_executor/cuda/cuda_blas.cc
+++ b/tensorflow/stream_executor/cuda/cuda_blas.cc
@@ -292,6 +292,7 @@ STREAM_EXECUTOR_CUBLAS_WRAP(cublasSetMathMode)
#if CUDA_VERSION >= 9010
STREAM_EXECUTOR_CUBLAS_WRAP(cublasGemmBatchedEx)
+STREAM_EXECUTOR_CUBLAS_WRAP(cublasGemmStridedBatchedEx)
#endif
} // namespace wrap
@@ -2636,16 +2637,33 @@ bool CUDABlas::DoBlasGemmStridedBatched(
bool use_tensor_ops = false;
#if CUDA_VERSION >= 9000
int cc_major, cc_minor;
- stream->parent()->GetDeviceDescription().cuda_compute_capability(&cc_major,
- &cc_minor);
-
- // GPUs < sm_70 don't support tensor ops.
- if (cc_major >= 7 && TensorOpMathEnabled()) {
- use_tensor_ops = true;
+ if (stream->parent()->GetDeviceDescription().cuda_compute_capability(
+ &cc_major, &cc_minor)) {
+ // GPUs < sm_70 don't support tensor ops.
+ if (cc_major >= 7 && TensorOpMathEnabled()) {
+ use_tensor_ops = true;
+ }
+#if CUDA_VERSION >= 9010
+ if (cc_major >= 5) {
+ cublasGemmAlgo_t algo =
+ (use_tensor_ops ? CUBLAS_GEMM_DFALT_TENSOR_OP : CUBLAS_GEMM_DFALT);
+ bool ok = DoBlasInternalImpl(
+ wrap::cublasGemmStridedBatchedEx, stream,
+ true /* = pointer_mode_host */, true /* = err_on_failure */,
+ use_tensor_ops, CUDABlasTranspose(transa), CUDABlasTranspose(transb),
+ m, n, k, &alpha, CUDAMemory(a), CUDA_R_16F, lda, stride_a,
+ CUDAMemory(b), CUDA_R_16F, ldb, stride_b, &beta, CUDAMemoryMutable(c),
+ CUDA_R_16F, ldc, stride_c, batch_count, CUDA_R_32F, algo);
+ if (ok) {
+ return true;
+ }
+ LOG(ERROR) << "failed BLAS call, see log for details";
+ return false;
+ }
+#endif
}
#endif
- // We'd need cublasgemmStridedBatchedEx for this, which isn't available before
- // CUDA 9.1. Fall back to a loop.
+ // Either CUDA_VERSION < 9.1 or SM < 5.0. Fall back to a loop.
for (int batch = 0; batch < batch_count; ++batch) {
const auto *a_matrix =
reinterpret_cast<const __half *>(CUDAMemory(a) + batch * stride_a);