aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/cuda_solvers.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/cuda_solvers.h')
-rw-r--r--tensorflow/core/kernels/cuda_solvers.h16
1 files changed, 13 insertions, 3 deletions
diff --git a/tensorflow/core/kernels/cuda_solvers.h b/tensorflow/core/kernels/cuda_solvers.h
index 0fd6450f98..7cbdc895dd 100644
--- a/tensorflow/core/kernels/cuda_solvers.h
+++ b/tensorflow/core/kernels/cuda_solvers.h
@@ -258,13 +258,23 @@ class CudaSolver {
Status Syevd(cusolverEigMode_t jobz, cublasFillMode_t uplo, int n, Scalar*
dev_A, int lda, Scalar* dev_W, int* dev_lapack_info) const;
+*/
// Singular value decomposition.
// See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-gesvd
template <typename Scalar>
Status Gesvd(signed char jobu, signed char jobvt, int m, int n, Scalar* dev_A,
- int lda, Scalar* dev_S, Scalar* dev_U, int ldu, Scalar* dev_VT,
- int ldvt, int* dev_lapack_info);
- */
+ int lda, Scalar* dev_S, Scalar* dev_U, int ldu, Scalar* dev_VT,
+ int ldvt, int* dev_lapack_info) const;
+ /*
+ // Batched linear solver using LU factorization from getrfBatched.
+ // See:
+ http://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-getrsbatched
+ template <typename Scalar>
+ Status GetrsBatched(cublasOperation_t trans, int n, int nrhs,
+ const Scalar* dev_Aarray[], int lda, const int* devIpiv,
+ Scalar* dev_Barray[], int ldb, int* info, int batch_size)
+ const;
+ */
private:
OpKernelContext* context_; // not owned.