diff options
author | 2017-12-19 12:38:19 -0800 | |
---|---|---|
committer | 2017-12-19 12:42:09 -0800 | |
commit | a37d4ae2e63648f0225b1a514b50642101c3161b (patch) | |
tree | f89797cb1f43ab42a876ce1f404830ba10ad15ea /tensorflow/core/kernels/cuda_solvers.cc | |
parent | b4e4677a62e1ce4f47dc7c80fedb59527cfc800b (diff) |
Protect all calls to launch cuSolver & cuBlas kernels by a lock. The code appears not to be threadsafe pre Cuda 9, and we have several report of crashes. Since the overhead is modest, better to be safe.
PiperOrigin-RevId: 179589983
Diffstat (limited to 'tensorflow/core/kernels/cuda_solvers.cc')
-rw-r--r-- | tensorflow/core/kernels/cuda_solvers.cc | 24 |
1 files changed, 16 insertions, 8 deletions
diff --git a/tensorflow/core/kernels/cuda_solvers.cc b/tensorflow/core/kernels/cuda_solvers.cc index a83671a471..6cec032f94 100644 --- a/tensorflow/core/kernels/cuda_solvers.cc +++ b/tensorflow/core/kernels/cuda_solvers.cc @@ -314,6 +314,11 @@ Status CudaSolver::forward_input_or_allocate_scoped_tensor( // are sometimes inaccurate, e.g., are missing 'const' on pointers // to immutable arguments, while the actual headers have them as expected. // Check the actual declarations in the cusolver_api.h header file. +// +// NOTE: The cuSolver functions called below appear not to be threadsafe. +// so we put a global lock around the calls. Since these functions only put a +// kernel on the shared stream, it is not a big performance hit. +// TODO(rmlarsen): Investigate if the locking is still needed in Cuda 9. //============================================================================= template <typename Scalar, typename SolverFnT> @@ -324,6 +329,7 @@ static inline Status GeamImpl(SolverFnT solver, cublasHandle_t cublas_handle, const Scalar* A, int lda, const Scalar* beta, /* host or device pointer */ const Scalar* B, int ldb, Scalar* C, int ldc) { + mutex_lock lock(handle_map_mutex); using CudaScalar = typename CUDAComplexT<Scalar>::type; TF_RETURN_IF_CUBLAS_ERROR(solver(cublas_handle, transa, transb, m, n, reinterpret_cast<const CudaScalar*>(alpha), @@ -355,6 +361,7 @@ static inline Status PotrfImpl(BufSizeFnT bufsize, SolverFnT solver, cusolverDnHandle_t cusolver_dn_handle, cublasFillMode_t uplo, int n, Scalar* A, int lda, int* dev_lapack_info) { + mutex_lock lock(handle_map_mutex); /* Get amount of workspace memory required. */ int lwork; TF_RETURN_IF_CUSOLVER_ERROR( @@ -387,6 +394,7 @@ static inline Status GetrfImpl(BufSizeFnT bufsize, SolverFnT solver, cusolverDnHandle_t cusolver_dn_handle, int m, int n, Scalar* A, int lda, int* dev_pivots, int* dev_lapack_info) { + mutex_lock lock(handle_map_mutex); /* Get amount of workspace memory required. */ int lwork; TF_RETURN_IF_CUSOLVER_ERROR( @@ -419,9 +427,6 @@ static inline Status GetrsImpl(SolverFnT solver, OpKernelContext* context, cublasOperation_t trans, int n, int nrhs, const Scalar* A, int lda, const int* pivots, Scalar* B, int ldb, int* dev_lapack_info) { - // Note: The cuSolver functions called here appear not to be threadsafe. - // so we put a global lock around it. Since this function only puts a - // kernel on the stream, it is not a big performance hit. mutex_lock lock(handle_map_mutex); /* Launch the solver kernel. */ TF_RETURN_IF_CUSOLVER_ERROR(solver(cusolver_dn_handle, trans, n, nrhs, @@ -449,6 +454,7 @@ static inline Status GeqrfImpl(BufSizeFnT bufsize, SolverFnT solver, cusolverDnHandle_t cusolver_dn_handle, int m, int n, Scalar* A, int lda, Scalar* tau, int* dev_lapack_info) { + mutex_lock lock(handle_map_mutex); /* Get amount of workspace memory required. */ int lwork; TF_RETURN_IF_CUSOLVER_ERROR( @@ -483,6 +489,7 @@ static inline Status UnmqrImpl(BufSizeFnT bufsize, SolverFnT solver, int m, int n, int k, const Scalar* dev_a, int lda, const Scalar* dev_tau, Scalar* dev_c, int ldc, int* dev_lapack_info) { + mutex_lock lock(handle_map_mutex); /* Get amount of workspace memory required. */ int lwork; TF_RETURN_IF_CUSOLVER_ERROR( @@ -526,6 +533,7 @@ static inline Status UngqrImpl(BufSizeFnT bufsize, SolverFnT solver, cusolverDnHandle_t cusolver_dn_handle, int m, int n, int k, Scalar* dev_a, int lda, const Scalar* dev_tau, int* dev_lapack_info) { + mutex_lock lock(handle_map_mutex); /* Get amount of workspace memory required. */ int lwork; TF_RETURN_IF_CUSOLVER_ERROR(bufsize(cusolver_dn_handle, m, n, k, @@ -606,17 +614,13 @@ static inline Status GesvdImpl( OpKernelContext* context, cusolverDnHandle_t cusolver_dn_handle, signed char jobu, signed char jobvt, int m, int n, Scalar* A, int lda, Scalar* S, Scalar* U, int ldu, Scalar* VT, int ldvt, int* dev_lapack_info) { + mutex_lock lock(handle_map_mutex); /* Get amount of workspace memory required. */ int lwork; TF_RETURN_IF_CUSOLVER_ERROR(bufsize(cusolver_dn_handle, m, n, &lwork)); /* Allocate device memory for workspace. */ auto dev_workspace = cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false); - // Note: The cuSolver functions called here appear not to be threadsafe. - // so we put a global lock around it. Since this function only puts a - // kernel on the stream, it is not a big performance hit. - mutex_lock lock(handle_map_mutex); - /* Launch the solver kernel. */ TF_RETURN_IF_CUSOLVER_ERROR(solver(cusolver_dn_handle, jobu, jobvt, m, n, CUDAComplex(A), lda, S, CUDAComplex(U), ldu, CUDAComplex(VT), ldvt, @@ -655,6 +659,7 @@ static inline Status GetrfBatchedImpl(SolverFnT solver, CudaSolver* cuda_solver, int lda, int* dev_pivots, DeviceLapackInfo* dev_lapack_info, int batch_size) { + mutex_lock lock(handle_map_mutex); using CudaScalar = typename CUDAComplexT<Scalar>::type; ScratchSpace<uint8> dev_a_dev_ptrs = cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "", @@ -689,6 +694,7 @@ static inline Status GetrsBatchedImpl( const Scalar* const host_a_dev_ptrs[], int lda, const int* dev_pivots, const Scalar* const host_b_dev_ptrs[], int ldb, DeviceLapackInfo* dev_lapack_info, int batch_size) { + mutex_lock lock(handle_map_mutex); using CudaScalar = typename CUDAComplexT<Scalar>::type; ScratchSpace<uint8> dev_a_dev_ptrs = cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "", @@ -734,6 +740,7 @@ static inline Status GetriBatchedImpl( cublasHandle_t cublas_handle, int n, const Scalar* const host_a_dev_ptrs[], int lda, const int* dev_pivots, const Scalar* const host_a_inv_dev_ptrs[], int ldainv, DeviceLapackInfo* dev_lapack_info, int batch_size) { + mutex_lock lock(handle_map_mutex); using CudaScalar = typename CUDAComplexT<Scalar>::type; ScratchSpace<uint8> dev_a_dev_ptrs = cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "", @@ -776,6 +783,7 @@ static inline Status MatInvBatchedImpl( cublasHandle_t cublas_handle, int n, const Scalar* const host_a_dev_ptrs[], int lda, const Scalar* const host_a_inv_dev_ptrs[], int ldainv, DeviceLapackInfo* dev_lapack_info, int batch_size) { + mutex_lock lock(handle_map_mutex); using CudaScalar = typename CUDAComplexT<Scalar>::type; ScratchSpace<uint8> dev_a_dev_ptrs = cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "", |