aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor
diff options
context:
space:
mode:
authorGravatar Benjamin Kramer <kramerb@google.com>2018-08-03 10:29:38 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-03 10:33:04 -0700
commit1e55f802fd4bc30a3e68726e2914dba0c08ffbd8 (patch)
treee7d7234c073301ad56cdaf098e93502fc0ab2e72 /tensorflow/stream_executor
parente36c16c6720dd64ae5d8a1f8555102a1323af9ae (diff)
[XLA:GPU] Add a fast version of gemmStridedBatched for cuda 9.1
It's unfortunate that this was only added in 9.1, but I haven't found a good way of emulating the behavior on 9.0 without falling back to non-batched gemms. PiperOrigin-RevId: 207286575
Diffstat (limited to 'tensorflow/stream_executor')
-rw-r--r--tensorflow/stream_executor/cuda/cuda_blas.cc34
1 files changed, 26 insertions, 8 deletions
diff --git a/tensorflow/stream_executor/cuda/cuda_blas.cc b/tensorflow/stream_executor/cuda/cuda_blas.cc
index e5b5f7fc94..ab7091b3f5 100644
--- a/tensorflow/stream_executor/cuda/cuda_blas.cc
+++ b/tensorflow/stream_executor/cuda/cuda_blas.cc
@@ -292,6 +292,7 @@ STREAM_EXECUTOR_CUBLAS_WRAP(cublasSetMathMode)
#if CUDA_VERSION >= 9010
STREAM_EXECUTOR_CUBLAS_WRAP(cublasGemmBatchedEx)
+STREAM_EXECUTOR_CUBLAS_WRAP(cublasGemmStridedBatchedEx)
#endif
} // namespace wrap
@@ -2636,16 +2637,33 @@ bool CUDABlas::DoBlasGemmStridedBatched(
bool use_tensor_ops = false;
#if CUDA_VERSION >= 9000
int cc_major, cc_minor;
- stream->parent()->GetDeviceDescription().cuda_compute_capability(&cc_major,
- &cc_minor);
-
- // GPUs < sm_70 don't support tensor ops.
- if (cc_major >= 7 && TensorOpMathEnabled()) {
- use_tensor_ops = true;
+ if (stream->parent()->GetDeviceDescription().cuda_compute_capability(
+ &cc_major, &cc_minor)) {
+ // GPUs < sm_70 don't support tensor ops.
+ if (cc_major >= 7 && TensorOpMathEnabled()) {
+ use_tensor_ops = true;
+ }
+#if CUDA_VERSION >= 9010
+ if (cc_major >= 5) {
+ cublasGemmAlgo_t algo =
+ (use_tensor_ops ? CUBLAS_GEMM_DFALT_TENSOR_OP : CUBLAS_GEMM_DFALT);
+ bool ok = DoBlasInternalImpl(
+ wrap::cublasGemmStridedBatchedEx, stream,
+ true /* = pointer_mode_host */, true /* = err_on_failure */,
+ use_tensor_ops, CUDABlasTranspose(transa), CUDABlasTranspose(transb),
+ m, n, k, &alpha, CUDAMemory(a), CUDA_R_16F, lda, stride_a,
+ CUDAMemory(b), CUDA_R_16F, ldb, stride_b, &beta, CUDAMemoryMutable(c),
+ CUDA_R_16F, ldc, stride_c, batch_count, CUDA_R_32F, algo);
+ if (ok) {
+ return true;
+ }
+ LOG(ERROR) << "failed BLAS call, see log for details";
+ return false;
+ }
+#endif
}
#endif
- // We'd need cublasgemmStridedBatchedEx for this, which isn't available before
- // CUDA 9.1. Fall back to a loop.
+ // Either CUDA_VERSION < 9.1 or SM < 5.0. Fall back to a loop.
for (int batch = 0; batch < batch_count; ++batch) {
const auto *a_matrix =
reinterpret_cast<const __half *>(CUDAMemory(a) + batch * stride_a);