aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/cuda_solvers.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-12-19 12:38:19 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-19 12:42:09 -0800
commita37d4ae2e63648f0225b1a514b50642101c3161b (patch)
treef89797cb1f43ab42a876ce1f404830ba10ad15ea /tensorflow/core/kernels/cuda_solvers.cc
parentb4e4677a62e1ce4f47dc7c80fedb59527cfc800b (diff)
Protect all calls to launch cuSolver & cuBlas kernels by a lock. The code appears not to be threadsafe pre Cuda 9, and we have several report of crashes. Since the overhead is modest, better to be safe.
PiperOrigin-RevId: 179589983
Diffstat (limited to 'tensorflow/core/kernels/cuda_solvers.cc')
-rw-r--r--tensorflow/core/kernels/cuda_solvers.cc24
1 files changed, 16 insertions, 8 deletions
diff --git a/tensorflow/core/kernels/cuda_solvers.cc b/tensorflow/core/kernels/cuda_solvers.cc
index a83671a471..6cec032f94 100644
--- a/tensorflow/core/kernels/cuda_solvers.cc
+++ b/tensorflow/core/kernels/cuda_solvers.cc
@@ -314,6 +314,11 @@ Status CudaSolver::forward_input_or_allocate_scoped_tensor(
// are sometimes inaccurate, e.g., are missing 'const' on pointers
// to immutable arguments, while the actual headers have them as expected.
// Check the actual declarations in the cusolver_api.h header file.
+//
+// NOTE: The cuSolver functions called below appear not to be threadsafe.
+// so we put a global lock around the calls. Since these functions only put a
+// kernel on the shared stream, it is not a big performance hit.
+// TODO(rmlarsen): Investigate if the locking is still needed in Cuda 9.
//=============================================================================
template <typename Scalar, typename SolverFnT>
@@ -324,6 +329,7 @@ static inline Status GeamImpl(SolverFnT solver, cublasHandle_t cublas_handle,
const Scalar* A, int lda,
const Scalar* beta, /* host or device pointer */
const Scalar* B, int ldb, Scalar* C, int ldc) {
+ mutex_lock lock(handle_map_mutex);
using CudaScalar = typename CUDAComplexT<Scalar>::type;
TF_RETURN_IF_CUBLAS_ERROR(solver(cublas_handle, transa, transb, m, n,
reinterpret_cast<const CudaScalar*>(alpha),
@@ -355,6 +361,7 @@ static inline Status PotrfImpl(BufSizeFnT bufsize, SolverFnT solver,
cusolverDnHandle_t cusolver_dn_handle,
cublasFillMode_t uplo, int n, Scalar* A, int lda,
int* dev_lapack_info) {
+ mutex_lock lock(handle_map_mutex);
/* Get amount of workspace memory required. */
int lwork;
TF_RETURN_IF_CUSOLVER_ERROR(
@@ -387,6 +394,7 @@ static inline Status GetrfImpl(BufSizeFnT bufsize, SolverFnT solver,
cusolverDnHandle_t cusolver_dn_handle, int m,
int n, Scalar* A, int lda, int* dev_pivots,
int* dev_lapack_info) {
+ mutex_lock lock(handle_map_mutex);
/* Get amount of workspace memory required. */
int lwork;
TF_RETURN_IF_CUSOLVER_ERROR(
@@ -419,9 +427,6 @@ static inline Status GetrsImpl(SolverFnT solver, OpKernelContext* context,
cublasOperation_t trans, int n, int nrhs,
const Scalar* A, int lda, const int* pivots,
Scalar* B, int ldb, int* dev_lapack_info) {
- // Note: The cuSolver functions called here appear not to be threadsafe.
- // so we put a global lock around it. Since this function only puts a
- // kernel on the stream, it is not a big performance hit.
mutex_lock lock(handle_map_mutex);
/* Launch the solver kernel. */
TF_RETURN_IF_CUSOLVER_ERROR(solver(cusolver_dn_handle, trans, n, nrhs,
@@ -449,6 +454,7 @@ static inline Status GeqrfImpl(BufSizeFnT bufsize, SolverFnT solver,
cusolverDnHandle_t cusolver_dn_handle, int m,
int n, Scalar* A, int lda, Scalar* tau,
int* dev_lapack_info) {
+ mutex_lock lock(handle_map_mutex);
/* Get amount of workspace memory required. */
int lwork;
TF_RETURN_IF_CUSOLVER_ERROR(
@@ -483,6 +489,7 @@ static inline Status UnmqrImpl(BufSizeFnT bufsize, SolverFnT solver,
int m, int n, int k, const Scalar* dev_a,
int lda, const Scalar* dev_tau, Scalar* dev_c,
int ldc, int* dev_lapack_info) {
+ mutex_lock lock(handle_map_mutex);
/* Get amount of workspace memory required. */
int lwork;
TF_RETURN_IF_CUSOLVER_ERROR(
@@ -526,6 +533,7 @@ static inline Status UngqrImpl(BufSizeFnT bufsize, SolverFnT solver,
cusolverDnHandle_t cusolver_dn_handle, int m,
int n, int k, Scalar* dev_a, int lda,
const Scalar* dev_tau, int* dev_lapack_info) {
+ mutex_lock lock(handle_map_mutex);
/* Get amount of workspace memory required. */
int lwork;
TF_RETURN_IF_CUSOLVER_ERROR(bufsize(cusolver_dn_handle, m, n, k,
@@ -606,17 +614,13 @@ static inline Status GesvdImpl(
OpKernelContext* context, cusolverDnHandle_t cusolver_dn_handle,
signed char jobu, signed char jobvt, int m, int n, Scalar* A, int lda,
Scalar* S, Scalar* U, int ldu, Scalar* VT, int ldvt, int* dev_lapack_info) {
+ mutex_lock lock(handle_map_mutex);
/* Get amount of workspace memory required. */
int lwork;
TF_RETURN_IF_CUSOLVER_ERROR(bufsize(cusolver_dn_handle, m, n, &lwork));
/* Allocate device memory for workspace. */
auto dev_workspace =
cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false);
- // Note: The cuSolver functions called here appear not to be threadsafe.
- // so we put a global lock around it. Since this function only puts a
- // kernel on the stream, it is not a big performance hit.
- mutex_lock lock(handle_map_mutex);
- /* Launch the solver kernel. */
TF_RETURN_IF_CUSOLVER_ERROR(solver(cusolver_dn_handle, jobu, jobvt, m, n,
CUDAComplex(A), lda, S, CUDAComplex(U),
ldu, CUDAComplex(VT), ldvt,
@@ -655,6 +659,7 @@ static inline Status GetrfBatchedImpl(SolverFnT solver, CudaSolver* cuda_solver,
int lda, int* dev_pivots,
DeviceLapackInfo* dev_lapack_info,
int batch_size) {
+ mutex_lock lock(handle_map_mutex);
using CudaScalar = typename CUDAComplexT<Scalar>::type;
ScratchSpace<uint8> dev_a_dev_ptrs =
cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "",
@@ -689,6 +694,7 @@ static inline Status GetrsBatchedImpl(
const Scalar* const host_a_dev_ptrs[], int lda, const int* dev_pivots,
const Scalar* const host_b_dev_ptrs[], int ldb,
DeviceLapackInfo* dev_lapack_info, int batch_size) {
+ mutex_lock lock(handle_map_mutex);
using CudaScalar = typename CUDAComplexT<Scalar>::type;
ScratchSpace<uint8> dev_a_dev_ptrs =
cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "",
@@ -734,6 +740,7 @@ static inline Status GetriBatchedImpl(
cublasHandle_t cublas_handle, int n, const Scalar* const host_a_dev_ptrs[],
int lda, const int* dev_pivots, const Scalar* const host_a_inv_dev_ptrs[],
int ldainv, DeviceLapackInfo* dev_lapack_info, int batch_size) {
+ mutex_lock lock(handle_map_mutex);
using CudaScalar = typename CUDAComplexT<Scalar>::type;
ScratchSpace<uint8> dev_a_dev_ptrs =
cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "",
@@ -776,6 +783,7 @@ static inline Status MatInvBatchedImpl(
cublasHandle_t cublas_handle, int n, const Scalar* const host_a_dev_ptrs[],
int lda, const Scalar* const host_a_inv_dev_ptrs[], int ldainv,
DeviceLapackInfo* dev_lapack_info, int batch_size) {
+ mutex_lock lock(handle_map_mutex);
using CudaScalar = typename CUDAComplexT<Scalar>::type;
ScratchSpace<uint8> dev_a_dev_ptrs =
cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "",