diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-06-13 10:20:26 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-13 10:23:16 -0700 |
commit | 4880423ae9d2785faaffccea965f5b223f1318b0 (patch) | |
tree | 0db8ecabb5a3c7d4a6cb3178a3d339d450318801 /tensorflow/stream_executor | |
parent | 65cefda2f9a62f29af51b3effa0725c180244576 (diff) |
Detect configurations that would be hitting a bug in cuBLAS and report an error.
PiperOrigin-RevId: 200411493
Diffstat (limited to 'tensorflow/stream_executor')
-rw-r--r-- | tensorflow/stream_executor/cuda/cuda_blas.cc | 13 |
1 files changed, 8 insertions, 5 deletions
diff --git a/tensorflow/stream_executor/cuda/cuda_blas.cc b/tensorflow/stream_executor/cuda/cuda_blas.cc index 08fe153b59..92c1a5fc07 100644 --- a/tensorflow/stream_executor/cuda/cuda_blas.cc +++ b/tensorflow/stream_executor/cuda/cuda_blas.cc @@ -2155,10 +2155,7 @@ bool CUDABlas::DoBlasGemmWithAlgorithmImpl( const HostOrDeviceScalar<CompT> &beta, DeviceMemory<OutT> *c, int ldc, blas::ComputationType computation_type, blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) { -// CUDA < version 8 and GPUs < sm_50 don't support cublasGemmEx. -#if CUDA_VERSION < 8000 - return false; -#else + // GPUs < sm_50 don't support cublasGemmEx. int cc_major, cc_minor; if (stream->parent()->GetDeviceDescription().cuda_compute_capability( &cc_major, &cc_minor) && @@ -2184,6 +2181,13 @@ bool CUDABlas::DoBlasGemmWithAlgorithmImpl( } } + // Return false if we might be hitting a cuBLAS bug that produces the wrong + // result. See nvbugs/2156201, b/79126339. + if (CUDA_VERSION < 9020 && algorithm != CUBLAS_GEMM_ALGO12 && + std::max({m, n, k}) >= 2097153 && cc_major < 7) { + return false; + } + cudaDataType_t cuda_in_type = CUDADataType<InT>::type; // Since we are converting 'algorithm' to cublasGemmAlgo_t by static_cast, // we do the following compile-time check on the default value: @@ -2213,7 +2217,6 @@ bool CUDABlas::DoBlasGemmWithAlgorithmImpl( timer->GetElapsedMilliseconds()); } return result; -#endif } bool CUDABlas::GetBlasGemmAlgorithms( |