diff options
Diffstat (limited to 'tensorflow/stream_executor/cuda/cuda_blas.cc')
-rw-r--r-- | tensorflow/stream_executor/cuda/cuda_blas.cc | 217 |
1 files changed, 197 insertions, 20 deletions
diff --git a/tensorflow/stream_executor/cuda/cuda_blas.cc b/tensorflow/stream_executor/cuda/cuda_blas.cc index 874bf0e8cb..ab7091b3f5 100644 --- a/tensorflow/stream_executor/cuda/cuda_blas.cc +++ b/tensorflow/stream_executor/cuda/cuda_blas.cc @@ -279,6 +279,10 @@ STREAM_EXECUTOR_CUBLAS_WRAP(cublasSgemmEx) #if CUDA_VERSION >= 8000 STREAM_EXECUTOR_CUBLAS_WRAP(cublasGemmEx) +STREAM_EXECUTOR_CUBLAS_WRAP(cublasSgemmStridedBatched) +STREAM_EXECUTOR_CUBLAS_WRAP(cublasDgemmStridedBatched) +STREAM_EXECUTOR_CUBLAS_WRAP(cublasCgemmStridedBatched) +STREAM_EXECUTOR_CUBLAS_WRAP(cublasZgemmStridedBatched) #endif #if CUDA_VERSION >= 9000 @@ -288,6 +292,7 @@ STREAM_EXECUTOR_CUBLAS_WRAP(cublasSetMathMode) #if CUDA_VERSION >= 9010 STREAM_EXECUTOR_CUBLAS_WRAP(cublasGemmBatchedEx) +STREAM_EXECUTOR_CUBLAS_WRAP(cublasGemmStridedBatchedEx) #endif } // namespace wrap @@ -643,7 +648,7 @@ bool CUDABlas::DoBlasInternalImpl(FuncT cublas_func, Stream *stream, } #endif cublasStatus_t ret = cublas_func(parent_, blas_, args...); - if (err_on_failure && ret != CUBLAS_STATUS_SUCCESS) { + if ((err_on_failure || VLOG_IS_ON(3)) && ret != CUBLAS_STATUS_SUCCESS) { LOG(ERROR) << "failed to run cuBLAS routine " << cublas_func.kName << ": " << ToString(ret); } @@ -1865,7 +1870,7 @@ bool CUDABlas::DoBlasGemm( stream->parent()->GetDeviceDescription().cuda_compute_capability(&cc_major, &cc_minor); - // GPUs < sm_70 don't support Volta hardware. + // GPUs < sm_70 don't support tensor ops. if (cc_major >= 7 && TensorOpMathEnabled()) { use_tensor_ops = true; } @@ -2139,6 +2144,10 @@ static bool UsesTensorOps(blas::AlgorithmType algo) { template <typename InType> static bool TensorOpsAvailable(int cc_major) { #if CUDA_VERSION >= 9000 + // cublas *does* allow tensor ops on inputs that are not fp16, so this is not + // strictly correct. We can't simply enable it, though, as that would change + // clients' behavior significantly: Using tensor ops on fp32 inputs cause them + // to be rounded to fp16. if (cc_major >= 7 && TensorOpMathEnabled() && std::is_same<InType, Eigen::half>::value) { return true; @@ -2160,16 +2169,30 @@ bool CUDABlas::DoBlasGemmWithAlgorithmImpl( if (stream->parent()->GetDeviceDescription().cuda_compute_capability( &cc_major, &cc_minor) && cc_major < 5) { + VLOG(2) << "DoBlasGemmWithAlgorithm returning false because sm" << cc_major + << cc_minor << " devices don't support explicit gemm algorithms."; return false; } if (UsesTensorOps(algorithm) && !TensorOpsAvailable<InT>(cc_major)) { + if (std::is_same<InT, Eigen::half>::value) { + VLOG(2) << "DoBlasGemmWithAlgorithm returning false because algorithm " + << algorithm + << " uses tensor ops, but tensor ops are not available in sm" + << cc_major << "X devices."; + } else { + VLOG(2) << "DoBlasGemmWithAlgorithm returning false because algorithm " + << algorithm + << " uses tensor ops, but the input data type is not fp16."; + } return false; } // Either both 'alpha' and 'beta' need to be pointers to device memory, or // they need to be both host scalars. if (alpha.is_pointer() != beta.is_pointer()) { + VLOG(2) << "DoBlasGemmWithAlgorithm returning false because one of `alpha` " + "and `beta` is a pointer, but the other is not."; return false; } @@ -2177,6 +2200,9 @@ bool CUDABlas::DoBlasGemmWithAlgorithmImpl( if (output_profile_result != nullptr) { timer.reset(new CUDATimer(parent_)); if (!timer->Init() || !timer->Start(AsCUDAStream(stream))) { + VLOG(2) << "DoBlasGemmWithAlgorithm returning false because " + "output_profile_result was given, but we were unable to " + "create a CUDATimer."; return false; } } @@ -2186,6 +2212,8 @@ bool CUDABlas::DoBlasGemmWithAlgorithmImpl( #if CUDA_VERSION >= 9000 && CUDA_VERSION < 9020 if ((algorithm == CUBLAS_GEMM_DEFAULT || algorithm >= CUBLAS_GEMM_ALGO13) && std::max({m, n, k}) >= 2097153 && cc_major < 7) { + VLOG(2) << "DoBlasGemmWithAlgorithm returning false to work around cudnn " + "<9.2 bug with m, n, or k >= 2097153. See b/79126339."; return false; } #endif @@ -2211,6 +2239,8 @@ bool CUDABlas::DoBlasGemmWithAlgorithmImpl( // CUDATimer will CHECK-fail if we Stop() it while the stream is in an error // state. if (!timer->Stop(AsCUDAStream(stream))) { + VLOG(2) << "DoBlasGemmWithAlgorithm returning false; unable to stop " + "CUDATimer."; return false; } output_profile_result->set_is_valid(true); @@ -2223,26 +2253,60 @@ bool CUDABlas::DoBlasGemmWithAlgorithmImpl( bool CUDABlas::GetBlasGemmAlgorithms( std::vector<blas::AlgorithmType> *out_algorithms) { -// cublasGemmAlgo_t (and the function that accepts this type, cublasGemmEx) -// were first introduced in CUDA 8. -// Note that when CUDA version and compute capability is not sufficient, we -// still return the out_algorithms. Caller needs to make sure that in this case, -// the returned vector is empty. - for (cublasGemmAlgo_t algo : { - CUBLAS_GEMM_DFALT, CUBLAS_GEMM_ALGO0, CUBLAS_GEMM_ALGO1, - CUBLAS_GEMM_ALGO2, CUBLAS_GEMM_ALGO3, CUBLAS_GEMM_ALGO4, - CUBLAS_GEMM_ALGO5, CUBLAS_GEMM_ALGO6, CUBLAS_GEMM_ALGO7, + // cublasGemmAlgo_t (and the function that accepts this type, cublasGemmEx) + // were first introduced in CUDA 8. + // + // Note that when CUDA version and compute capability is not sufficient, we + // still return the out_algorithms. Caller needs to make sure that in this + // case, the returned vector is empty. + *out_algorithms = { + CUBLAS_GEMM_DFALT, + CUBLAS_GEMM_ALGO0, + CUBLAS_GEMM_ALGO1, + CUBLAS_GEMM_ALGO2, + CUBLAS_GEMM_ALGO3, + CUBLAS_GEMM_ALGO4, + CUBLAS_GEMM_ALGO5, + CUBLAS_GEMM_ALGO6, + CUBLAS_GEMM_ALGO7, #if CUDA_VERSION >= 9000 - CUBLAS_GEMM_ALGO8, CUBLAS_GEMM_ALGO9, CUBLAS_GEMM_ALGO10, - CUBLAS_GEMM_ALGO11, CUBLAS_GEMM_ALGO12, CUBLAS_GEMM_ALGO13, - CUBLAS_GEMM_ALGO14, CUBLAS_GEMM_ALGO15, CUBLAS_GEMM_ALGO16, - CUBLAS_GEMM_ALGO17, CUBLAS_GEMM_DFALT_TENSOR_OP, - CUBLAS_GEMM_ALGO0_TENSOR_OP, CUBLAS_GEMM_ALGO1_TENSOR_OP, - CUBLAS_GEMM_ALGO2_TENSOR_OP + CUBLAS_GEMM_ALGO8, + CUBLAS_GEMM_ALGO9, + CUBLAS_GEMM_ALGO10, + CUBLAS_GEMM_ALGO11, + CUBLAS_GEMM_ALGO12, + CUBLAS_GEMM_ALGO13, + CUBLAS_GEMM_ALGO14, + CUBLAS_GEMM_ALGO15, + CUBLAS_GEMM_ALGO16, + CUBLAS_GEMM_ALGO17, + CUBLAS_GEMM_DFALT_TENSOR_OP, + CUBLAS_GEMM_ALGO0_TENSOR_OP, + CUBLAS_GEMM_ALGO1_TENSOR_OP, + CUBLAS_GEMM_ALGO2_TENSOR_OP, + CUBLAS_GEMM_ALGO3_TENSOR_OP, + CUBLAS_GEMM_ALGO4_TENSOR_OP, #endif - }) { - out_algorithms->push_back(algo); - } +#if CUDA_VERSION >= 9200 + CUBLAS_GEMM_ALGO18, + CUBLAS_GEMM_ALGO19, + CUBLAS_GEMM_ALGO20, + CUBLAS_GEMM_ALGO21, + CUBLAS_GEMM_ALGO22, + CUBLAS_GEMM_ALGO23, + CUBLAS_GEMM_ALGO5_TENSOR_OP, + CUBLAS_GEMM_ALGO6_TENSOR_OP, + CUBLAS_GEMM_ALGO7_TENSOR_OP, + CUBLAS_GEMM_ALGO8_TENSOR_OP, + CUBLAS_GEMM_ALGO9_TENSOR_OP, + CUBLAS_GEMM_ALGO10_TENSOR_OP, + CUBLAS_GEMM_ALGO11_TENSOR_OP, + CUBLAS_GEMM_ALGO12_TENSOR_OP, + CUBLAS_GEMM_ALGO13_TENSOR_OP, + CUBLAS_GEMM_ALGO14_TENSOR_OP, + CUBLAS_GEMM_ALGO15_TENSOR_OP, +#endif + }; return true; } @@ -2564,6 +2628,119 @@ bool CUDABlas::DoBlasGemmBatched( return status.ok(); } +bool CUDABlas::DoBlasGemmStridedBatched( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, float alpha, const DeviceMemory<Eigen::half> &a, + int lda, int64 stride_a, const DeviceMemory<Eigen::half> &b, int ldb, + int64 stride_b, float beta, DeviceMemory<Eigen::half> *c, int ldc, + int64 stride_c, int batch_count) { + bool use_tensor_ops = false; +#if CUDA_VERSION >= 9000 + int cc_major, cc_minor; + if (stream->parent()->GetDeviceDescription().cuda_compute_capability( + &cc_major, &cc_minor)) { + // GPUs < sm_70 don't support tensor ops. + if (cc_major >= 7 && TensorOpMathEnabled()) { + use_tensor_ops = true; + } +#if CUDA_VERSION >= 9010 + if (cc_major >= 5) { + cublasGemmAlgo_t algo = + (use_tensor_ops ? CUBLAS_GEMM_DFALT_TENSOR_OP : CUBLAS_GEMM_DFALT); + bool ok = DoBlasInternalImpl( + wrap::cublasGemmStridedBatchedEx, stream, + true /* = pointer_mode_host */, true /* = err_on_failure */, + use_tensor_ops, CUDABlasTranspose(transa), CUDABlasTranspose(transb), + m, n, k, &alpha, CUDAMemory(a), CUDA_R_16F, lda, stride_a, + CUDAMemory(b), CUDA_R_16F, ldb, stride_b, &beta, CUDAMemoryMutable(c), + CUDA_R_16F, ldc, stride_c, batch_count, CUDA_R_32F, algo); + if (ok) { + return true; + } + LOG(ERROR) << "failed BLAS call, see log for details"; + return false; + } +#endif + } +#endif + // Either CUDA_VERSION < 9.1 or SM < 5.0. Fall back to a loop. + for (int batch = 0; batch < batch_count; ++batch) { + const auto *a_matrix = + reinterpret_cast<const __half *>(CUDAMemory(a) + batch * stride_a); + const auto *b_matrix = + reinterpret_cast<const __half *>(CUDAMemory(b) + batch * stride_b); + auto *c_matrix = + reinterpret_cast<__half *>(CUDAMemoryMutable(c) + batch * stride_c); + bool ok = DoBlasInternalImpl( + wrap::cublasSgemmEx, stream, true /* = pointer_mode_host */, + true /* = err_on_failure= */, use_tensor_ops, CUDABlasTranspose(transa), + CUDABlasTranspose(transb), m, n, k, &alpha, a_matrix, SE_CUDA_DATA_HALF, + lda, b_matrix, SE_CUDA_DATA_HALF, ldb, &beta, c_matrix, + SE_CUDA_DATA_HALF, ldc); + if (!ok) { + LOG(ERROR) << "failed BLAS call, see log for details"; + return false; + } + } + return true; +} + +bool CUDABlas::DoBlasGemmStridedBatched( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, int lda, + int64 stride_a, const DeviceMemory<float> &b, int ldb, int64 stride_b, + float beta, DeviceMemory<float> *c, int ldc, int64 stride_c, + int batch_count) { + return DoBlasInternal( + wrap::cublasSgemmStridedBatched, stream, true /* = pointer_mode_host */, + CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, &alpha, + CUDAMemory(a), lda, stride_a, CUDAMemory(b), ldb, stride_b, &beta, + CUDAMemoryMutable(c), ldc, stride_c, batch_count); +} + +bool CUDABlas::DoBlasGemmStridedBatched( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, double alpha, const DeviceMemory<double> &a, int lda, + int64 stride_a, const DeviceMemory<double> &b, int ldb, int64 stride_b, + double beta, DeviceMemory<double> *c, int ldc, int64 stride_c, + int batch_count) { + return DoBlasInternal( + wrap::cublasDgemmStridedBatched, stream, true /* = pointer_mode_host */, + CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, &alpha, + CUDAMemory(a), lda, stride_a, CUDAMemory(b), ldb, stride_b, &beta, + CUDAMemoryMutable(c), ldc, stride_c, batch_count); +} + +bool CUDABlas::DoBlasGemmStridedBatched( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, std::complex<float> alpha, + const DeviceMemory<std::complex<float>> &a, int lda, int64 stride_a, + const DeviceMemory<std::complex<float>> &b, int ldb, int64 stride_b, + std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc, + int64 stride_c, int batch_count) { + return DoBlasInternal( + wrap::cublasCgemmStridedBatched, stream, true /* = pointer_mode_host */, + CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, + CUDAComplex(&alpha), CUDAComplex(CUDAMemory(a)), lda, stride_a, + CUDAComplex(CUDAMemory(b)), ldb, stride_b, CUDAComplex(&beta), + CUDAComplex(CUDAMemoryMutable(c)), ldc, stride_c, batch_count); +} + +bool CUDABlas::DoBlasGemmStridedBatched( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, std::complex<double> alpha, + const DeviceMemory<std::complex<double>> &a, int lda, int64 stride_a, + const DeviceMemory<std::complex<double>> &b, int ldb, int64 stride_b, + std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc, + int64 stride_c, int batch_count) { + return DoBlasInternal( + wrap::cublasZgemmStridedBatched, stream, true /* = pointer_mode_host */, + CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, + CUDAComplex(&alpha), CUDAComplex(CUDAMemory(a)), lda, stride_a, + CUDAComplex(CUDAMemory(b)), ldb, stride_b, CUDAComplex(&beta), + CUDAComplex(CUDAMemoryMutable(c)), ldc, stride_c, batch_count); +} + bool CUDABlas::DoBlasHemm(Stream *stream, blas::Side side, blas::UpperLower uplo, uint64 m, uint64 n, std::complex<float> alpha, |