aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-13 10:20:26 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-13 10:23:16 -0700
commit4880423ae9d2785faaffccea965f5b223f1318b0 (patch)
tree0db8ecabb5a3c7d4a6cb3178a3d339d450318801 /tensorflow/stream_executor
parent65cefda2f9a62f29af51b3effa0725c180244576 (diff)
Detect configurations that would be hitting a bug in cuBLAS and report an error.
PiperOrigin-RevId: 200411493
Diffstat (limited to 'tensorflow/stream_executor')
-rw-r--r--tensorflow/stream_executor/cuda/cuda_blas.cc13
1 files changed, 8 insertions, 5 deletions
diff --git a/tensorflow/stream_executor/cuda/cuda_blas.cc b/tensorflow/stream_executor/cuda/cuda_blas.cc
index 08fe153b59..92c1a5fc07 100644
--- a/tensorflow/stream_executor/cuda/cuda_blas.cc
+++ b/tensorflow/stream_executor/cuda/cuda_blas.cc
@@ -2155,10 +2155,7 @@ bool CUDABlas::DoBlasGemmWithAlgorithmImpl(
const HostOrDeviceScalar<CompT> &beta, DeviceMemory<OutT> *c, int ldc,
blas::ComputationType computation_type, blas::AlgorithmType algorithm,
blas::ProfileResult *output_profile_result) {
-// CUDA < version 8 and GPUs < sm_50 don't support cublasGemmEx.
-#if CUDA_VERSION < 8000
- return false;
-#else
+ // GPUs < sm_50 don't support cublasGemmEx.
int cc_major, cc_minor;
if (stream->parent()->GetDeviceDescription().cuda_compute_capability(
&cc_major, &cc_minor) &&
@@ -2184,6 +2181,13 @@ bool CUDABlas::DoBlasGemmWithAlgorithmImpl(
}
}
+ // Return false if we might be hitting a cuBLAS bug that produces the wrong
+ // result. See nvbugs/2156201, b/79126339.
+ if (CUDA_VERSION < 9020 && algorithm != CUBLAS_GEMM_ALGO12 &&
+ std::max({m, n, k}) >= 2097153 && cc_major < 7) {
+ return false;
+ }
+
cudaDataType_t cuda_in_type = CUDADataType<InT>::type;
// Since we are converting 'algorithm' to cublasGemmAlgo_t by static_cast,
// we do the following compile-time check on the default value:
@@ -2213,7 +2217,6 @@ bool CUDABlas::DoBlasGemmWithAlgorithmImpl(
timer->GetElapsedMilliseconds());
}
return result;
-#endif
}
bool CUDABlas::GetBlasGemmAlgorithms(