diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-08-17 10:14:12 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-08-17 10:17:33 -0700 |
commit | d568d6362d09031cf1a483aadc15fc68ec034767 (patch) | |
tree | 62c410a2392b7d129e9c18e30cbef80cf83938b2 /tensorflow/core/kernels/cuda_solvers.h | |
parent | c1f69be22e151e2d051f41fccf436767eee4a26a (diff) |
Add GPU implementation of tf.matrix_solve.
Add benchmark for tf.matrix_solve.
PiperOrigin-RevId: 165593559
Diffstat (limited to 'tensorflow/core/kernels/cuda_solvers.h')
-rw-r--r-- | tensorflow/core/kernels/cuda_solvers.h | 18 |
1 files changed, 9 insertions, 9 deletions
diff --git a/tensorflow/core/kernels/cuda_solvers.h b/tensorflow/core/kernels/cuda_solvers.h index 8166bbc505..ac6119d8a2 100644 --- a/tensorflow/core/kernels/cuda_solvers.h +++ b/tensorflow/core/kernels/cuda_solvers.h @@ -197,6 +197,15 @@ class CudaSolver { int* dev_pivots, DeviceLapackInfo* dev_lapack_info, int batch_size) const; + // Batched linear solver using LU factorization from getrfBatched. + // See: + // http://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-getrsbatched + template <typename Scalar> + Status GetrsBatched(cublasOperation_t trans, int n, int nrhs, + const Scalar* dev_Aarray[], int lda, const int* devIpiv, + const Scalar* dev_Barray[], int ldb, + DeviceLapackInfo* dev_lapack_info, int batch_size) const; + // Computes matrix inverses for a batch of small matrices. Uses the outputs // from GetrfBatched. Returns Status::OK() if the kernel was launched // successfully. See: @@ -255,15 +264,6 @@ class CudaSolver { Status Gesvd(signed char jobu, signed char jobvt, int m, int n, Scalar* dev_A, int lda, Scalar* dev_S, Scalar* dev_U, int ldu, Scalar* dev_VT, int ldvt, int* dev_lapack_info); - - // Batched linear solver using LU factorization from getrfBatched. - // See: - http://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-getrsbatched - template <typename Scalar> - Status GetrsBatched(cublasOperation_t trans, int n, int nrhs, - const Scalar* dev_Aarray[], int lda, const int* devIpiv, - Scalar* dev_Barray[], int ldb, int* info, int batch_size) - const; */ private: |