aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-09-19 09:09:05 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-19 09:12:47 -0700
commit34018f8fa7290650291bbd478534e58c128a5df4 (patch)
treed227f22d09089ee4cad773d490b39ecee4300018
parentec962ff63820e3ab9f5cc4c5f37c3579be0afcd9 (diff)
Add GPU support for QR decomposition.
Remove support support for on-the-fly transpose in internal matrix_band_part functor recently added (in anticipation of using it for QR), since it turned out to not be useful. PiperOrigin-RevId: 169249336
-rw-r--r--tensorflow/core/kernels/BUILD6
-rw-r--r--tensorflow/core/kernels/cholesky_op.cc23
-rw-r--r--tensorflow/core/kernels/cuda_solvers.cc152
-rw-r--r--tensorflow/core/kernels/cuda_solvers.h97
-rw-r--r--tensorflow/core/kernels/cuda_solvers_gpu.cu.cc5
-rw-r--r--tensorflow/core/kernels/matrix_band_part_op.cc118
-rw-r--r--tensorflow/core/kernels/matrix_band_part_op.h2
-rw-r--r--tensorflow/core/kernels/matrix_band_part_op_gpu.cu.cc56
-rw-r--r--tensorflow/core/kernels/qr_op_complex128.cc4
-rw-r--r--tensorflow/core/kernels/qr_op_complex64.cc4
-rw-r--r--tensorflow/core/kernels/qr_op_double.cc4
-rw-r--r--tensorflow/core/kernels/qr_op_float.cc4
-rw-r--r--tensorflow/core/kernels/qr_op_impl.h198
-rw-r--r--tensorflow/python/kernel_tests/qr_op_test.py88
14 files changed, 547 insertions, 214 deletions
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index cff6e30c04..dcbbe5335d 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -2334,7 +2334,11 @@ tf_kernel_library(
tf_kernel_library(
name = "qr_op",
prefix = "qr_op",
- deps = LINALG_DEPS,
+ deps = LINALG_DEPS + if_cuda([
+ ":cuda_solvers",
+ ":transpose_functor",
+ ":matrix_band_part_op",
+ ]),
)
tf_kernel_library(
diff --git a/tensorflow/core/kernels/cholesky_op.cc b/tensorflow/core/kernels/cholesky_op.cc
index 6668b0d654..3adff530f7 100644
--- a/tensorflow/core/kernels/cholesky_op.cc
+++ b/tensorflow/core/kernels/cholesky_op.cc
@@ -76,14 +76,14 @@ class CholeskyOp : public LinearAlgebraOp<Scalar> {
typedef Eigen::GpuDevice GPUDevice;
namespace functor {
-#define DECLARE_GPU_SPEC(T) \
- template <> \
- struct MatrixBandPartFunctor<GPUDevice, T> { \
- void operator()(OpKernelContext* context, const GPUDevice& device, \
- int num_upper_diags, int num_lower_diags, bool transpose, \
- typename TTypes<T, 3>::ConstTensor input, \
- typename TTypes<T, 3>::Tensor output); \
- }; \
+#define DECLARE_GPU_SPEC(T) \
+ template <> \
+ struct MatrixBandPartFunctor<GPUDevice, T> { \
+ void operator()(OpKernelContext* context, const GPUDevice& device, \
+ int num_upper_diags, int num_lower_diags, \
+ typename TTypes<T, 3>::ConstTensor input, \
+ typename TTypes<T, 3>::Tensor output); \
+ }; \
extern template struct MatrixBandPartFunctor<GPUDevice, T>;
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
@@ -132,9 +132,10 @@ class CholeskyOpGpu : public AsyncOpKernel {
// before we launch each of the Cholesky factorization kernels in paralle.
auto input_reshaped = input.template flat_inner_dims<Scalar, 3>();
auto output_reshaped = output->template flat_inner_dims<Scalar, 3>();
- functor::MatrixBandPartFunctor<GPUDevice, Scalar> fn;
- fn(context, context->eigen_device<GPUDevice>(), n, 0, false /* transpose */,
- input_reshaped, output_reshaped);
+ functor::MatrixBandPartFunctor<GPUDevice, Scalar> band_part;
+ band_part(context, context->eigen_device<GPUDevice>(),
+ n /* num_lower_diags */, 0 /* num_upper_diags */, input_reshaped,
+ output_reshaped);
// Launch a Cholesky kernel for each matrix in the batch.
const int64 batch_size = input_reshaped.dimension(0);
diff --git a/tensorflow/core/kernels/cuda_solvers.cc b/tensorflow/core/kernels/cuda_solvers.cc
index 43197d8cf4..85f1473c6c 100644
--- a/tensorflow/core/kernels/cuda_solvers.cc
+++ b/tensorflow/core/kernels/cuda_solvers.cc
@@ -174,7 +174,7 @@ Status CudaSolver::CopyLapackInfoToHostAsync(
}
info_checker_callback(status, host_lapack_infos);
};
-
+
auto cb =
std::bind(wrapped_info_checker_callback, context_,
std::move(info_checker_callback), std::move(host_lapack_infos));
@@ -363,6 +363,156 @@ static inline Status GesvdImpl(BufSizeFnT bufsize, SolverFnT solver,
TF_CALL_LAPACK_TYPES_NO_COMPLEX(GESVD_INSTANCE);
+template <typename Scalar, typename BufSizeFnT, typename SolverFnT>
+static inline Status GeqrfImpl(BufSizeFnT bufsize, SolverFnT solver,
+ OpKernelContext* context,
+ cusolverDnHandle_t cusolver_dn_handle, int m,
+ int n, Scalar* A, int lda, Scalar* tau,
+ int* dev_lapack_info) {
+ /* Get amount of workspace memory required. */
+ int lwork;
+ TF_RETURN_IF_CUSOLVER_ERROR(
+ bufsize(cusolver_dn_handle, m, n, CUDAComplex(A), lda, &lwork));
+ /* Allocate device memory for workspace. */
+ ScratchSpace<Scalar> dev_workspace(context, lwork, /* on_host */ false);
+ /* Launch the solver kernel. */
+ TF_RETURN_IF_CUSOLVER_ERROR(solver(
+ cusolver_dn_handle, m, n, CUDAComplex(A), lda, CUDAComplex(tau),
+ CUDAComplex(dev_workspace.mutable_data()), lwork, dev_lapack_info));
+ return Status::OK();
+}
+
+#define GEQRF_INSTANCE(Scalar, lapack_prefix) \
+ template <> \
+ Status CudaSolver::Geqrf<Scalar>(int m, int n, Scalar* A, int lda, \
+ Scalar* tau, int* dev_lapack_info) const { \
+ return GeqrfImpl(DN_BUFSIZE_FN(geqrf, lapack_prefix), \
+ DN_SOLVER_FN(geqrf, lapack_prefix), context_, \
+ cusolver_dn_handle_, m, n, A, lda, tau, dev_lapack_info); \
+ }
+
+TF_CALL_LAPACK_TYPES(GEQRF_INSTANCE);
+
+template <typename Scalar, typename BufSizeFnT, typename SolverFnT>
+static inline Status OrmqrImpl(BufSizeFnT bufsize, SolverFnT solver,
+ OpKernelContext* context,
+ cusolverDnHandle_t cusolver_dn_handle,
+ cublasSideMode_t side, cublasOperation_t trans,
+ int m, int n, int k, const Scalar* dev_a,
+ int lda, const Scalar* dev_tau, Scalar* dev_c,
+ int ldc, int* dev_lapack_info) {
+ /* Get amount of workspace memory required. */
+ int lwork;
+ TF_RETURN_IF_CUSOLVER_ERROR(
+ bufsize(cusolver_dn_handle, side, trans, m, n, k, CUDAComplex(dev_a), lda,
+ CUDAComplex(dev_tau), CUDAComplex(dev_c), ldc, &lwork));
+ /* Allocate device memory for workspace. */
+ ScratchSpace<Scalar> dev_workspace(context, lwork, /* on_host */ false);
+ /* Launch the solver kernel. */
+ TF_RETURN_IF_CUSOLVER_ERROR(solver(
+ cusolver_dn_handle, side, trans, m, n, k, CUDAComplex(dev_a), lda,
+ CUDAComplex(dev_tau), CUDAComplex(dev_c), ldc,
+ CUDAComplex(dev_workspace.mutable_data()), lwork, dev_lapack_info));
+ return Status::OK();
+}
+
+// Unfortunately the LAPACK function name differs for the real and complex case
+// (complex ones are prefixed with "UN" for "unitary"), so we instantiate each
+// one separately.
+template <>
+Status CudaSolver::Ormqr(cublasSideMode_t side, cublasOperation_t trans, int m,
+ int n, int k, const float* dev_a, int lda,
+ const float* dev_tau, float* dev_c, int ldc,
+ int* dev_lapack_info) const {
+ return OrmqrImpl(DN_BUFSIZE_FN(ormqr, S), DN_SOLVER_FN(ormqr, S), context_,
+ cusolver_dn_handle_, side, trans, m, n, k, dev_a, lda,
+ dev_tau, dev_c, ldc, dev_lapack_info);
+}
+template <>
+Status CudaSolver::Ormqr(cublasSideMode_t side, cublasOperation_t trans, int m,
+ int n, int k, const double* dev_a, int lda,
+ const double* dev_tau, double* dev_c, int ldc,
+ int* dev_lapack_info) const {
+ return OrmqrImpl(DN_BUFSIZE_FN(ormqr, D), DN_SOLVER_FN(ormqr, D), context_,
+ cusolver_dn_handle_, side, trans, m, n, k, dev_a, lda,
+ dev_tau, dev_c, ldc, dev_lapack_info);
+}
+template <>
+Status CudaSolver::Ormqr(cublasSideMode_t side, cublasOperation_t trans, int m,
+ int n, int k, const std::complex<float>* dev_a,
+ int lda, const std::complex<float>* dev_tau,
+ std::complex<float>* dev_c, int ldc,
+ int* dev_lapack_info) const {
+ return OrmqrImpl(DN_BUFSIZE_FN(unmqr, C), DN_SOLVER_FN(unmqr, C), context_,
+ cusolver_dn_handle_, side, trans, m, n, k, dev_a, lda,
+ dev_tau, dev_c, ldc, dev_lapack_info);
+}
+template <>
+Status CudaSolver::Ormqr(cublasSideMode_t side, cublasOperation_t trans, int m,
+ int n, int k, const std::complex<double>* dev_a,
+ int lda, const std::complex<double>* dev_tau,
+ std::complex<double>* dev_c, int ldc,
+ int* dev_lapack_info) const {
+ return OrmqrImpl(DN_BUFSIZE_FN(unmqr, Z), DN_SOLVER_FN(unmqr, Z), context_,
+ cusolver_dn_handle_, side, trans, m, n, k, dev_a, lda,
+ dev_tau, dev_c, ldc, dev_lapack_info);
+}
+
+template <typename Scalar, typename BufSizeFnT, typename SolverFnT>
+static inline Status OrgqrImpl(BufSizeFnT bufsize, SolverFnT solver,
+ OpKernelContext* context,
+ cusolverDnHandle_t cusolver_dn_handle, int m,
+ int n, int k, Scalar* dev_a, int lda,
+ const Scalar* dev_tau, int* dev_lapack_info) {
+ /* Get amount of workspace memory required. */
+ int lwork;
+ TF_RETURN_IF_CUSOLVER_ERROR(bufsize(cusolver_dn_handle, m, n, k,
+ CUDAComplex(dev_a), lda,
+ CUDAComplex(dev_tau), &lwork));
+ /* Allocate device memory for workspace. */
+ ScratchSpace<Scalar> dev_workspace(context, lwork, /* on_host */ false);
+ /* Launch the solver kernel. */
+ TF_RETURN_IF_CUSOLVER_ERROR(
+ solver(cusolver_dn_handle, m, n, k, CUDAComplex(dev_a), lda,
+ CUDAComplex(dev_tau), CUDAComplex(dev_workspace.mutable_data()),
+ lwork, dev_lapack_info));
+ return Status::OK();
+}
+
+// Unfortunately the LAPACK function name differs for the real and complex case
+// (complex ones are prefixed with "UN" for "unitary"), so we instantiate each
+// one separately.
+template <>
+Status CudaSolver::Orgqr(int m, int n, int k, float* dev_a, int lda,
+ const float* dev_tau, int* dev_lapack_info) const {
+ return OrgqrImpl(DN_BUFSIZE_FN(orgqr, S), DN_SOLVER_FN(orgqr, S), context_,
+ cusolver_dn_handle_, m, n, k, dev_a, lda, dev_tau,
+ dev_lapack_info);
+}
+template <>
+Status CudaSolver::Orgqr(int m, int n, int k, double* dev_a, int lda,
+ const double* dev_tau, int* dev_lapack_info) const {
+ return OrgqrImpl(DN_BUFSIZE_FN(orgqr, D), DN_SOLVER_FN(orgqr, D), context_,
+ cusolver_dn_handle_, m, n, k, dev_a, lda, dev_tau,
+ dev_lapack_info);
+}
+template <>
+Status CudaSolver::Orgqr(int m, int n, int k, std::complex<float>* dev_a,
+ int lda, const std::complex<float>* dev_tau,
+ int* dev_lapack_info) const {
+ return OrgqrImpl(DN_BUFSIZE_FN(ungqr, C), DN_SOLVER_FN(ungqr, C), context_,
+ cusolver_dn_handle_, m, n, k, dev_a, lda, dev_tau,
+ dev_lapack_info);
+}
+template <>
+Status CudaSolver::Orgqr(int m, int n, int k, std::complex<double>* dev_a,
+ int lda, const std::complex<double>* dev_tau,
+ int* dev_lapack_info) const {
+ return OrgqrImpl(DN_BUFSIZE_FN(ungqr, Z), DN_SOLVER_FN(ungqr, Z), context_,
+ cusolver_dn_handle_, m, n, k, dev_a, lda, dev_tau,
+ dev_lapack_info);
+}
+
//=============================================================================
// Wrappers of cuBlas computational methods begin here.
//
diff --git a/tensorflow/core/kernels/cuda_solvers.h b/tensorflow/core/kernels/cuda_solvers.h
index 7cbdc895dd..38873a0dec 100644
--- a/tensorflow/core/kernels/cuda_solvers.h
+++ b/tensorflow/core/kernels/cuda_solvers.h
@@ -147,7 +147,7 @@ class CudaSolver {
Status CopyLapackInfoToHostAsync(
const std::vector<DeviceLapackInfo>& dev_lapack_info,
std::function<void(const Status&, const std::vector<HostLapackInfo>&)>
- info_checker_callback) const;
+ info_checker_callback) const TF_MUST_USE_RESULT;
// ====================================================================
// Wrappers for cuSolverDN and cuBlas solvers start here.
@@ -166,28 +166,29 @@ class CudaSolver {
const Scalar* alpha, /* host or device pointer */
const Scalar* A, int lda,
const Scalar* beta, /* host or device pointer */
- const Scalar* B, int ldb, Scalar* C, int ldc) const;
+ const Scalar* B, int ldb, Scalar* C,
+ int ldc) const TF_MUST_USE_RESULT;
// Computes the Cholesky factorization A = L * L^T for a single matrix.
// Returns Status::OK() if the kernel was launched successfully. See:
// http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-potrf
template <typename Scalar>
Status Potrf(cublasFillMode_t uplo, int n, Scalar* dev_A, int lda,
- int* dev_lapack_info) const;
+ int* dev_lapack_info) const TF_MUST_USE_RESULT;
// LU factorization.
// Computes LU factorization with partial pivoting P * A = L * U.
// See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-getrf
template <typename Scalar>
Status Getrf(int m, int n, Scalar* dev_A, int lda, int* dev_pivots,
- int* dev_lapack_info) const;
+ int* dev_lapack_info) const TF_MUST_USE_RESULT;
// Uses LU factorization to solve A * X = B.
// See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-getrs
template <typename Scalar>
Status Getrs(cublasOperation_t trans, int n, int nrhs, const Scalar* A,
int lda, const int* pivots, Scalar* B, int ldb,
- int* dev_lapack_info) const;
+ int* dev_lapack_info) const TF_MUST_USE_RESULT;
// Computes partially pivoted LU factorizations for a batch of small matrices.
// Returns Status::OK() if the kernel was launched successfully.See:
@@ -195,7 +196,7 @@ class CudaSolver {
template <typename Scalar>
Status GetrfBatched(int n, const Scalar* host_a_dev_ptrs[], int lda,
int* dev_pivots, DeviceLapackInfo* dev_lapack_info,
- int batch_size) const;
+ int batch_size) const TF_MUST_USE_RESULT;
// Batched linear solver using LU factorization from getrfBatched.
// See:
@@ -204,7 +205,8 @@ class CudaSolver {
Status GetrsBatched(cublasOperation_t trans, int n, int nrhs,
const Scalar* dev_Aarray[], int lda, const int* devIpiv,
const Scalar* dev_Barray[], int ldb,
- DeviceLapackInfo* dev_lapack_info, int batch_size) const;
+ DeviceLapackInfo* dev_lapack_info,
+ int batch_size) const TF_MUST_USE_RESULT;
// Computes matrix inverses for a batch of small matrices. Uses the outputs
// from GetrfBatched. Returns Status::OK() if the kernel was launched
@@ -214,7 +216,8 @@ class CudaSolver {
Status GetriBatched(int n, const Scalar* host_a_dev_ptrs[], int lda,
const int* dev_pivots,
const Scalar* host_a_inverse_dev_ptrs[], int ldainv,
- DeviceLapackInfo* dev_lapack_info, int batch_size) const;
+ DeviceLapackInfo* dev_lapack_info,
+ int batch_size) const TF_MUST_USE_RESULT;
// Computes matrix inverses for a batch of small matrices with size n < 32.
// Returns Status::OK() if the kernel was launched successfully. See:
@@ -222,59 +225,58 @@ class CudaSolver {
template <typename Scalar>
Status MatInvBatched(int n, const Scalar* host_a_dev_ptrs[], int lda,
const Scalar* host_a_inverse_dev_ptrs[], int ldainv,
- DeviceLapackInfo* dev_lapack_info, int batch_size) const;
-
- /*
- TODO(rmlarsen, volunteers): Implement the kernels below.
- // Uses Cholesky factorization to solve A * X = B.
- // See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-potrs
- template <typename Scalar>
- Status Potrs(cublasFillMode_t uplo, int n, int nrhs, const Scalar* dev_A, int
- lda, Scalar* dev_B, int ldb, int* dev_lapack_info) const;
+ DeviceLapackInfo* dev_lapack_info,
+ int batch_size) const TF_MUST_USE_RESULT;
// QR factorization.
// Computes QR factorization A = Q * R.
+ // Returns Status::OK() if the kernel was launched successfully.
// See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-geqrf
template <typename Scalar>
- Status Geqrf(int m, int n, Scalar* dev_A, int lda, Scalar* dev_TAU, int*
- devInfo) const;
+ Status Geqrf(int m, int n, Scalar* dev_A, int lda, Scalar* dev_tau,
+ int* dev_lapack_info) const TF_MUST_USE_RESULT;
- // Multiplies by Q.
+ // Overwrite matrix C by product of C and Householder matrix Q. The
+ // Householder matrix Q is represented by the output from Geqrf in dev_a and
+ // dev_tau.
+ // Returns Status::OK() if the kernel was launched successfully.
// See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-ormqr
template <typename Scalar>
- Status Ormqr(cublasSideMode_t side, cublasOperation_t trans, int m, int n, int
- k, const Scalar* dev_a, int lda, const Scalar* dev_tau, Scalar* dev_c, int
- ldc, int* dev_lapack_info) const;
-
- // Generate Q.
+ Status Ormqr(cublasSideMode_t side, cublasOperation_t trans, int m, int n,
+ int k, const Scalar* dev_a, int lda, const Scalar* dev_tau,
+ Scalar* dev_c, int ldc,
+ int* dev_lapack_info) const TF_MUST_USE_RESULT;
+
+ // Overwrites QR factorization produced by Geqrf by Householder matrix Q.
+ // On input, the Householder matrix Q is represented by the output from Geqrf
+ // in dev_a and dev_tau. On output, dev_a is overwritten with the first n
+ // columns of Q.
+ // Requires m >= n >= 0.
+ // Returns Status::OK() if the kernel was launched successfully.
// See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-orgqr
template <typename Scalar>
- Status Orgqr(int m, int n, int k, Scalar* dev_A, int lda, const Scalar*
- dev_tau, int* dev_lapack_info) const;
-
- // Symmetric/Hermitian Eigen decomposition.
- // See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-syevd
- template <typename Scalar>
- Status Syevd(cusolverEigMode_t jobz, cublasFillMode_t uplo, int n, Scalar*
- dev_A, int lda, Scalar* dev_W, int* dev_lapack_info) const;
+ Status Orgqr(int m, int n, int k, Scalar* dev_a, int lda,
+ const Scalar* dev_tau,
+ int* dev_lapack_info) const TF_MUST_USE_RESULT;
-*/
// Singular value decomposition.
+ // Returns Status::OK() if the kernel was launched successfully.
+ // TODO(rmlarsen, volunteers): Add support for complex types.
// See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-gesvd
template <typename Scalar>
Status Gesvd(signed char jobu, signed char jobvt, int m, int n, Scalar* dev_A,
int lda, Scalar* dev_S, Scalar* dev_U, int ldu, Scalar* dev_VT,
- int ldvt, int* dev_lapack_info) const;
+ int ldvt, int* dev_lapack_info) const TF_MUST_USE_RESULT;
+
/*
- // Batched linear solver using LU factorization from getrfBatched.
- // See:
- http://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-getrsbatched
- template <typename Scalar>
- Status GetrsBatched(cublasOperation_t trans, int n, int nrhs,
- const Scalar* dev_Aarray[], int lda, const int* devIpiv,
- Scalar* dev_Barray[], int ldb, int* info, int batch_size)
- const;
- */
+ TODO(rmlarsen, volunteers): Implement the kernels below.
+
+ // Symmetric/Hermitian Eigen decomposition.
+ // See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-syevd
+ template <typename Scalar>
+ Status Syevd(cusolverEigMode_t jobz, cublasFillMode_t uplo, int n, Scalar*
+ dev_A, int lda, Scalar* dev_W, int* dev_lapack_info) const TF_MUST_USE_RESULT;
+ */
private:
OpKernelContext* context_; // not owned.
@@ -371,7 +373,7 @@ namespace functor {
template <typename Device, typename Scalar>
struct AdjointBatchFunctor {
// We assume that the tensor sizes are correct.
- void operator()(const Device& d,
+ void operator()(const Device& device,
typename TTypes<Scalar, 3>::ConstTensor input,
typename TTypes<Scalar, 3>::Tensor output);
};
@@ -380,7 +382,8 @@ struct AdjointBatchFunctor {
// in a flattened batch.
template <typename Device, typename Scalar>
struct DeterminantFromPivotedLUFunctor {
- void operator()(const Device& d, typename TTypes<Scalar, 3>::Tensor lu_factor,
+ void operator()(const Device& device,
+ typename TTypes<Scalar, 3>::Tensor lu_factor,
const int* pivots, typename TTypes<Scalar, 1>::Tensor output,
int* info);
};
@@ -390,7 +393,7 @@ struct DeterminantFromPivotedLUFunctor {
// op.
template <typename Device, typename Scalar>
struct EyeFunctor {
- void operator()(const Device& d,
+ void operator()(const Device& device,
typename TTypes<Scalar, 3>::Tensor matrix_batch);
};
diff --git a/tensorflow/core/kernels/cuda_solvers_gpu.cu.cc b/tensorflow/core/kernels/cuda_solvers_gpu.cu.cc
index af6c094d7a..bbbe1377b2 100644
--- a/tensorflow/core/kernels/cuda_solvers_gpu.cu.cc
+++ b/tensorflow/core/kernels/cuda_solvers_gpu.cu.cc
@@ -190,7 +190,6 @@ struct DeterminantFromPivotedLUFunctor<GPUDevice, Scalar> {
}
};
-// Instantiate implementations for the 4 numeric types.
template struct DeterminantFromPivotedLUFunctor<GPUDevice, float>;
template struct DeterminantFromPivotedLUFunctor<GPUDevice, double>;
template struct DeterminantFromPivotedLUFunctor<GPUDevice, std::complex<float>>;
@@ -202,7 +201,6 @@ __global__ void EyeKernel(Cuda3DLaunchConfig config, int batch_size, int m,
int n, Scalar* matrix_batch_ptr) {
const int matrix_size = m * n;
const Scalar one = Const<Scalar>::make_const(1.0);
- const Scalar zero = Const<Scalar>::make_const(0.0);
CUDA_AXIS_KERNEL_LOOP(batch, config.virtual_thread_count, x) {
if (batch >= batch_size) {
break;
@@ -216,7 +214,7 @@ __global__ void EyeKernel(Cuda3DLaunchConfig config, int batch_size, int m,
if (col >= n) {
break;
}
- matrix_batch_ptr[row_start + col] = row == col ? one : zero;
+ matrix_batch_ptr[row_start + col] = row == col ? one : Scalar();
}
}
}
@@ -239,7 +237,6 @@ struct EyeFunctor<GPUDevice, Scalar> {
}
};
-// Instantiate implementations for the 4 numeric types.
template struct EyeFunctor<GPUDevice, float>;
template struct EyeFunctor<GPUDevice, double>;
template struct EyeFunctor<GPUDevice, std::complex<float>>;
diff --git a/tensorflow/core/kernels/matrix_band_part_op.cc b/tensorflow/core/kernels/matrix_band_part_op.cc
index 8b8accc0b3..e5f9086dba 100644
--- a/tensorflow/core/kernels/matrix_band_part_op.cc
+++ b/tensorflow/core/kernels/matrix_band_part_op.cc
@@ -93,7 +93,7 @@ class MatrixBandPartOp : public OpKernel {
auto output_reshaped = output->flat_inner_dims<T, 3>();
functor::MatrixBandPartFunctor<Device, T> fn;
fn(context, context->eigen_device<Device>(), num_lower, num_upper,
- false /* transpose */, input_reshaped, output_reshaped);
+ input_reshaped, output_reshaped);
}
private:
@@ -126,7 +126,7 @@ typedef Eigen::ThreadPoolDevice CPUDevice;
template <typename Scalar>
struct MatrixBandPartFunctor<CPUDevice, Scalar> {
void operator()(OpKernelContext* context, const CPUDevice& device,
- int num_lower_diags, int num_upper_diags, bool transpose,
+ int num_lower_diags, int num_upper_diags,
typename TTypes<Scalar, 3>::ConstTensor input,
typename TTypes<Scalar, 3>::Tensor output) {
const int64 b = input.dimension(0);
@@ -137,72 +137,46 @@ struct MatrixBandPartFunctor<CPUDevice, Scalar> {
const int64 total_rows = b * m;
const int64 row_cost = 10 * n;
const bool in_place = input.data() == output.data();
- CHECK(!(transpose && in_place));
- if (!transpose) {
- auto compute_shard = [=, &input, &output](int64 begin, int64 end) {
- if (!in_place) {
- std::fill(output.data() + begin * n, output.data() + end * n,
- Scalar());
- }
- const int64 batch_begin = begin / m;
- const int64 batch_end = (end + m - 1) / m;
- for (int64 batch = batch_begin; batch < batch_end; ++batch) {
- const int64 row_begin = begin > batch * m ? begin % m : 0;
- const int64 row_end = end < (batch + 1) * m ? end % m : m;
- for (int64 row = row_begin; row < row_end; ++row) {
- const int64 band_start =
- num_lower_diags < 0
- ? 0
- : std::min(n, std::max(0ll, row - num_lower_diags));
- const int64 band_end = num_upper_diags < 0
- ? n
- : std::min(static_cast<int64>(n),
- row + num_upper_diags + 1);
- if (in_place) {
- if (band_start > 0) {
- std::fill(&output(batch, row, 0),
- &output(batch, row, band_start), Scalar());
- }
- if (band_end < n) {
- std::fill(&output(batch, row, band_end), &output(batch, row, n),
- Scalar());
- }
- } else {
- if (band_start < band_end) {
- const Eigen::DSizes<Eigen::DenseIndex, 3> indices(batch, row,
- band_start);
- const Eigen::DSizes<Eigen::DenseIndex, 3> sizes(
- 1, 1, band_end - band_start);
- output.slice(indices, sizes) = input.slice(indices, sizes);
- }
+ auto compute_shard = [=, &input, &output](int64 begin, int64 end) {
+ if (!in_place) {
+ std::fill(output.data() + begin * n, output.data() + end * n, Scalar());
+ }
+ const int64 batch_begin = begin / m;
+ const int64 batch_end = (end + m - 1) / m;
+ for (int64 batch = batch_begin; batch < batch_end; ++batch) {
+ const int64 row_begin = begin > batch * m ? begin % m : 0;
+ const int64 row_end = end < (batch + 1) * m ? end % m : m;
+ for (int64 row = row_begin; row < row_end; ++row) {
+ const int64 band_start =
+ num_lower_diags < 0
+ ? 0
+ : std::min(n, std::max(0ll, row - num_lower_diags));
+ const int64 band_end =
+ num_upper_diags < 0
+ ? n
+ : std::min(static_cast<int64>(n), row + num_upper_diags + 1);
+ if (in_place) {
+ if (band_start > 0) {
+ std::fill(&output(batch, row, 0), &output(batch, row, band_start),
+ Scalar());
}
- }
- }
- };
- thread_pool->ParallelFor(total_rows, row_cost, std::move(compute_shard));
- } else {
- output.device(device) = output.constant(Scalar());
- auto compute_shard = [=, &input, &output](int64 begin, int64 end) {
- const int64 batch_begin = begin / m;
- const int64 batch_end = (end + m - 1) / m;
- for (int64 batch = batch_begin; batch < batch_end; ++batch) {
- const int64 row_begin = begin > batch * m ? begin % m : 0;
- const int64 row_end = end < (batch + 1) * m ? end % m : m;
- for (int64 row = row_begin; row < row_end; ++row) {
- const int64 band_start =
- num_lower_diags < 0 ? 0 : std::max(0ll, row - num_lower_diags);
- const int64 band_end = num_upper_diags < 0
- ? n
- : std::min(static_cast<int64>(n),
- row + num_upper_diags + 1);
- for (int64 col = band_start; col < band_end; ++col) {
- output(batch, col, row) = input(batch, row, col);
+ if (band_end < n) {
+ std::fill(&output(batch, row, band_end), &output(batch, row, n),
+ Scalar());
+ }
+ } else {
+ if (band_start < band_end) {
+ const Eigen::DSizes<Eigen::DenseIndex, 3> indices(batch, row,
+ band_start);
+ const Eigen::DSizes<Eigen::DenseIndex, 3> sizes(
+ 1, 1, band_end - band_start);
+ output.slice(indices, sizes) = input.slice(indices, sizes);
}
}
}
- };
- thread_pool->ParallelFor(total_rows, row_cost, std::move(compute_shard));
- }
+ }
+ };
+ thread_pool->ParallelFor(total_rows, row_cost, std::move(compute_shard));
}
};
@@ -216,14 +190,14 @@ TF_CALL_POD_TYPES(DEFINE_CPU_SPEC);
// Forward declarations of the functor specializations for GPU.
namespace functor {
-#define DECLARE_GPU_SPEC(T) \
- template <> \
- struct MatrixBandPartFunctor<GPUDevice, T> { \
- void operator()(OpKernelContext* context, const GPUDevice& device, \
- int num_upper_diags, int num_lower_diags, bool transpose, \
- typename TTypes<T, 3>::ConstTensor input, \
- typename TTypes<T, 3>::Tensor output); \
- }; \
+#define DECLARE_GPU_SPEC(T) \
+ template <> \
+ struct MatrixBandPartFunctor<GPUDevice, T> { \
+ void operator()(OpKernelContext* context, const GPUDevice& device, \
+ int num_upper_diags, int num_lower_diags, \
+ typename TTypes<T, 3>::ConstTensor input, \
+ typename TTypes<T, 3>::Tensor output); \
+ }; \
extern template struct MatrixBandPartFunctor<GPUDevice, T>;
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
diff --git a/tensorflow/core/kernels/matrix_band_part_op.h b/tensorflow/core/kernels/matrix_band_part_op.h
index 43b6724dae..97cc950793 100644
--- a/tensorflow/core/kernels/matrix_band_part_op.h
+++ b/tensorflow/core/kernels/matrix_band_part_op.h
@@ -26,7 +26,7 @@ namespace functor {
template <typename Device, typename Scalar>
struct MatrixBandPartFunctor {
void operator()(OpKernelContext* context, const Device& device,
- int num_upper_diags, int num_lower_diags, bool transpose,
+ int num_upper_diags, int num_lower_diags,
typename TTypes<Scalar, 3>::ConstTensor input,
typename TTypes<Scalar, 3>::Tensor output);
};
diff --git a/tensorflow/core/kernels/matrix_band_part_op_gpu.cu.cc b/tensorflow/core/kernels/matrix_band_part_op_gpu.cu.cc
index afebdacdca..41b2f5c0ef 100644
--- a/tensorflow/core/kernels/matrix_band_part_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/matrix_band_part_op_gpu.cu.cc
@@ -28,41 +28,22 @@ namespace tensorflow {
namespace functor {
typedef Eigen::GpuDevice GPUDevice;
-template <bool transpose, typename Scalar>
+template <typename Scalar>
__global__ void MatrixBandPartKernel(const int num_threads,
const int batch_size, const int m,
const int n, const int num_lower_diags,
const int num_upper_diags,
const Scalar* input_ptr,
Scalar* output_ptr) {
- if (!transpose) {
- CUDA_1D_KERNEL_LOOP(index, num_threads) {
- const int col = index % n;
- const int row = (index / n) % m;
- const int band_start = (num_lower_diags < 0 ? 0 : row - num_lower_diags);
- const int band_end =
- (num_upper_diags < 0 ? n : row + num_upper_diags + 1);
- if (col < band_start || col >= band_end) {
- output_ptr[index] = Scalar();
- } else {
- output_ptr[index] = input_ptr[index];
- }
- }
- } else {
- const int matrix_size = m * n;
- CUDA_1D_KERNEL_LOOP(index, num_threads) {
- const int col = index % n;
- const int row = (index / n) % m;
- const int batch = index / matrix_size;
- const int transpose_index = batch * matrix_size + n * col + row;
- const int band_start = (num_lower_diags < 0 ? 0 : row - num_lower_diags);
- const int band_end =
- (num_upper_diags < 0 ? n : row + num_upper_diags + 1);
- if (col < band_start || col >= band_end) {
- output_ptr[transpose_index] = Scalar();
- } else {
- output_ptr[transpose_index] = input_ptr[index];
- }
+ CUDA_1D_KERNEL_LOOP(index, num_threads) {
+ const int col = index % n;
+ const int row = (index / n) % m;
+ const int band_start = (num_lower_diags < 0 ? 0 : row - num_lower_diags);
+ const int band_end = (num_upper_diags < 0 ? n : row + num_upper_diags + 1);
+ if (col < band_start || col >= band_end) {
+ output_ptr[index] = Scalar();
+ } else {
+ output_ptr[index] = input_ptr[index];
}
}
}
@@ -70,7 +51,7 @@ __global__ void MatrixBandPartKernel(const int num_threads,
template <typename Scalar>
struct MatrixBandPartFunctor<GPUDevice, Scalar> {
void operator()(OpKernelContext* context, const GPUDevice& device,
- int num_lower_diags, int num_upper_diags, bool transpose,
+ int num_lower_diags, int num_upper_diags,
typename TTypes<Scalar, 3>::ConstTensor input,
typename TTypes<Scalar, 3>::Tensor output) {
using CudaType = typename CUDAComplexT<Scalar>::type;
@@ -80,17 +61,10 @@ struct MatrixBandPartFunctor<GPUDevice, Scalar> {
const CudaType* input_ptr = reinterpret_cast<const CudaType*>(input.data());
CudaType* output_ptr = reinterpret_cast<CudaType*>(output.data());
CudaLaunchConfig config = GetCudaLaunchConfig(batch_size * m * n, device);
- if (transpose) {
- MatrixBandPartKernel<true>
- <<<config.block_count, config.thread_per_block, 0, device.stream()>>>(
- config.virtual_thread_count, batch_size, m, n, num_lower_diags,
- num_upper_diags, input_ptr, output_ptr);
- } else {
- MatrixBandPartKernel<false>
- <<<config.block_count, config.thread_per_block, 0, device.stream()>>>(
- config.virtual_thread_count, batch_size, m, n, num_lower_diags,
- num_upper_diags, input_ptr, output_ptr);
- }
+ MatrixBandPartKernel<<<config.block_count, config.thread_per_block, 0,
+ device.stream()>>>(
+ config.virtual_thread_count, batch_size, m, n, num_lower_diags,
+ num_upper_diags, input_ptr, output_ptr);
}
};
diff --git a/tensorflow/core/kernels/qr_op_complex128.cc b/tensorflow/core/kernels/qr_op_complex128.cc
index f22bdf0d21..c5b73139bb 100644
--- a/tensorflow/core/kernels/qr_op_complex128.cc
+++ b/tensorflow/core/kernels/qr_op_complex128.cc
@@ -19,4 +19,8 @@ namespace tensorflow {
REGISTER_LINALG_OP("Qr", (QrOp<complex128>), complex128);
+#if GOOGLE_CUDA
+REGISTER_LINALG_OP_GPU("Qr", (QrOpGpu<complex128>), complex128);
+#endif
+
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/qr_op_complex64.cc b/tensorflow/core/kernels/qr_op_complex64.cc
index 2d99a856a3..4e14f2639c 100644
--- a/tensorflow/core/kernels/qr_op_complex64.cc
+++ b/tensorflow/core/kernels/qr_op_complex64.cc
@@ -19,4 +19,8 @@ namespace tensorflow {
REGISTER_LINALG_OP("Qr", (QrOp<complex64>), complex64);
+#if GOOGLE_CUDA
+REGISTER_LINALG_OP_GPU("Qr", (QrOpGpu<complex64>), complex64);
+#endif
+
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/qr_op_double.cc b/tensorflow/core/kernels/qr_op_double.cc
index 3873d7fbcf..51885eb355 100644
--- a/tensorflow/core/kernels/qr_op_double.cc
+++ b/tensorflow/core/kernels/qr_op_double.cc
@@ -19,4 +19,8 @@ namespace tensorflow {
REGISTER_LINALG_OP("Qr", (QrOp<double>), double);
+#if GOOGLE_CUDA
+REGISTER_LINALG_OP_GPU("Qr", (QrOpGpu<double>), double);
+#endif
+
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/qr_op_float.cc b/tensorflow/core/kernels/qr_op_float.cc
index e23cd5a0d9..d0a1dd4204 100644
--- a/tensorflow/core/kernels/qr_op_float.cc
+++ b/tensorflow/core/kernels/qr_op_float.cc
@@ -19,4 +19,8 @@ namespace tensorflow {
REGISTER_LINALG_OP("Qr", (QrOp<float>), float);
+#if GOOGLE_CUDA
+REGISTER_LINALG_OP_GPU("Qr", (QrOpGpu<float>), float);
+#endif
+
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/qr_op_impl.h b/tensorflow/core/kernels/qr_op_impl.h
index 029ef83480..aea0c552de 100644
--- a/tensorflow/core/kernels/qr_op_impl.h
+++ b/tensorflow/core/kernels/qr_op_impl.h
@@ -19,10 +19,16 @@ limitations under the License.
// individual kernels. A separate file is used for each instantiated kernel to
// improve compilation times.
#include <algorithm>
+#include <numeric>
+
+#if GOOGLE_CUDA
+#define EIGEN_USE_GPU
+#endif
#include "third_party/eigen3/Eigen/QR"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/kernels/linalg_ops_common.h"
#include "tensorflow/core/lib/core/errors.h"
@@ -30,6 +36,13 @@ limitations under the License.
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
+#if GOOGLE_CUDA
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/kernels/cuda_solvers.h"
+#include "tensorflow/core/kernels/matrix_band_part_op.h"
+#include "tensorflow/core/kernels/transpose_functor.h"
+#endif
+
namespace tensorflow {
template <class Scalar>
@@ -107,4 +120,189 @@ class QrOp : public LinearAlgebraOp<Scalar> {
TF_DISALLOW_COPY_AND_ASSIGN(QrOp);
};
+#if GOOGLE_CUDA
+
+typedef Eigen::GpuDevice GPUDevice;
+
+template <class Scalar>
+class QrOpGpu : public AsyncOpKernel {
+ public:
+ explicit QrOpGpu(OpKernelConstruction* context) : AsyncOpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("full_matrices", &full_matrices_));
+ }
+
+ void ComputeAsync(OpKernelContext* context, DoneCallback done) final {
+ const Tensor& input = context->input(0);
+ const int ndims = input.dims();
+ const int64 m = input.dim_size(ndims - 2);
+ const int64 n = input.dim_size(ndims - 1);
+ const int64 min_size = std::min(m, n);
+ const int64 batch_size =
+ input.template flat_inner_dims<Scalar, 3>().dimension(0);
+
+ // Validate inputs.
+ OP_REQUIRES_ASYNC(
+ context, ndims >= 2,
+ errors::InvalidArgument("Input must have rank >= 2, got ", ndims),
+ done);
+
+ // Allocate output.
+ // If full_matrices_ is true then Q is m x m and R is m x n.
+ // Otherwise, Q is m x min(m, n), and R is min(m, n) x n.
+ Tensor* q;
+ TensorShape q_shape = input.shape();
+ q_shape.set_dim(ndims - 1, full_matrices_ ? m : min_size);
+ OP_REQUIRES_OK_ASYNC(context, context->allocate_output(0, q_shape, &q),
+ done);
+ Tensor* r;
+ TensorShape r_shape = input.shape();
+ r_shape.set_dim(ndims - 2, full_matrices_ ? m : min_size);
+ OP_REQUIRES_OK_ASYNC(context, context->allocate_output(1, r_shape, &r),
+ done);
+
+ if (input.NumElements() == 0) {
+ done();
+ return;
+ }
+
+ // Allocate temporaries.
+ Tensor input_transposed;
+ TensorShape transposed_shape = input.shape();
+ transposed_shape.set_dim(ndims - 2, input.dim_size(ndims - 1));
+ transposed_shape.set_dim(ndims - 1, input.dim_size(ndims - 2));
+ OP_REQUIRES_OK_ASYNC(
+ context,
+ context->allocate_temp(DataTypeToEnum<Scalar>::value, transposed_shape,
+ &input_transposed),
+ done);
+
+ Tensor tau;
+ OP_REQUIRES_OK_ASYNC(
+ context,
+ context->allocate_temp(DataTypeToEnum<Scalar>::value,
+ TensorShape({batch_size, min_size}), &tau),
+ done);
+
+ // Transpose input, since cuSolver uses column-major, while TensorFlow uses
+ // row-major storage.
+ std::vector<int> perm(ndims);
+ std::iota(perm.begin(), perm.end(), 0);
+ std::swap(perm[ndims - 2], perm[ndims - 1]);
+ const GPUDevice& device = context->eigen_device<GPUDevice>();
+ OP_REQUIRES_OK_ASYNC(
+ context, DoTranspose(device, input, perm, &input_transposed), done);
+
+ // Compute QR decomposition in-place in input_transposed.
+ CudaSolver solver(context);
+ std::vector<DeviceLapackInfo> dev_info;
+ dev_info.emplace_back(context, batch_size, "geqrf");
+ auto input_transposed_reshaped =
+ input_transposed.flat_inner_dims<Scalar, 3>();
+ auto tau_matrix = tau.matrix<Scalar>();
+ auto r_reshaped = r->flat_inner_dims<Scalar, 3>();
+ for (int batch = 0; batch < batch_size; ++batch) {
+ OP_REQUIRES_OK_ASYNC(
+ context,
+ solver.Geqrf(m, n, &input_transposed_reshaped(batch, 0, 0), m,
+ &tau_matrix(batch, 0),
+ dev_info.back().mutable_data() + batch),
+ done);
+ }
+
+ // Generate R. R is equal to the upper triangle of the decomposition
+ // stored in input_transposed. Crop, transpose (to get back to row-major)
+ // and copy it to the output buffer.
+ if (full_matrices_ || m == n) {
+ OP_REQUIRES_OK_ASYNC(
+ context, DoTranspose(device, input_transposed, perm, r), done);
+ } else {
+ const Scalar alpha(1);
+ const Scalar beta(0);
+ const Scalar* dummy = nullptr;
+ for (int batch = 0; batch < batch_size; ++batch) {
+ OP_REQUIRES_OK_ASYNC(
+ context,
+ solver.Geam(CUBLAS_OP_T, CUBLAS_OP_N, n,
+ full_matrices_ ? m : min_size, &alpha,
+ &input_transposed_reshaped(batch, 0, 0), m, &beta,
+ dummy, n, &r_reshaped(batch, 0, 0), n),
+ done);
+ }
+ }
+ // Extract the upper triangle of r (i.e. zero out the strictly lower
+ // triangle).
+ functor::MatrixBandPartFunctor<GPUDevice, Scalar> band_part;
+ auto r_reshaped_const =
+ const_cast<const Tensor*>(r)->flat_inner_dims<Scalar, 3>();
+ band_part(context, device, 0 /* num_lower_diags */,
+ -1 /* num_upper_diags */, r_reshaped_const, r_reshaped);
+
+ // Generate Q from the decomposition in input_transposed.
+ if (m != n && (full_matrices_ || m < n)) {
+ context->CtxFailure(
+ errors::Unimplemented("The case m != n && (full_matrices_ || m < "
+ "n) is not currently supported on GPU."));
+ done();
+ return;
+
+ /* TODO(rmlarsen): FIXME. This branch fails with info != 0 (both
+ positive and negative) error statuses from ORMQR.
+
+ // Generate full m x m matrix Q by computing the product Q^T * I
+ // (transpose to get back to row-major form).
+ functor::EyeFunctor<GPUDevice, Scalar> eye;
+ auto q_reshaped = q->flat_inner_dims<Scalar, 3>();
+ eye(device, q_reshaped);
+ dev_info.emplace_back(context, batch_size, "ormqr");
+ for (int batch = 0; batch < batch_size; ++batch) {
+ OP_REQUIRES_OK_ASYNC(
+ context,
+ solver.Ormqr(CUBLAS_SIDE_LEFT, CUBLAS_OP_T, m, m, min_size,
+ &input_transposed_reshaped(batch, 0, 0), m,
+ &tau_matrix(batch, 0), &q_reshaped(batch, 0, 0), m,
+ dev_info.back().mutable_data() + batch),
+ done);
+ }
+ */
+ } else {
+ // Generate m x n matrix Q. In this case we can use the more efficient
+ // algorithm in Orgqr to generate Q in place.
+ dev_info.emplace_back(context, batch_size, "orgqr");
+ for (int batch = 0; batch < batch_size; ++batch) {
+ OP_REQUIRES_OK_ASYNC(
+ context,
+ solver.Orgqr(
+ m, n, min_size, &input_transposed_reshaped(batch, 0, 0), m,
+ &tau_matrix(batch, 0), dev_info.back().mutable_data() + batch),
+ done);
+ }
+ OP_REQUIRES_OK_ASYNC(
+ context, DoTranspose(device, input_transposed, perm, q), done);
+ }
+
+ // Asynchronously check return status from cuSolver kernels.
+ TensorReference input_transposed_ref(input_transposed);
+ TensorReference tau_ref(tau);
+ auto info_checker = [context, dev_info, input_transposed_ref, tau_ref,
+ done](const Status& status,
+ const std::vector<HostLapackInfo>& host_infos) {
+ input_transposed_ref.Unref();
+ tau_ref.Unref();
+ OP_REQUIRES_OK_ASYNC(context, status, done);
+ done();
+ };
+ OP_REQUIRES_OK_ASYNC(
+ context,
+ solver.CopyLapackInfoToHostAsync(dev_info, std::move(info_checker)),
+ done);
+ }
+
+ private:
+ bool full_matrices_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(QrOpGpu);
+};
+
+#endif
+
} // namespace tensorflow
diff --git a/tensorflow/python/kernel_tests/qr_op_test.py b/tensorflow/python/kernel_tests/qr_op_test.py
index 7867e0e42d..6b5becef60 100644
--- a/tensorflow/python/kernel_tests/qr_op_test.py
+++ b/tensorflow/python/kernel_tests/qr_op_test.py
@@ -27,6 +27,13 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
+def _AddTest(test_class, op_name, testcase_name, fn):
+ test_name = "_".join(["test", op_name, testcase_name])
+ if hasattr(test_class, test_name):
+ raise RuntimeError("Test %s defined more than once" % test_name)
+ setattr(test_class, test_name, fn)
+
+
class QrOpTest(test.TestCase):
def testWrongDimensions(self):
@@ -41,7 +48,7 @@ class QrOpTest(test.TestCase):
linalg_ops.qr(vector)
-def _GetQrOpTest(dtype_, shape_, use_static_shape_):
+def _GetQrOpTest(dtype_, shape_, full_matrices_, use_static_shape_):
is_complex = dtype_ in (np.complex64, np.complex128)
is_single = dtype_ in (np.float32, np.complex64)
@@ -95,36 +102,41 @@ def _GetQrOpTest(dtype_, shape_, use_static_shape_):
low=-1.0, high=1.0,
size=np.prod(shape_)).reshape(shape_).astype(dtype_)
- for full_matrices in False, True:
- with self.test_session() as sess:
- if use_static_shape_:
- x_tf = constant_op.constant(x_np)
- else:
- x_tf = array_ops.placeholder(dtype_)
- q_tf, r_tf = linalg_ops.qr(x_tf, full_matrices=full_matrices)
-
- if use_static_shape_:
- q_tf_val, r_tf_val = sess.run([q_tf, r_tf])
- else:
- q_tf_val, r_tf_val = sess.run([q_tf, r_tf], feed_dict={x_tf: x_np})
-
- q_dims = q_tf_val.shape
- np_q = np.ndarray(q_dims, dtype_)
- np_q_reshape = np.reshape(np_q, (-1, q_dims[-2], q_dims[-1]))
- new_first_dim = np_q_reshape.shape[0]
-
- x_reshape = np.reshape(x_np, (-1, x_np.shape[-2], x_np.shape[-1]))
- for i in range(new_first_dim):
- if full_matrices:
- np_q_reshape[i,:,:], _ = \
+ # TODO(rmlarsen): Debug failure due to invalid parameter to ORMQR.
+ rows_ = shape_[-2]
+ cols_ = shape_[-1]
+ use_gpu = False if rows_ < cols_ or (full_matrices_ and
+ rows_ != cols_) else True
+
+ with self.test_session(use_gpu=use_gpu) as sess:
+ if use_static_shape_:
+ x_tf = constant_op.constant(x_np)
+ else:
+ x_tf = array_ops.placeholder(dtype_)
+ q_tf, r_tf = linalg_ops.qr(x_tf, full_matrices=full_matrices_)
+
+ if use_static_shape_:
+ q_tf_val, r_tf_val = sess.run([q_tf, r_tf])
+ else:
+ q_tf_val, r_tf_val = sess.run([q_tf, r_tf], feed_dict={x_tf: x_np})
+
+ q_dims = q_tf_val.shape
+ np_q = np.ndarray(q_dims, dtype_)
+ np_q_reshape = np.reshape(np_q, (-1, q_dims[-2], q_dims[-1]))
+ new_first_dim = np_q_reshape.shape[0]
+
+ x_reshape = np.reshape(x_np, (-1, x_np.shape[-2], x_np.shape[-1]))
+ for i in range(new_first_dim):
+ if full_matrices_:
+ np_q_reshape[i,:,:], _ = \
np.linalg.qr(x_reshape[i,:,:], mode="complete")
- else:
- np_q_reshape[i,:,:], _ = \
+ else:
+ np_q_reshape[i,:,:], _ = \
np.linalg.qr(x_reshape[i,:,:], mode="reduced")
- np_q = np.reshape(np_q_reshape, q_dims)
- CompareOrthogonal(self, np_q, q_tf_val, min(shape_[-2:]))
- CheckApproximation(self, x_np, q_tf_val, r_tf_val)
- CheckUnitary(self, q_tf_val)
+ np_q = np.reshape(np_q_reshape, q_dims)
+ CompareOrthogonal(self, np_q, q_tf_val, min(shape_[-2:]))
+ CheckApproximation(self, x_np, q_tf_val, r_tf_val)
+ CheckUnitary(self, q_tf_val)
return Test
@@ -133,11 +145,15 @@ if __name__ == "__main__":
for dtype in np.float32, np.float64, np.complex64, np.complex128:
for rows in 1, 2, 5, 10, 32, 100:
for cols in 1, 2, 5, 10, 32, 100:
- for batch_dims in [(), (3,)] + [(3, 2)] * (max(rows, cols) < 10):
- shape = batch_dims + (rows, cols)
- for use_static_shape in True, False:
- name = "%s_%s_%s" % (dtype.__name__, "_".join(map(str, shape)),
- use_static_shape)
- setattr(QrOpTest, "testQr_" + name,
- _GetQrOpTest(dtype, shape, use_static_shape))
+ for full_matrices in False, True:
+ for batch_dims in [(), (3,)] + [(3, 2)] * (max(rows, cols) < 10):
+ for use_static_shape in True, False:
+ shape = batch_dims + (rows, cols)
+ name = "%s_%s_full_%s_static_%s" % (dtype.__name__,
+ "_".join(map(str, shape)),
+ full_matrices,
+ use_static_shape)
+ _AddTest(QrOpTest, "Qr", name,
+ _GetQrOpTest(dtype, shape, full_matrices,
+ use_static_shape))
test.main()