aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Gael Guennebaud <g.gael@free.fr>2014-09-29 13:37:49 +0200
committerGravatar Gael Guennebaud <g.gael@free.fr>2014-09-29 13:37:49 +0200
commit842e31cf5c8fd31f394156ada84a1aeeab89ef7e (patch)
treea21201247a051d28816713d31d92cabb53d14697
parentabd3502e9ea3e659c39dd5edc17d6deabd26e048 (diff)
Let KroneckerProduct exploits the recently introduced generic InnerIterator class.
-rw-r--r--unsupported/Eigen/src/KroneckerProduct/KroneckerTensorProduct.h33
-rw-r--r--unsupported/test/kronecker_product.cpp12
2 files changed, 22 insertions, 23 deletions
diff --git a/unsupported/Eigen/src/KroneckerProduct/KroneckerTensorProduct.h b/unsupported/Eigen/src/KroneckerProduct/KroneckerTensorProduct.h
index 608c72021..b459360df 100644
--- a/unsupported/Eigen/src/KroneckerProduct/KroneckerTensorProduct.h
+++ b/unsupported/Eigen/src/KroneckerProduct/KroneckerTensorProduct.h
@@ -157,40 +157,27 @@ void KroneckerProductSparse<Lhs,Rhs>::evalTo(Dest& dst) const
dst.resizeNonZeros(0);
// 1 - evaluate the operands if needed:
- typedef typename internal::nested_eval<Lhs,10>::type Lhs1;
+ typedef typename internal::nested_eval<Lhs,Dynamic>::type Lhs1;
typedef typename internal::remove_all<Lhs1>::type Lhs1Cleaned;
const Lhs1 lhs1(m_A);
- typedef typename internal::nested_eval<Rhs,10>::type Rhs1;
+ typedef typename internal::nested_eval<Rhs,Dynamic>::type Rhs1;
typedef typename internal::remove_all<Rhs1>::type Rhs1Cleaned;
const Rhs1 rhs1(m_B);
-
- // 2 - construct a SparseView for dense operands
- typedef typename internal::conditional<internal::is_same<typename internal::traits<Lhs1Cleaned>::StorageKind,Sparse>::value, Lhs1, SparseView<const Lhs1Cleaned> >::type Lhs2;
- typedef typename internal::remove_all<Lhs2>::type Lhs2Cleaned;
- const Lhs2 lhs2(lhs1);
- typedef typename internal::conditional<internal::is_same<typename internal::traits<Rhs1Cleaned>::StorageKind,Sparse>::value, Rhs1, SparseView<const Rhs1Cleaned> >::type Rhs2;
- typedef typename internal::remove_all<Rhs2>::type Rhs2Cleaned;
- const Rhs2 rhs2(rhs1);
-
- // 3 - construct respective evaluators
- typedef typename internal::evaluator<Lhs2Cleaned>::type LhsEval;
- LhsEval lhsEval(lhs2);
- typedef typename internal::evaluator<Rhs2Cleaned>::type RhsEval;
- RhsEval rhsEval(rhs2);
-
- typedef typename LhsEval::InnerIterator LhsInnerIterator;
- typedef typename RhsEval::InnerIterator RhsInnerIterator;
+
+ // 2 - construct respective iterators
+ typedef InnerIterator<Lhs1Cleaned> LhsInnerIterator;
+ typedef InnerIterator<Rhs1Cleaned> RhsInnerIterator;
// compute number of non-zeros per innervectors of dst
{
VectorXi nnzA = VectorXi::Zero(Dest::IsRowMajor ? m_A.rows() : m_A.cols());
for (typename Lhs::Index kA=0; kA < m_A.outerSize(); ++kA)
- for (LhsInnerIterator itA(lhsEval,kA); itA; ++itA)
+ for (LhsInnerIterator itA(lhs1,kA); itA; ++itA)
nnzA(Dest::IsRowMajor ? itA.row() : itA.col())++;
VectorXi nnzB = VectorXi::Zero(Dest::IsRowMajor ? m_B.rows() : m_B.cols());
for (typename Rhs::Index kB=0; kB < m_B.outerSize(); ++kB)
- for (RhsInnerIterator itB(rhsEval,kB); itB; ++itB)
+ for (RhsInnerIterator itB(rhs1,kB); itB; ++itB)
nnzB(Dest::IsRowMajor ? itB.row() : itB.col())++;
Matrix<int,Dynamic,Dynamic,ColMajor> nnzAB = nnzB * nnzA.transpose();
@@ -201,9 +188,9 @@ void KroneckerProductSparse<Lhs,Rhs>::evalTo(Dest& dst) const
{
for (typename Rhs::Index kB=0; kB < m_B.outerSize(); ++kB)
{
- for (LhsInnerIterator itA(lhsEval,kA); itA; ++itA)
+ for (LhsInnerIterator itA(lhs1,kA); itA; ++itA)
{
- for (RhsInnerIterator itB(rhsEval,kB); itB; ++itB)
+ for (RhsInnerIterator itB(rhs1,kB); itB; ++itB)
{
const DestIndex
i = DestIndex(itA.row() * Br + itB.row()),
diff --git a/unsupported/test/kronecker_product.cpp b/unsupported/test/kronecker_product.cpp
index 753a2d417..02411a262 100644
--- a/unsupported/test/kronecker_product.cpp
+++ b/unsupported/test/kronecker_product.cpp
@@ -216,5 +216,17 @@ void test_kronecker_product()
sC2 = kroneckerProduct(sA,sB);
dC = kroneckerProduct(dA,dB);
VERIFY_IS_APPROX(MatrixXf(sC2),dC);
+
+ sC2 = kroneckerProduct(dA,sB);
+ dC = kroneckerProduct(dA,dB);
+ VERIFY_IS_APPROX(MatrixXf(sC2),dC);
+
+ sC2 = kroneckerProduct(sA,dB);
+ dC = kroneckerProduct(dA,dB);
+ VERIFY_IS_APPROX(MatrixXf(sC2),dC);
+
+ sC2 = kroneckerProduct(2*sA,sB);
+ dC = kroneckerProduct(2*dA,dB);
+ VERIFY_IS_APPROX(MatrixXf(sC2),dC);
}
}