aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/src/KroneckerProduct
diff options
context:
space:
mode:
authorGravatar Chen-Pang He <jdh8@ms63.hinet.net>2012-10-15 19:45:50 +0800
committerGravatar Chen-Pang He <jdh8@ms63.hinet.net>2012-10-15 19:45:50 +0800
commit0508a0620b51a9faaffea0c520b5c1840dd32d29 (patch)
treef7d247cd8d6d68b915c8d2f63c2809a0d5686863 /unsupported/Eigen/src/KroneckerProduct
parent8284e7134b59bb8d4307d1207cc4bea5c68d5674 (diff)
Let KroneckerProduct inherit ReturnByValue to eliminate temporary evaluation. It's uncommon to store the product back to one of the operands.
Diffstat (limited to 'unsupported/Eigen/src/KroneckerProduct')
-rw-r--r--unsupported/Eigen/src/KroneckerProduct/KroneckerTensorProduct.h141
1 files changed, 73 insertions, 68 deletions
diff --git a/unsupported/Eigen/src/KroneckerProduct/KroneckerTensorProduct.h b/unsupported/Eigen/src/KroneckerProduct/KroneckerTensorProduct.h
index c33d8f0ce..5149566a9 100644
--- a/unsupported/Eigen/src/KroneckerProduct/KroneckerTensorProduct.h
+++ b/unsupported/Eigen/src/KroneckerProduct/KroneckerTensorProduct.h
@@ -18,59 +18,6 @@
namespace Eigen {
-namespace internal {
-
-template<typename _Lhs, typename _Rhs>
-struct traits<KroneckerProduct<_Lhs,_Rhs> >
-{
- typedef MatrixXpr XprKind;
- typedef typename remove_all<_Lhs>::type Lhs;
- typedef typename remove_all<_Rhs>::type Rhs;
- typedef typename scalar_product_traits<typename Lhs::Scalar, typename Rhs::Scalar>::ReturnType Scalar;
- typedef Dense StorageKind;
- typedef typename promote_index_type<typename Lhs::Index, typename Rhs::Index>::type Index;
-
- enum {
- RowsAtCompileTime = EIGEN_SIZE_PRODUCT(traits<Lhs>::RowsAtCompileTime, traits<Rhs>::RowsAtCompileTime),
- ColsAtCompileTime = EIGEN_SIZE_PRODUCT(traits<Lhs>::ColsAtCompileTime, traits<Rhs>::ColsAtCompileTime),
- MaxRowsAtCompileTime = EIGEN_SIZE_PRODUCT(traits<Lhs>::MaxRowsAtCompileTime, traits<Rhs>::MaxRowsAtCompileTime),
- MaxColsAtCompileTime = EIGEN_SIZE_PRODUCT(traits<Lhs>::MaxColsAtCompileTime, traits<Rhs>::MaxColsAtCompileTime),
- Flags = (MaxRowsAtCompileTime==1 ? RowMajorBit : 0)
- | EvalBeforeNestingBit | EvalBeforeAssigningBit | NestByRefBit,
- CoeffReadCost = Lhs::CoeffReadCost + Rhs::CoeffReadCost + NumTraits<Scalar>::MulCost
- };
-};
-
-template<typename _Lhs, typename _Rhs>
-struct traits<KroneckerProductSparse<_Lhs,_Rhs> >
-{
- typedef MatrixXpr XprKind;
- typedef typename remove_all<_Lhs>::type Lhs;
- typedef typename remove_all<_Rhs>::type Rhs;
- typedef typename scalar_product_traits<typename Lhs::Scalar, typename Rhs::Scalar>::ReturnType Scalar;
- typedef Sparse StorageKind;
- typedef typename promote_index_type<typename Lhs::Index, typename Rhs::Index>::type Index;
-
- enum {
- LhsFlags = Lhs::Flags,
- RhsFlags = Rhs::Flags,
-
- RowsAtCompileTime = EIGEN_SIZE_PRODUCT(traits<Lhs>::RowsAtCompileTime, traits<Rhs>::RowsAtCompileTime),
- ColsAtCompileTime = EIGEN_SIZE_PRODUCT(traits<Lhs>::ColsAtCompileTime, traits<Rhs>::ColsAtCompileTime),
- MaxRowsAtCompileTime = EIGEN_SIZE_PRODUCT(traits<Lhs>::MaxRowsAtCompileTime, traits<Rhs>::MaxRowsAtCompileTime),
- MaxColsAtCompileTime = EIGEN_SIZE_PRODUCT(traits<Lhs>::MaxColsAtCompileTime, traits<Rhs>::MaxColsAtCompileTime),
-
- EvalToRowMajor = (LhsFlags & RhsFlags & RowMajorBit),
- RemovedBits = ~(EvalToRowMajor ? 0 : RowMajorBit),
-
- Flags = ((LhsFlags | RhsFlags) & HereditaryBits & RemovedBits)
- | EvalBeforeNestingBit | EvalBeforeAssigningBit,
- CoeffReadCost = Dynamic
- };
-};
-
-} // end namespace internal
-
/*!
* \brief Kronecker tensor product helper class for dense matrices
*
@@ -82,12 +29,14 @@ struct traits<KroneckerProductSparse<_Lhs,_Rhs> >
* \tparam Rhs Type of the rignt-hand side, a matrix expression.
*/
template<typename Lhs, typename Rhs>
-class KroneckerProduct : public MatrixBase<KroneckerProduct<Lhs,Rhs> >
+class KroneckerProduct : public ReturnByValue<KroneckerProduct<Lhs,Rhs> >
{
- public:
- typedef MatrixBase<KroneckerProduct> Base;
- EIGEN_DENSE_PUBLIC_INTERFACE(KroneckerProduct)
+ private:
+ typedef ReturnByValue<KroneckerProduct> Base;
+ typedef typename Base::Scalar Scalar;
+ typedef typename Base::Index Index;
+ public:
/*! \brief Constructor. */
KroneckerProduct(const Lhs& A, const Rhs& B)
: m_A(A), m_B(B)
@@ -99,13 +48,13 @@ class KroneckerProduct : public MatrixBase<KroneckerProduct<Lhs,Rhs> >
inline Index rows() const { return m_A.rows() * m_B.rows(); }
inline Index cols() const { return m_A.cols() * m_B.cols(); }
- typename Base::CoeffReturnType coeff(Index row, Index col) const
+ Scalar coeff(Index row, Index col) const
{
return m_A.coeff(row / m_A.cols(), col / m_A.rows()) *
m_B.coeff(row % m_A.cols(), col % m_A.rows());
}
- typename Base::CoeffReturnType coeff(Index i) const
+ Scalar coeff(Index i) const
{
EIGEN_STATIC_ASSERT_VECTOR_ONLY(KroneckerProduct);
return m_A.coeff(i / m_A.size()) * m_B.coeff(i % m_A.size());
@@ -198,9 +147,71 @@ void KroneckerProductSparse<Lhs,Rhs>::evalTo(Dest& dst) const
}
}
+namespace internal {
+
+template<typename _Lhs, typename _Rhs>
+struct traits<KroneckerProduct<_Lhs,_Rhs> >
+{
+ typedef typename remove_all<_Lhs>::type Lhs;
+ typedef typename remove_all<_Rhs>::type Rhs;
+ typedef typename scalar_product_traits<typename Lhs::Scalar, typename Rhs::Scalar>::ReturnType Scalar;
+
+ enum {
+ Rows = EIGEN_SIZE_PRODUCT(traits<Lhs>::RowsAtCompileTime, traits<Rhs>::RowsAtCompileTime),
+ Cols = EIGEN_SIZE_PRODUCT(traits<Lhs>::ColsAtCompileTime, traits<Rhs>::ColsAtCompileTime),
+ MaxRows = EIGEN_SIZE_PRODUCT(traits<Lhs>::MaxRowsAtCompileTime, traits<Rhs>::MaxRowsAtCompileTime),
+ MaxCols = EIGEN_SIZE_PRODUCT(traits<Lhs>::MaxColsAtCompileTime, traits<Rhs>::MaxColsAtCompileTime),
+ CoeffReadCost = Lhs::CoeffReadCost + Rhs::CoeffReadCost + NumTraits<Scalar>::MulCost
+ };
+
+ typedef Matrix<Scalar,Rows,Cols> ReturnType;
+};
+
+template<typename _Lhs, typename _Rhs>
+struct traits<KroneckerProductSparse<_Lhs,_Rhs> >
+{
+ typedef MatrixXpr XprKind;
+ typedef typename remove_all<_Lhs>::type Lhs;
+ typedef typename remove_all<_Rhs>::type Rhs;
+ typedef typename scalar_product_traits<typename Lhs::Scalar, typename Rhs::Scalar>::ReturnType Scalar;
+ typedef Sparse StorageKind;
+ typedef typename promote_index_type<typename Lhs::Index, typename Rhs::Index>::type Index;
+
+ enum {
+ LhsFlags = Lhs::Flags,
+ RhsFlags = Rhs::Flags,
+
+ RowsAtCompileTime = EIGEN_SIZE_PRODUCT(traits<Lhs>::RowsAtCompileTime, traits<Rhs>::RowsAtCompileTime),
+ ColsAtCompileTime = EIGEN_SIZE_PRODUCT(traits<Lhs>::ColsAtCompileTime, traits<Rhs>::ColsAtCompileTime),
+ MaxRowsAtCompileTime = EIGEN_SIZE_PRODUCT(traits<Lhs>::MaxRowsAtCompileTime, traits<Rhs>::MaxRowsAtCompileTime),
+ MaxColsAtCompileTime = EIGEN_SIZE_PRODUCT(traits<Lhs>::MaxColsAtCompileTime, traits<Rhs>::MaxColsAtCompileTime),
+
+ EvalToRowMajor = (LhsFlags & RhsFlags & RowMajorBit),
+ RemovedBits = ~(EvalToRowMajor ? 0 : RowMajorBit),
+
+ Flags = ((LhsFlags | RhsFlags) & HereditaryBits & RemovedBits)
+ | EvalBeforeNestingBit | EvalBeforeAssigningBit,
+ CoeffReadCost = Dynamic
+ };
+};
+
+} // end namespace internal
+
/*!
+ * \ingroup KroneckerProduct_Module
+ *
* Computes Kronecker tensor product of two dense matrices
*
+ * \warning If you want to replace a matrix by its Kronecker product
+ * with some matrix, do \b NOT do this:
+ * \code
+ * A = kroneckerProduct(A,B); // bug!!! caused by aliasing effect
+ * \endcode
+ * instead, use eval() to work around this:
+ * \code
+ * A = kroneckerProduct(A,B).eval();
+ * \endcode
+ *
* \param a Dense matrix a
* \param b Dense matrix b
* \return Kronecker tensor product of a and b
@@ -212,8 +223,10 @@ KroneckerProduct<A,B> kroneckerProduct(const MatrixBase<A>& a, const MatrixBase<
}
/*!
+ * \ingroup KroneckerProduct_Module
+ *
* Computes Kronecker tensor product of two matrices, at least one of
- * which is sparse.
+ * which is sparse
*
* \param a Dense/sparse matrix a
* \param b Dense/sparse matrix b
@@ -228,14 +241,6 @@ KroneckerProductSparse<A,B> kroneckerProduct(const EigenBase<A>& a, const EigenB
template<typename Derived>
template<typename Lhs, typename Rhs>
-Derived& MatrixBase<Derived>::lazyAssign(const KroneckerProduct<Lhs,Rhs>& other)
-{
- other.evalTo(derived());
- return derived();
-}
-
-template<typename Derived>
-template<typename Lhs, typename Rhs>
Derived& SparseMatrixBase<Derived>::operator=(const KroneckerProductSparse<Lhs,Rhs>& product)
{
product.evalTo(derived());