aboutsummaryrefslogtreecommitdiffhomepage
path: root/test/sparse_solver.h
diff options
context:
space:
mode:
authorGravatar Ralf Hannemann-Tamas <ralf.ht@gmail.com>2021-02-08 22:00:31 +0000
committerGravatar Rasmus Munk Larsen <rmlarsen@google.com>2021-02-08 22:00:31 +0000
commit984d010b7bcb6a03f0319e79b8a768587be85422 (patch)
treeef0f438ca2690419079c3cc05bc503a85c6b761a /test/sparse_solver.h
parentb578930657c962def63c3b4d0bdd1dde8927f1cd (diff)
add specialization of check_sparse_solving() for SuperLU solver, in order to test adjoint and transpose solves
Diffstat (limited to 'test/sparse_solver.h')
-rw-r--r--test/sparse_solver.h131
1 files changed, 131 insertions, 0 deletions
diff --git a/test/sparse_solver.h b/test/sparse_solver.h
index f45d7ef80..58927944b 100644
--- a/test/sparse_solver.h
+++ b/test/sparse_solver.h
@@ -9,6 +9,7 @@
#include "sparse.h"
#include <Eigen/SparseCore>
+#include <Eigen/SparseLU>
#include <sstream>
template<typename Solver, typename Rhs, typename Guess,typename Result>
@@ -144,6 +145,136 @@ void check_sparse_solving(Solver& solver, const typename Solver::MatrixType& A,
}
}
+// specialization of generic check_sparse_solving for SuperLU in order to also test adjoint and transpose solves
+template<typename Scalar, typename Rhs, typename DenseMat, typename DenseRhs>
+void check_sparse_solving(Eigen::SparseLU<Eigen::SparseMatrix<Scalar> >& solver, const typename Eigen::SparseMatrix<Scalar>& A, const Rhs& b, const DenseMat& dA, const DenseRhs& db)
+{
+ typedef typename Eigen::SparseMatrix<Scalar> Mat;
+ typedef typename Mat::StorageIndex StorageIndex;
+ typedef typename Eigen::SparseLU<Eigen::SparseMatrix<Scalar> > Solver;
+
+ // reference solutions computed by dense QR solver
+ DenseRhs refX1 = dA.householderQr().solve(db); // solution of A x = db
+ DenseRhs refX2 = dA.transpose().householderQr().solve(db); // solution of A^T * x = db (use transposed matrix A^T)
+ DenseRhs refX3 = dA.adjoint().householderQr().solve(db); // solution of A^* * x = db (use adjoint matrix A^*)
+
+
+ {
+ Rhs x1(A.cols(), b.cols());
+ Rhs x2(A.cols(), b.cols());
+ Rhs x3(A.cols(), b.cols());
+ Rhs oldb = b;
+
+ solver.compute(A);
+ if (solver.info() != Success)
+ {
+ std::cerr << "ERROR | sparse solver testing, factorization failed (" << typeid(Solver).name() << ")\n";
+ VERIFY(solver.info() == Success);
+ }
+ x1 = solver.solve(b);
+ if (solver.info() != Success)
+ {
+ std::cerr << "WARNING | sparse solver testing: solving failed (" << typeid(Solver).name() << ")\n";
+ return;
+ }
+ VERIFY(oldb.isApprox(b,0.0) && "sparse solver testing: the rhs should not be modified!");
+ VERIFY(x1.isApprox(refX1,test_precision<Scalar>()));
+
+ // test solve with transposed
+ x2 = solver.transpose().solve(b);
+ VERIFY(oldb.isApprox(b) && "sparse solver testing: the rhs should not be modified!");
+ VERIFY(x2.isApprox(refX2,test_precision<Scalar>()));
+
+
+ // test solve with adjoint
+ //solver.template _solve_impl_transposed<true>(b, x3);
+ x3 = solver.adjoint().solve(b);
+ VERIFY(oldb.isApprox(b,0.0) && "sparse solver testing: the rhs should not be modified!");
+ VERIFY(x3.isApprox(refX3,test_precision<Scalar>()));
+
+ x1.setZero();
+ solve_with_guess(solver, b, x1, x1);
+ VERIFY(solver.info() == Success && "solving failed when using analyzePattern/factorize API");
+ VERIFY(oldb.isApprox(b,0.0) && "sparse solver testing: the rhs should not be modified!");
+ VERIFY(x1.isApprox(refX1,test_precision<Scalar>()));
+
+ x1.setZero();
+ x2.setZero();
+ x3.setZero();
+ // test the analyze/factorize API
+ solver.analyzePattern(A);
+ solver.factorize(A);
+ VERIFY(solver.info() == Success && "factorization failed when using analyzePattern/factorize API");
+ x1 = solver.solve(b);
+ x2 = solver.transpose().solve(b);
+ x3 = solver.adjoint().solve(b);
+
+ VERIFY(solver.info() == Success && "solving failed when using analyzePattern/factorize API");
+ VERIFY(oldb.isApprox(b,0.0) && "sparse solver testing: the rhs should not be modified!");
+ VERIFY(x1.isApprox(refX1,test_precision<Scalar>()));
+ VERIFY(x2.isApprox(refX2,test_precision<Scalar>()));
+ VERIFY(x3.isApprox(refX3,test_precision<Scalar>()));
+
+ x1.setZero();
+ // test with Map
+ MappedSparseMatrix<Scalar,Mat::Options,StorageIndex> Am(A.rows(), A.cols(), A.nonZeros(), const_cast<StorageIndex*>(A.outerIndexPtr()), const_cast<StorageIndex*>(A.innerIndexPtr()), const_cast<Scalar*>(A.valuePtr()));
+ solver.compute(Am);
+ VERIFY(solver.info() == Success && "factorization failed when using Map");
+ DenseRhs dx(refX1);
+ dx.setZero();
+ Map<DenseRhs> xm(dx.data(), dx.rows(), dx.cols());
+ Map<const DenseRhs> bm(db.data(), db.rows(), db.cols());
+ xm = solver.solve(bm);
+ VERIFY(solver.info() == Success && "solving failed when using Map");
+ VERIFY(oldb.isApprox(bm,0.0) && "sparse solver testing: the rhs should not be modified!");
+ VERIFY(xm.isApprox(refX1,test_precision<Scalar>()));
+ }
+
+ // if not too large, do some extra check:
+ if(A.rows()<2000)
+ {
+ // test initialization ctor
+ {
+ Rhs x(b.rows(), b.cols());
+ Solver solver2(A);
+ VERIFY(solver2.info() == Success);
+ x = solver2.solve(b);
+ VERIFY(x.isApprox(refX1,test_precision<Scalar>()));
+ }
+
+ // test dense Block as the result and rhs:
+ {
+ DenseRhs x(refX1.rows(), refX1.cols());
+ DenseRhs oldb(db);
+ x.setZero();
+ x.block(0,0,x.rows(),x.cols()) = solver.solve(db.block(0,0,db.rows(),db.cols()));
+ VERIFY(oldb.isApprox(db,0.0) && "sparse solver testing: the rhs should not be modified!");
+ VERIFY(x.isApprox(refX1,test_precision<Scalar>()));
+ }
+
+ // test uncompressed inputs
+ {
+ Mat A2 = A;
+ A2.reserve((ArrayXf::Random(A.outerSize())+2).template cast<typename Mat::StorageIndex>().eval());
+ solver.compute(A2);
+ Rhs x = solver.solve(b);
+ VERIFY(x.isApprox(refX1,test_precision<Scalar>()));
+ }
+
+ // test expression as input
+ {
+ solver.compute(0.5*(A+A));
+ Rhs x = solver.solve(b);
+ VERIFY(x.isApprox(refX1,test_precision<Scalar>()));
+
+ Solver solver2(0.5*(A+A));
+ Rhs x2 = solver2.solve(b);
+ VERIFY(x2.isApprox(refX1,test_precision<Scalar>()));
+ }
+ }
+}
+
+
template<typename Solver, typename Rhs>
void check_sparse_solving_real_cases(Solver& solver, const typename Solver::MatrixType& A, const Rhs& b, const typename Solver::MatrixType& fullA, const Rhs& refX)
{