diff options
Diffstat (limited to 'tensorflow/core/kernels/cuda_solvers.h')
-rw-r--r-- | tensorflow/core/kernels/cuda_solvers.h | 16 |
1 files changed, 13 insertions, 3 deletions
diff --git a/tensorflow/core/kernels/cuda_solvers.h b/tensorflow/core/kernels/cuda_solvers.h index 0fd6450f98..7cbdc895dd 100644 --- a/tensorflow/core/kernels/cuda_solvers.h +++ b/tensorflow/core/kernels/cuda_solvers.h @@ -258,13 +258,23 @@ class CudaSolver { Status Syevd(cusolverEigMode_t jobz, cublasFillMode_t uplo, int n, Scalar* dev_A, int lda, Scalar* dev_W, int* dev_lapack_info) const; +*/ // Singular value decomposition. // See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-gesvd template <typename Scalar> Status Gesvd(signed char jobu, signed char jobvt, int m, int n, Scalar* dev_A, - int lda, Scalar* dev_S, Scalar* dev_U, int ldu, Scalar* dev_VT, - int ldvt, int* dev_lapack_info); - */ + int lda, Scalar* dev_S, Scalar* dev_U, int ldu, Scalar* dev_VT, + int ldvt, int* dev_lapack_info) const; + /* + // Batched linear solver using LU factorization from getrfBatched. + // See: + http://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-getrsbatched + template <typename Scalar> + Status GetrsBatched(cublasOperation_t trans, int n, int nrhs, + const Scalar* dev_Aarray[], int lda, const int* devIpiv, + Scalar* dev_Barray[], int ldb, int* info, int batch_size) + const; + */ private: OpKernelContext* context_; // not owned. |