aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported
diff options
context:
space:
mode:
authorGravatar Chen-Pang He <jdh8@ms63.hinet.net>2013-07-15 09:10:17 +0800
committerGravatar Chen-Pang He <jdh8@ms63.hinet.net>2013-07-15 09:10:17 +0800
commit4b780553e05e35cb4813cd57eaf12befb6062891 (patch)
treef74ea7f4aa6e9de820bb0778a50f2db491f814a9 /unsupported
parent9be658f7015161989b7ecccd70fd050ce563cad9 (diff)
Eliminate unnecessary copying for sparse Kronecker product.
Diffstat (limited to 'unsupported')
-rw-r--r--unsupported/Eigen/src/KroneckerProduct/KroneckerTensorProduct.h104
-rw-r--r--unsupported/test/kronecker_product.cpp21
2 files changed, 74 insertions, 51 deletions
diff --git a/unsupported/Eigen/src/KroneckerProduct/KroneckerTensorProduct.h b/unsupported/Eigen/src/KroneckerProduct/KroneckerTensorProduct.h
index 532896c3b..6ec8eb558 100644
--- a/unsupported/Eigen/src/KroneckerProduct/KroneckerTensorProduct.h
+++ b/unsupported/Eigen/src/KroneckerProduct/KroneckerTensorProduct.h
@@ -14,35 +14,23 @@
namespace Eigen {
-template<typename Scalar, int Options, typename Index> class SparseMatrix;
-
-/*!
- * \brief Kronecker tensor product helper class for dense matrices
- *
- * This class is the return value of kroneckerProduct(MatrixBase,
- * MatrixBase). Use the function rather than construct this class
- * directly to avoid specifying template prarameters.
- *
- * \tparam Lhs Type of the left-hand side, a matrix expression.
- * \tparam Rhs Type of the rignt-hand side, a matrix expression.
- */
-template<typename Lhs, typename Rhs>
-class KroneckerProduct : public ReturnByValue<KroneckerProduct<Lhs,Rhs> >
+template<typename Derived>
+class KroneckerProductBase : public ReturnByValue<Derived>
{
private:
- typedef ReturnByValue<KroneckerProduct> Base;
- typedef typename Base::Scalar Scalar;
- typedef typename Base::Index Index;
+ typedef typename internal::traits<Derived> Traits;
+ typedef typename Traits::Lhs Lhs;
+ typedef typename Traits::Rhs Rhs;
+ typedef typename Traits::Scalar Scalar;
+
+ protected:
+ typedef typename Traits::Index Index;
public:
- /*! \brief Constructor. */
- KroneckerProduct(const Lhs& A, const Rhs& B)
+ KroneckerProductBase(const Lhs& A, const Rhs& B)
: m_A(A), m_B(B)
{}
- /*! \brief Evaluate the Kronecker tensor product. */
- template<typename Dest> void evalTo(Dest& dst) const;
-
inline Index rows() const { return m_A.rows() * m_B.rows(); }
inline Index cols() const { return m_A.cols() * m_B.cols(); }
@@ -54,16 +42,44 @@ class KroneckerProduct : public ReturnByValue<KroneckerProduct<Lhs,Rhs> >
Scalar coeff(Index i) const
{
- EIGEN_STATIC_ASSERT_VECTOR_ONLY(KroneckerProduct);
+ EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived);
return m_A.coeff(i / m_A.size()) * m_B.coeff(i % m_A.size());
}
- private:
+ protected:
typename Lhs::Nested m_A;
typename Rhs::Nested m_B;
};
/*!
+ * \brief Kronecker tensor product helper class for dense matrices
+ *
+ * This class is the return value of kroneckerProduct(MatrixBase,
+ * MatrixBase). Use the function rather than construct this class
+ * directly to avoid specifying template prarameters.
+ *
+ * \tparam Lhs Type of the left-hand side, a matrix expression.
+ * \tparam Rhs Type of the rignt-hand side, a matrix expression.
+ */
+template<typename Lhs, typename Rhs>
+class KroneckerProduct : public KroneckerProductBase<KroneckerProduct<Lhs,Rhs> >
+{
+ private:
+ typedef KroneckerProductBase<KroneckerProduct> Base;
+ using Base::m_A;
+ using Base::m_B;
+
+ public:
+ /*! \brief Constructor. */
+ KroneckerProduct(const Lhs& A, const Rhs& B)
+ : Base(A, B)
+ {}
+
+ /*! \brief Evaluate the Kronecker tensor product. */
+ template<typename Dest> void evalTo(Dest& dst) const;
+};
+
+/*!
* \brief Kronecker tensor product helper class for sparse matrices
*
* If at least one of the operands is a sparse matrix expression,
@@ -77,40 +93,28 @@ class KroneckerProduct : public ReturnByValue<KroneckerProduct<Lhs,Rhs> >
* \tparam Rhs Type of the rignt-hand side, a matrix expression.
*/
template<typename Lhs, typename Rhs>
-class KroneckerProductSparse : public EigenBase<KroneckerProductSparse<Lhs,Rhs> >
+class KroneckerProductSparse : public KroneckerProductBase<KroneckerProductSparse<Lhs,Rhs> >
{
private:
- typedef typename internal::traits<KroneckerProductSparse>::Index Index;
+ typedef KroneckerProductBase<KroneckerProductSparse> Base;
+ using Base::m_A;
+ using Base::m_B;
public:
/*! \brief Constructor. */
KroneckerProductSparse(const Lhs& A, const Rhs& B)
- : m_A(A), m_B(B)
+ : Base(A, B)
{}
/*! \brief Evaluate the Kronecker tensor product. */
template<typename Dest> void evalTo(Dest& dst) const;
-
- inline Index rows() const { return m_A.rows() * m_B.rows(); }
- inline Index cols() const { return m_A.cols() * m_B.cols(); }
-
- template<typename Scalar, int Options, typename Index>
- operator SparseMatrix<Scalar, Options, Index>()
- {
- SparseMatrix<Scalar, Options, Index> result;
- evalTo(result.derived());
- return result;
- }
-
- private:
- typename Lhs::Nested m_A;
- typename Rhs::Nested m_B;
};
template<typename Lhs, typename Rhs>
template<typename Dest>
void KroneckerProduct<Lhs,Rhs>::evalTo(Dest& dst) const
{
+ typedef typename Base::Index Index;
const int BlockRows = Rhs::RowsAtCompileTime,
BlockCols = Rhs::ColsAtCompileTime;
const Index Br = m_B.rows(),
@@ -124,9 +128,10 @@ template<typename Lhs, typename Rhs>
template<typename Dest>
void KroneckerProductSparse<Lhs,Rhs>::evalTo(Dest& dst) const
{
+ typedef typename Base::Index Index;
const Index Br = m_B.rows(),
Bc = m_B.cols();
- dst.resize(rows(),cols());
+ dst.resize(this->rows(), this->cols());
dst.resizeNonZeros(0);
dst.reserve(m_A.nonZeros() * m_B.nonZeros());
@@ -155,6 +160,7 @@ struct traits<KroneckerProduct<_Lhs,_Rhs> >
typedef typename remove_all<_Lhs>::type Lhs;
typedef typename remove_all<_Rhs>::type Rhs;
typedef typename scalar_product_traits<typename Lhs::Scalar, typename Rhs::Scalar>::ReturnType Scalar;
+ typedef typename promote_index_type<typename Lhs::Index, typename Rhs::Index>::type Index;
enum {
Rows = size_at_compile_time<traits<Lhs>::RowsAtCompileTime, traits<Rhs>::RowsAtCompileTime>::ret,
@@ -193,6 +199,8 @@ struct traits<KroneckerProductSparse<_Lhs,_Rhs> >
| EvalBeforeNestingBit | EvalBeforeAssigningBit,
CoeffReadCost = Dynamic
};
+
+ typedef SparseMatrix<Scalar> ReturnType;
};
} // end namespace internal
@@ -228,6 +236,16 @@ KroneckerProduct<A,B> kroneckerProduct(const MatrixBase<A>& a, const MatrixBase<
* Computes Kronecker tensor product of two matrices, at least one of
* which is sparse
*
+ * \warning If you want to replace a matrix by its Kronecker product
+ * with some matrix, do \b NOT do this:
+ * \code
+ * A = kroneckerProduct(A,B); // bug!!! caused by aliasing effect
+ * \endcode
+ * instead, use eval() to work around this:
+ * \code
+ * A = kroneckerProduct(A,B).eval();
+ * \endcode
+ *
* \param a Dense/sparse matrix a
* \param b Dense/sparse matrix b
* \return Kronecker tensor product of a and b, stored in a sparse
diff --git a/unsupported/test/kronecker_product.cpp b/unsupported/test/kronecker_product.cpp
index 8ddc6ec28..c68a07de8 100644
--- a/unsupported/test/kronecker_product.cpp
+++ b/unsupported/test/kronecker_product.cpp
@@ -107,31 +107,34 @@ void test_kronecker_product()
SparseMatrix<double,RowMajor> SM_row_a(SM_a), SM_row_b(SM_b);
- // test kroneckerProduct(DM_block,DM,DM_fixedSize)
+ // test DM_fixedSize = kroneckerProduct(DM_block,DM)
Matrix<double, 6, 6> DM_fix_ab = kroneckerProduct(DM_a.topLeftCorner<2,3>(),DM_b);
CALL_SUBTEST(check_kronecker_product(DM_fix_ab));
+ CALL_SUBTEST(check_kronecker_product(kroneckerProduct(DM_a.topLeftCorner<2,3>(),DM_b)));
for(int i=0;i<DM_fix_ab.rows();++i)
for(int j=0;j<DM_fix_ab.cols();++j)
VERIFY_IS_APPROX(kroneckerProduct(DM_a,DM_b).coeff(i,j), DM_fix_ab(i,j));
- // test kroneckerProduct(DM,DM,DM_block)
+ // test DM_block = kroneckerProduct(DM,DM)
MatrixXd DM_block_ab(10,15);
DM_block_ab.block<6,6>(2,5) = kroneckerProduct(DM_a,DM_b);
CALL_SUBTEST(check_kronecker_product(DM_block_ab.block<6,6>(2,5)));
- // test kroneckerProduct(DM,DM,DM)
+ // test DM = kroneckerProduct(DM,DM)
MatrixXd DM_ab = kroneckerProduct(DM_a,DM_b);
CALL_SUBTEST(check_kronecker_product(DM_ab));
+ CALL_SUBTEST(check_kronecker_product(kroneckerProduct(DM_a,DM_b)));
- // test kroneckerProduct(SM,DM,SM)
+ // test SM = kroneckerProduct(SM,DM)
SparseMatrix<double> SM_ab = kroneckerProduct(SM_a,DM_b);
CALL_SUBTEST(check_kronecker_product(SM_ab));
SparseMatrix<double,RowMajor> SM_ab2 = kroneckerProduct(SM_a,DM_b);
CALL_SUBTEST(check_kronecker_product(SM_ab2));
+ CALL_SUBTEST(check_kronecker_product(kroneckerProduct(SM_a,DM_b)));
- // test kroneckerProduct(DM,SM,SM)
+ // test SM = kroneckerProduct(DM,SM)
SM_ab.setZero();
SM_ab.insert(0,0)=37.0;
SM_ab = kroneckerProduct(DM_a,SM_b);
@@ -140,8 +143,9 @@ void test_kronecker_product()
SM_ab2.insert(0,0)=37.0;
SM_ab2 = kroneckerProduct(DM_a,SM_b);
CALL_SUBTEST(check_kronecker_product(SM_ab2));
+ CALL_SUBTEST(check_kronecker_product(kroneckerProduct(DM_a,SM_b)));
- // test kroneckerProduct(SM,SM,SM)
+ // test SM = kroneckerProduct(SM,SM)
SM_ab.resize(2,33);
SM_ab.insert(0,0)=37.0;
SM_ab = kroneckerProduct(SM_a,SM_b);
@@ -150,8 +154,9 @@ void test_kronecker_product()
SM_ab2.insert(0,0)=37.0;
SM_ab2 = kroneckerProduct(SM_a,SM_b);
CALL_SUBTEST(check_kronecker_product(SM_ab2));
+ CALL_SUBTEST(check_kronecker_product(kroneckerProduct(SM_a,SM_b)));
- // test kroneckerProduct(SM,SM,SM) with sparse pattern
+ // test SM = kroneckerProduct(SM,SM) with sparse pattern
SM_a.resize(4,5);
SM_b.resize(3,2);
SM_a.resizeNonZeros(0);
@@ -169,7 +174,7 @@ void test_kronecker_product()
SM_ab = kroneckerProduct(SM_a,SM_b);
CALL_SUBTEST(check_sparse_kronecker_product(SM_ab));
- // test dimension of result of kroneckerProduct(DM,DM,DM)
+ // test dimension of result of DM = kroneckerProduct(DM,DM)
MatrixXd DM_a2(2,1);
MatrixXd DM_b2(5,4);
MatrixXd DM_ab2 = kroneckerProduct(DM_a2,DM_b2);