aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/cuda_solvers.h
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-08-17 10:14:12 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-08-17 10:17:33 -0700
commitd568d6362d09031cf1a483aadc15fc68ec034767 (patch)
tree62c410a2392b7d129e9c18e30cbef80cf83938b2 /tensorflow/core/kernels/cuda_solvers.h
parentc1f69be22e151e2d051f41fccf436767eee4a26a (diff)
Add GPU implementation of tf.matrix_solve.
Add benchmark for tf.matrix_solve. PiperOrigin-RevId: 165593559
Diffstat (limited to 'tensorflow/core/kernels/cuda_solvers.h')
-rw-r--r--tensorflow/core/kernels/cuda_solvers.h18
1 files changed, 9 insertions, 9 deletions
diff --git a/tensorflow/core/kernels/cuda_solvers.h b/tensorflow/core/kernels/cuda_solvers.h
index 8166bbc505..ac6119d8a2 100644
--- a/tensorflow/core/kernels/cuda_solvers.h
+++ b/tensorflow/core/kernels/cuda_solvers.h
@@ -197,6 +197,15 @@ class CudaSolver {
int* dev_pivots, DeviceLapackInfo* dev_lapack_info,
int batch_size) const;
+ // Batched linear solver using LU factorization from getrfBatched.
+ // See:
+ // http://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-getrsbatched
+ template <typename Scalar>
+ Status GetrsBatched(cublasOperation_t trans, int n, int nrhs,
+ const Scalar* dev_Aarray[], int lda, const int* devIpiv,
+ const Scalar* dev_Barray[], int ldb,
+ DeviceLapackInfo* dev_lapack_info, int batch_size) const;
+
// Computes matrix inverses for a batch of small matrices. Uses the outputs
// from GetrfBatched. Returns Status::OK() if the kernel was launched
// successfully. See:
@@ -255,15 +264,6 @@ class CudaSolver {
Status Gesvd(signed char jobu, signed char jobvt, int m, int n, Scalar* dev_A,
int lda, Scalar* dev_S, Scalar* dev_U, int ldu, Scalar* dev_VT,
int ldvt, int* dev_lapack_info);
-
- // Batched linear solver using LU factorization from getrfBatched.
- // See:
- http://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-getrsbatched
- template <typename Scalar>
- Status GetrsBatched(cublasOperation_t trans, int n, int nrhs,
- const Scalar* dev_Aarray[], int lda, const int* devIpiv,
- Scalar* dev_Barray[], int ldb, int* info, int batch_size)
- const;
*/
private: