aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor/cuda/cuda_blas.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/stream_executor/cuda/cuda_blas.cc')
-rw-r--r--tensorflow/stream_executor/cuda/cuda_blas.cc217
1 files changed, 197 insertions, 20 deletions
diff --git a/tensorflow/stream_executor/cuda/cuda_blas.cc b/tensorflow/stream_executor/cuda/cuda_blas.cc
index 874bf0e8cb..ab7091b3f5 100644
--- a/tensorflow/stream_executor/cuda/cuda_blas.cc
+++ b/tensorflow/stream_executor/cuda/cuda_blas.cc
@@ -279,6 +279,10 @@ STREAM_EXECUTOR_CUBLAS_WRAP(cublasSgemmEx)
#if CUDA_VERSION >= 8000
STREAM_EXECUTOR_CUBLAS_WRAP(cublasGemmEx)
+STREAM_EXECUTOR_CUBLAS_WRAP(cublasSgemmStridedBatched)
+STREAM_EXECUTOR_CUBLAS_WRAP(cublasDgemmStridedBatched)
+STREAM_EXECUTOR_CUBLAS_WRAP(cublasCgemmStridedBatched)
+STREAM_EXECUTOR_CUBLAS_WRAP(cublasZgemmStridedBatched)
#endif
#if CUDA_VERSION >= 9000
@@ -288,6 +292,7 @@ STREAM_EXECUTOR_CUBLAS_WRAP(cublasSetMathMode)
#if CUDA_VERSION >= 9010
STREAM_EXECUTOR_CUBLAS_WRAP(cublasGemmBatchedEx)
+STREAM_EXECUTOR_CUBLAS_WRAP(cublasGemmStridedBatchedEx)
#endif
} // namespace wrap
@@ -643,7 +648,7 @@ bool CUDABlas::DoBlasInternalImpl(FuncT cublas_func, Stream *stream,
}
#endif
cublasStatus_t ret = cublas_func(parent_, blas_, args...);
- if (err_on_failure && ret != CUBLAS_STATUS_SUCCESS) {
+ if ((err_on_failure || VLOG_IS_ON(3)) && ret != CUBLAS_STATUS_SUCCESS) {
LOG(ERROR) << "failed to run cuBLAS routine " << cublas_func.kName << ": "
<< ToString(ret);
}
@@ -1865,7 +1870,7 @@ bool CUDABlas::DoBlasGemm(
stream->parent()->GetDeviceDescription().cuda_compute_capability(&cc_major,
&cc_minor);
- // GPUs < sm_70 don't support Volta hardware.
+ // GPUs < sm_70 don't support tensor ops.
if (cc_major >= 7 && TensorOpMathEnabled()) {
use_tensor_ops = true;
}
@@ -2139,6 +2144,10 @@ static bool UsesTensorOps(blas::AlgorithmType algo) {
template <typename InType>
static bool TensorOpsAvailable(int cc_major) {
#if CUDA_VERSION >= 9000
+ // cublas *does* allow tensor ops on inputs that are not fp16, so this is not
+ // strictly correct. We can't simply enable it, though, as that would change
+ // clients' behavior significantly: Using tensor ops on fp32 inputs cause them
+ // to be rounded to fp16.
if (cc_major >= 7 && TensorOpMathEnabled() &&
std::is_same<InType, Eigen::half>::value) {
return true;
@@ -2160,16 +2169,30 @@ bool CUDABlas::DoBlasGemmWithAlgorithmImpl(
if (stream->parent()->GetDeviceDescription().cuda_compute_capability(
&cc_major, &cc_minor) &&
cc_major < 5) {
+ VLOG(2) << "DoBlasGemmWithAlgorithm returning false because sm" << cc_major
+ << cc_minor << " devices don't support explicit gemm algorithms.";
return false;
}
if (UsesTensorOps(algorithm) && !TensorOpsAvailable<InT>(cc_major)) {
+ if (std::is_same<InT, Eigen::half>::value) {
+ VLOG(2) << "DoBlasGemmWithAlgorithm returning false because algorithm "
+ << algorithm
+ << " uses tensor ops, but tensor ops are not available in sm"
+ << cc_major << "X devices.";
+ } else {
+ VLOG(2) << "DoBlasGemmWithAlgorithm returning false because algorithm "
+ << algorithm
+ << " uses tensor ops, but the input data type is not fp16.";
+ }
return false;
}
// Either both 'alpha' and 'beta' need to be pointers to device memory, or
// they need to be both host scalars.
if (alpha.is_pointer() != beta.is_pointer()) {
+ VLOG(2) << "DoBlasGemmWithAlgorithm returning false because one of `alpha` "
+ "and `beta` is a pointer, but the other is not.";
return false;
}
@@ -2177,6 +2200,9 @@ bool CUDABlas::DoBlasGemmWithAlgorithmImpl(
if (output_profile_result != nullptr) {
timer.reset(new CUDATimer(parent_));
if (!timer->Init() || !timer->Start(AsCUDAStream(stream))) {
+ VLOG(2) << "DoBlasGemmWithAlgorithm returning false because "
+ "output_profile_result was given, but we were unable to "
+ "create a CUDATimer.";
return false;
}
}
@@ -2186,6 +2212,8 @@ bool CUDABlas::DoBlasGemmWithAlgorithmImpl(
#if CUDA_VERSION >= 9000 && CUDA_VERSION < 9020
if ((algorithm == CUBLAS_GEMM_DEFAULT || algorithm >= CUBLAS_GEMM_ALGO13) &&
std::max({m, n, k}) >= 2097153 && cc_major < 7) {
+ VLOG(2) << "DoBlasGemmWithAlgorithm returning false to work around cudnn "
+ "<9.2 bug with m, n, or k >= 2097153. See b/79126339.";
return false;
}
#endif
@@ -2211,6 +2239,8 @@ bool CUDABlas::DoBlasGemmWithAlgorithmImpl(
// CUDATimer will CHECK-fail if we Stop() it while the stream is in an error
// state.
if (!timer->Stop(AsCUDAStream(stream))) {
+ VLOG(2) << "DoBlasGemmWithAlgorithm returning false; unable to stop "
+ "CUDATimer.";
return false;
}
output_profile_result->set_is_valid(true);
@@ -2223,26 +2253,60 @@ bool CUDABlas::DoBlasGemmWithAlgorithmImpl(
bool CUDABlas::GetBlasGemmAlgorithms(
std::vector<blas::AlgorithmType> *out_algorithms) {
-// cublasGemmAlgo_t (and the function that accepts this type, cublasGemmEx)
-// were first introduced in CUDA 8.
-// Note that when CUDA version and compute capability is not sufficient, we
-// still return the out_algorithms. Caller needs to make sure that in this case,
-// the returned vector is empty.
- for (cublasGemmAlgo_t algo : {
- CUBLAS_GEMM_DFALT, CUBLAS_GEMM_ALGO0, CUBLAS_GEMM_ALGO1,
- CUBLAS_GEMM_ALGO2, CUBLAS_GEMM_ALGO3, CUBLAS_GEMM_ALGO4,
- CUBLAS_GEMM_ALGO5, CUBLAS_GEMM_ALGO6, CUBLAS_GEMM_ALGO7,
+ // cublasGemmAlgo_t (and the function that accepts this type, cublasGemmEx)
+ // were first introduced in CUDA 8.
+ //
+ // Note that when CUDA version and compute capability is not sufficient, we
+ // still return the out_algorithms. Caller needs to make sure that in this
+ // case, the returned vector is empty.
+ *out_algorithms = {
+ CUBLAS_GEMM_DFALT,
+ CUBLAS_GEMM_ALGO0,
+ CUBLAS_GEMM_ALGO1,
+ CUBLAS_GEMM_ALGO2,
+ CUBLAS_GEMM_ALGO3,
+ CUBLAS_GEMM_ALGO4,
+ CUBLAS_GEMM_ALGO5,
+ CUBLAS_GEMM_ALGO6,
+ CUBLAS_GEMM_ALGO7,
#if CUDA_VERSION >= 9000
- CUBLAS_GEMM_ALGO8, CUBLAS_GEMM_ALGO9, CUBLAS_GEMM_ALGO10,
- CUBLAS_GEMM_ALGO11, CUBLAS_GEMM_ALGO12, CUBLAS_GEMM_ALGO13,
- CUBLAS_GEMM_ALGO14, CUBLAS_GEMM_ALGO15, CUBLAS_GEMM_ALGO16,
- CUBLAS_GEMM_ALGO17, CUBLAS_GEMM_DFALT_TENSOR_OP,
- CUBLAS_GEMM_ALGO0_TENSOR_OP, CUBLAS_GEMM_ALGO1_TENSOR_OP,
- CUBLAS_GEMM_ALGO2_TENSOR_OP
+ CUBLAS_GEMM_ALGO8,
+ CUBLAS_GEMM_ALGO9,
+ CUBLAS_GEMM_ALGO10,
+ CUBLAS_GEMM_ALGO11,
+ CUBLAS_GEMM_ALGO12,
+ CUBLAS_GEMM_ALGO13,
+ CUBLAS_GEMM_ALGO14,
+ CUBLAS_GEMM_ALGO15,
+ CUBLAS_GEMM_ALGO16,
+ CUBLAS_GEMM_ALGO17,
+ CUBLAS_GEMM_DFALT_TENSOR_OP,
+ CUBLAS_GEMM_ALGO0_TENSOR_OP,
+ CUBLAS_GEMM_ALGO1_TENSOR_OP,
+ CUBLAS_GEMM_ALGO2_TENSOR_OP,
+ CUBLAS_GEMM_ALGO3_TENSOR_OP,
+ CUBLAS_GEMM_ALGO4_TENSOR_OP,
#endif
- }) {
- out_algorithms->push_back(algo);
- }
+#if CUDA_VERSION >= 9200
+ CUBLAS_GEMM_ALGO18,
+ CUBLAS_GEMM_ALGO19,
+ CUBLAS_GEMM_ALGO20,
+ CUBLAS_GEMM_ALGO21,
+ CUBLAS_GEMM_ALGO22,
+ CUBLAS_GEMM_ALGO23,
+ CUBLAS_GEMM_ALGO5_TENSOR_OP,
+ CUBLAS_GEMM_ALGO6_TENSOR_OP,
+ CUBLAS_GEMM_ALGO7_TENSOR_OP,
+ CUBLAS_GEMM_ALGO8_TENSOR_OP,
+ CUBLAS_GEMM_ALGO9_TENSOR_OP,
+ CUBLAS_GEMM_ALGO10_TENSOR_OP,
+ CUBLAS_GEMM_ALGO11_TENSOR_OP,
+ CUBLAS_GEMM_ALGO12_TENSOR_OP,
+ CUBLAS_GEMM_ALGO13_TENSOR_OP,
+ CUBLAS_GEMM_ALGO14_TENSOR_OP,
+ CUBLAS_GEMM_ALGO15_TENSOR_OP,
+#endif
+ };
return true;
}
@@ -2564,6 +2628,119 @@ bool CUDABlas::DoBlasGemmBatched(
return status.ok();
}
+bool CUDABlas::DoBlasGemmStridedBatched(
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, float alpha, const DeviceMemory<Eigen::half> &a,
+ int lda, int64 stride_a, const DeviceMemory<Eigen::half> &b, int ldb,
+ int64 stride_b, float beta, DeviceMemory<Eigen::half> *c, int ldc,
+ int64 stride_c, int batch_count) {
+ bool use_tensor_ops = false;
+#if CUDA_VERSION >= 9000
+ int cc_major, cc_minor;
+ if (stream->parent()->GetDeviceDescription().cuda_compute_capability(
+ &cc_major, &cc_minor)) {
+ // GPUs < sm_70 don't support tensor ops.
+ if (cc_major >= 7 && TensorOpMathEnabled()) {
+ use_tensor_ops = true;
+ }
+#if CUDA_VERSION >= 9010
+ if (cc_major >= 5) {
+ cublasGemmAlgo_t algo =
+ (use_tensor_ops ? CUBLAS_GEMM_DFALT_TENSOR_OP : CUBLAS_GEMM_DFALT);
+ bool ok = DoBlasInternalImpl(
+ wrap::cublasGemmStridedBatchedEx, stream,
+ true /* = pointer_mode_host */, true /* = err_on_failure */,
+ use_tensor_ops, CUDABlasTranspose(transa), CUDABlasTranspose(transb),
+ m, n, k, &alpha, CUDAMemory(a), CUDA_R_16F, lda, stride_a,
+ CUDAMemory(b), CUDA_R_16F, ldb, stride_b, &beta, CUDAMemoryMutable(c),
+ CUDA_R_16F, ldc, stride_c, batch_count, CUDA_R_32F, algo);
+ if (ok) {
+ return true;
+ }
+ LOG(ERROR) << "failed BLAS call, see log for details";
+ return false;
+ }
+#endif
+ }
+#endif
+ // Either CUDA_VERSION < 9.1 or SM < 5.0. Fall back to a loop.
+ for (int batch = 0; batch < batch_count; ++batch) {
+ const auto *a_matrix =
+ reinterpret_cast<const __half *>(CUDAMemory(a) + batch * stride_a);
+ const auto *b_matrix =
+ reinterpret_cast<const __half *>(CUDAMemory(b) + batch * stride_b);
+ auto *c_matrix =
+ reinterpret_cast<__half *>(CUDAMemoryMutable(c) + batch * stride_c);
+ bool ok = DoBlasInternalImpl(
+ wrap::cublasSgemmEx, stream, true /* = pointer_mode_host */,
+ true /* = err_on_failure= */, use_tensor_ops, CUDABlasTranspose(transa),
+ CUDABlasTranspose(transb), m, n, k, &alpha, a_matrix, SE_CUDA_DATA_HALF,
+ lda, b_matrix, SE_CUDA_DATA_HALF, ldb, &beta, c_matrix,
+ SE_CUDA_DATA_HALF, ldc);
+ if (!ok) {
+ LOG(ERROR) << "failed BLAS call, see log for details";
+ return false;
+ }
+ }
+ return true;
+}
+
+bool CUDABlas::DoBlasGemmStridedBatched(
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
+ int64 stride_a, const DeviceMemory<float> &b, int ldb, int64 stride_b,
+ float beta, DeviceMemory<float> *c, int ldc, int64 stride_c,
+ int batch_count) {
+ return DoBlasInternal(
+ wrap::cublasSgemmStridedBatched, stream, true /* = pointer_mode_host */,
+ CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, &alpha,
+ CUDAMemory(a), lda, stride_a, CUDAMemory(b), ldb, stride_b, &beta,
+ CUDAMemoryMutable(c), ldc, stride_c, batch_count);
+}
+
+bool CUDABlas::DoBlasGemmStridedBatched(
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
+ int64 stride_a, const DeviceMemory<double> &b, int ldb, int64 stride_b,
+ double beta, DeviceMemory<double> *c, int ldc, int64 stride_c,
+ int batch_count) {
+ return DoBlasInternal(
+ wrap::cublasDgemmStridedBatched, stream, true /* = pointer_mode_host */,
+ CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, &alpha,
+ CUDAMemory(a), lda, stride_a, CUDAMemory(b), ldb, stride_b, &beta,
+ CUDAMemoryMutable(c), ldc, stride_c, batch_count);
+}
+
+bool CUDABlas::DoBlasGemmStridedBatched(
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &a, int lda, int64 stride_a,
+ const DeviceMemory<std::complex<float>> &b, int ldb, int64 stride_b,
+ std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
+ int64 stride_c, int batch_count) {
+ return DoBlasInternal(
+ wrap::cublasCgemmStridedBatched, stream, true /* = pointer_mode_host */,
+ CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k,
+ CUDAComplex(&alpha), CUDAComplex(CUDAMemory(a)), lda, stride_a,
+ CUDAComplex(CUDAMemory(b)), ldb, stride_b, CUDAComplex(&beta),
+ CUDAComplex(CUDAMemoryMutable(c)), ldc, stride_c, batch_count);
+}
+
+bool CUDABlas::DoBlasGemmStridedBatched(
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &a, int lda, int64 stride_a,
+ const DeviceMemory<std::complex<double>> &b, int ldb, int64 stride_b,
+ std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
+ int64 stride_c, int batch_count) {
+ return DoBlasInternal(
+ wrap::cublasZgemmStridedBatched, stream, true /* = pointer_mode_host */,
+ CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k,
+ CUDAComplex(&alpha), CUDAComplex(CUDAMemory(a)), lda, stride_a,
+ CUDAComplex(CUDAMemory(b)), ldb, stride_b, CUDAComplex(&beta),
+ CUDAComplex(CUDAMemoryMutable(c)), ldc, stride_c, batch_count);
+}
+
bool CUDABlas::DoBlasHemm(Stream *stream, blas::Side side,
blas::UpperLower uplo, uint64 m, uint64 n,
std::complex<float> alpha,