aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported
diff options
context:
space:
mode:
authorGravatar Gael Guennebaud <g.gael@free.fr>2014-02-14 14:46:01 +0100
committerGravatar Gael Guennebaud <g.gael@free.fr>2014-02-14 14:46:01 +0100
commit3283d98d1378ccd3a8c89eec1d88108ddd517d95 (patch)
tree42193cf1b6b7fc3c2051df808aeb03d817e393ea /unsupported
parent0d3f496233ceb0e96da0a39e360e5bdd5c89c0e3 (diff)
optimize sparse-sparse Kronecker product
Diffstat (limited to 'unsupported')
-rw-r--r--unsupported/Eigen/src/KroneckerProduct/KroneckerTensorProduct.h17
-rw-r--r--unsupported/test/kronecker_product.cpp34
2 files changed, 50 insertions, 1 deletions
diff --git a/unsupported/Eigen/src/KroneckerProduct/KroneckerTensorProduct.h b/unsupported/Eigen/src/KroneckerProduct/KroneckerTensorProduct.h
index a4516056d..b8f2cba17 100644
--- a/unsupported/Eigen/src/KroneckerProduct/KroneckerTensorProduct.h
+++ b/unsupported/Eigen/src/KroneckerProduct/KroneckerTensorProduct.h
@@ -153,7 +153,22 @@ void KroneckerProductSparse<Lhs,Rhs>::evalTo(Dest& dst) const
Bc = m_B.cols();
dst.resize(this->rows(), this->cols());
dst.resizeNonZeros(0);
- dst.reserve(m_A.nonZeros() * m_B.nonZeros());
+
+ // compute number of non-zeros per innervectors of dst
+ {
+ VectorXi nnzA = VectorXi::Zero(Dest::IsRowMajor ? m_A.rows() : m_A.cols());
+ for (Index kA=0; kA < m_A.outerSize(); ++kA)
+ for (typename Lhs::InnerIterator itA(m_A,kA); itA; ++itA)
+ nnzA(Dest::IsRowMajor ? itA.row() : itA.col())++;
+
+ VectorXi nnzB = VectorXi::Zero(Dest::IsRowMajor ? m_B.rows() : m_B.cols());
+ for (Index kB=0; kB < m_B.outerSize(); ++kB)
+ for (typename Rhs::InnerIterator itB(m_B,kB); itB; ++itB)
+ nnzB(Dest::IsRowMajor ? itB.row() : itB.col())++;
+
+ Matrix<int,Dynamic,Dynamic,ColMajor> nnzAB = nnzB * nnzA.transpose();
+ dst.reserve(VectorXi::Map(nnzAB.data(), nnzAB.size()));
+ }
for (Index kA=0; kA < m_A.outerSize(); ++kA)
{
diff --git a/unsupported/test/kronecker_product.cpp b/unsupported/test/kronecker_product.cpp
index c68a07de8..753a2d417 100644
--- a/unsupported/test/kronecker_product.cpp
+++ b/unsupported/test/kronecker_product.cpp
@@ -183,4 +183,38 @@ void test_kronecker_product()
DM_b2.resize(4,8);
DM_ab2 = kroneckerProduct(DM_a2,DM_b2);
CALL_SUBTEST(check_dimension(DM_ab2,10*4,9*8));
+
+ for(int i = 0; i < g_repeat; i++)
+ {
+ double density = Eigen::internal::random<double>(0.01,0.5);
+ int ra = Eigen::internal::random<int>(1,50);
+ int ca = Eigen::internal::random<int>(1,50);
+ int rb = Eigen::internal::random<int>(1,50);
+ int cb = Eigen::internal::random<int>(1,50);
+ SparseMatrix<float,ColMajor> sA(ra,ca), sB(rb,cb), sC;
+ SparseMatrix<float,RowMajor> sC2;
+ MatrixXf dA(ra,ca), dB(rb,cb), dC;
+ initSparse(density, dA, sA);
+ initSparse(density, dB, sB);
+
+ sC = kroneckerProduct(sA,sB);
+ dC = kroneckerProduct(dA,dB);
+ VERIFY_IS_APPROX(MatrixXf(sC),dC);
+
+ sC = kroneckerProduct(sA.transpose(),sB);
+ dC = kroneckerProduct(dA.transpose(),dB);
+ VERIFY_IS_APPROX(MatrixXf(sC),dC);
+
+ sC = kroneckerProduct(sA.transpose(),sB.transpose());
+ dC = kroneckerProduct(dA.transpose(),dB.transpose());
+ VERIFY_IS_APPROX(MatrixXf(sC),dC);
+
+ sC = kroneckerProduct(sA,sB.transpose());
+ dC = kroneckerProduct(dA,dB.transpose());
+ VERIFY_IS_APPROX(MatrixXf(sC),dC);
+
+ sC2 = kroneckerProduct(sA,sB);
+ dC = kroneckerProduct(dA,dB);
+ VERIFY_IS_APPROX(MatrixXf(sC2),dC);
+ }
}