diff options
author | Chen-Pang He <jdh8@ms63.hinet.net> | 2012-10-15 19:45:50 +0800 |
---|---|---|
committer | Chen-Pang He <jdh8@ms63.hinet.net> | 2012-10-15 19:45:50 +0800 |
commit | 0508a0620b51a9faaffea0c520b5c1840dd32d29 (patch) | |
tree | f7d247cd8d6d68b915c8d2f63c2809a0d5686863 /unsupported/Eigen/src/KroneckerProduct | |
parent | 8284e7134b59bb8d4307d1207cc4bea5c68d5674 (diff) |
Let KroneckerProduct inherit ReturnByValue to eliminate temporary evaluation. It's uncommon to store the product back to one of the operands.
Diffstat (limited to 'unsupported/Eigen/src/KroneckerProduct')
-rw-r--r-- | unsupported/Eigen/src/KroneckerProduct/KroneckerTensorProduct.h | 141 |
1 files changed, 73 insertions, 68 deletions
diff --git a/unsupported/Eigen/src/KroneckerProduct/KroneckerTensorProduct.h b/unsupported/Eigen/src/KroneckerProduct/KroneckerTensorProduct.h index c33d8f0ce..5149566a9 100644 --- a/unsupported/Eigen/src/KroneckerProduct/KroneckerTensorProduct.h +++ b/unsupported/Eigen/src/KroneckerProduct/KroneckerTensorProduct.h @@ -18,59 +18,6 @@ namespace Eigen { -namespace internal { - -template<typename _Lhs, typename _Rhs> -struct traits<KroneckerProduct<_Lhs,_Rhs> > -{ - typedef MatrixXpr XprKind; - typedef typename remove_all<_Lhs>::type Lhs; - typedef typename remove_all<_Rhs>::type Rhs; - typedef typename scalar_product_traits<typename Lhs::Scalar, typename Rhs::Scalar>::ReturnType Scalar; - typedef Dense StorageKind; - typedef typename promote_index_type<typename Lhs::Index, typename Rhs::Index>::type Index; - - enum { - RowsAtCompileTime = EIGEN_SIZE_PRODUCT(traits<Lhs>::RowsAtCompileTime, traits<Rhs>::RowsAtCompileTime), - ColsAtCompileTime = EIGEN_SIZE_PRODUCT(traits<Lhs>::ColsAtCompileTime, traits<Rhs>::ColsAtCompileTime), - MaxRowsAtCompileTime = EIGEN_SIZE_PRODUCT(traits<Lhs>::MaxRowsAtCompileTime, traits<Rhs>::MaxRowsAtCompileTime), - MaxColsAtCompileTime = EIGEN_SIZE_PRODUCT(traits<Lhs>::MaxColsAtCompileTime, traits<Rhs>::MaxColsAtCompileTime), - Flags = (MaxRowsAtCompileTime==1 ? RowMajorBit : 0) - | EvalBeforeNestingBit | EvalBeforeAssigningBit | NestByRefBit, - CoeffReadCost = Lhs::CoeffReadCost + Rhs::CoeffReadCost + NumTraits<Scalar>::MulCost - }; -}; - -template<typename _Lhs, typename _Rhs> -struct traits<KroneckerProductSparse<_Lhs,_Rhs> > -{ - typedef MatrixXpr XprKind; - typedef typename remove_all<_Lhs>::type Lhs; - typedef typename remove_all<_Rhs>::type Rhs; - typedef typename scalar_product_traits<typename Lhs::Scalar, typename Rhs::Scalar>::ReturnType Scalar; - typedef Sparse StorageKind; - typedef typename promote_index_type<typename Lhs::Index, typename Rhs::Index>::type Index; - - enum { - LhsFlags = Lhs::Flags, - RhsFlags = Rhs::Flags, - - RowsAtCompileTime = EIGEN_SIZE_PRODUCT(traits<Lhs>::RowsAtCompileTime, traits<Rhs>::RowsAtCompileTime), - ColsAtCompileTime = EIGEN_SIZE_PRODUCT(traits<Lhs>::ColsAtCompileTime, traits<Rhs>::ColsAtCompileTime), - MaxRowsAtCompileTime = EIGEN_SIZE_PRODUCT(traits<Lhs>::MaxRowsAtCompileTime, traits<Rhs>::MaxRowsAtCompileTime), - MaxColsAtCompileTime = EIGEN_SIZE_PRODUCT(traits<Lhs>::MaxColsAtCompileTime, traits<Rhs>::MaxColsAtCompileTime), - - EvalToRowMajor = (LhsFlags & RhsFlags & RowMajorBit), - RemovedBits = ~(EvalToRowMajor ? 0 : RowMajorBit), - - Flags = ((LhsFlags | RhsFlags) & HereditaryBits & RemovedBits) - | EvalBeforeNestingBit | EvalBeforeAssigningBit, - CoeffReadCost = Dynamic - }; -}; - -} // end namespace internal - /*! * \brief Kronecker tensor product helper class for dense matrices * @@ -82,12 +29,14 @@ struct traits<KroneckerProductSparse<_Lhs,_Rhs> > * \tparam Rhs Type of the rignt-hand side, a matrix expression. */ template<typename Lhs, typename Rhs> -class KroneckerProduct : public MatrixBase<KroneckerProduct<Lhs,Rhs> > +class KroneckerProduct : public ReturnByValue<KroneckerProduct<Lhs,Rhs> > { - public: - typedef MatrixBase<KroneckerProduct> Base; - EIGEN_DENSE_PUBLIC_INTERFACE(KroneckerProduct) + private: + typedef ReturnByValue<KroneckerProduct> Base; + typedef typename Base::Scalar Scalar; + typedef typename Base::Index Index; + public: /*! \brief Constructor. */ KroneckerProduct(const Lhs& A, const Rhs& B) : m_A(A), m_B(B) @@ -99,13 +48,13 @@ class KroneckerProduct : public MatrixBase<KroneckerProduct<Lhs,Rhs> > inline Index rows() const { return m_A.rows() * m_B.rows(); } inline Index cols() const { return m_A.cols() * m_B.cols(); } - typename Base::CoeffReturnType coeff(Index row, Index col) const + Scalar coeff(Index row, Index col) const { return m_A.coeff(row / m_A.cols(), col / m_A.rows()) * m_B.coeff(row % m_A.cols(), col % m_A.rows()); } - typename Base::CoeffReturnType coeff(Index i) const + Scalar coeff(Index i) const { EIGEN_STATIC_ASSERT_VECTOR_ONLY(KroneckerProduct); return m_A.coeff(i / m_A.size()) * m_B.coeff(i % m_A.size()); @@ -198,9 +147,71 @@ void KroneckerProductSparse<Lhs,Rhs>::evalTo(Dest& dst) const } } +namespace internal { + +template<typename _Lhs, typename _Rhs> +struct traits<KroneckerProduct<_Lhs,_Rhs> > +{ + typedef typename remove_all<_Lhs>::type Lhs; + typedef typename remove_all<_Rhs>::type Rhs; + typedef typename scalar_product_traits<typename Lhs::Scalar, typename Rhs::Scalar>::ReturnType Scalar; + + enum { + Rows = EIGEN_SIZE_PRODUCT(traits<Lhs>::RowsAtCompileTime, traits<Rhs>::RowsAtCompileTime), + Cols = EIGEN_SIZE_PRODUCT(traits<Lhs>::ColsAtCompileTime, traits<Rhs>::ColsAtCompileTime), + MaxRows = EIGEN_SIZE_PRODUCT(traits<Lhs>::MaxRowsAtCompileTime, traits<Rhs>::MaxRowsAtCompileTime), + MaxCols = EIGEN_SIZE_PRODUCT(traits<Lhs>::MaxColsAtCompileTime, traits<Rhs>::MaxColsAtCompileTime), + CoeffReadCost = Lhs::CoeffReadCost + Rhs::CoeffReadCost + NumTraits<Scalar>::MulCost + }; + + typedef Matrix<Scalar,Rows,Cols> ReturnType; +}; + +template<typename _Lhs, typename _Rhs> +struct traits<KroneckerProductSparse<_Lhs,_Rhs> > +{ + typedef MatrixXpr XprKind; + typedef typename remove_all<_Lhs>::type Lhs; + typedef typename remove_all<_Rhs>::type Rhs; + typedef typename scalar_product_traits<typename Lhs::Scalar, typename Rhs::Scalar>::ReturnType Scalar; + typedef Sparse StorageKind; + typedef typename promote_index_type<typename Lhs::Index, typename Rhs::Index>::type Index; + + enum { + LhsFlags = Lhs::Flags, + RhsFlags = Rhs::Flags, + + RowsAtCompileTime = EIGEN_SIZE_PRODUCT(traits<Lhs>::RowsAtCompileTime, traits<Rhs>::RowsAtCompileTime), + ColsAtCompileTime = EIGEN_SIZE_PRODUCT(traits<Lhs>::ColsAtCompileTime, traits<Rhs>::ColsAtCompileTime), + MaxRowsAtCompileTime = EIGEN_SIZE_PRODUCT(traits<Lhs>::MaxRowsAtCompileTime, traits<Rhs>::MaxRowsAtCompileTime), + MaxColsAtCompileTime = EIGEN_SIZE_PRODUCT(traits<Lhs>::MaxColsAtCompileTime, traits<Rhs>::MaxColsAtCompileTime), + + EvalToRowMajor = (LhsFlags & RhsFlags & RowMajorBit), + RemovedBits = ~(EvalToRowMajor ? 0 : RowMajorBit), + + Flags = ((LhsFlags | RhsFlags) & HereditaryBits & RemovedBits) + | EvalBeforeNestingBit | EvalBeforeAssigningBit, + CoeffReadCost = Dynamic + }; +}; + +} // end namespace internal + /*! + * \ingroup KroneckerProduct_Module + * * Computes Kronecker tensor product of two dense matrices * + * \warning If you want to replace a matrix by its Kronecker product + * with some matrix, do \b NOT do this: + * \code + * A = kroneckerProduct(A,B); // bug!!! caused by aliasing effect + * \endcode + * instead, use eval() to work around this: + * \code + * A = kroneckerProduct(A,B).eval(); + * \endcode + * * \param a Dense matrix a * \param b Dense matrix b * \return Kronecker tensor product of a and b @@ -212,8 +223,10 @@ KroneckerProduct<A,B> kroneckerProduct(const MatrixBase<A>& a, const MatrixBase< } /*! + * \ingroup KroneckerProduct_Module + * * Computes Kronecker tensor product of two matrices, at least one of - * which is sparse. + * which is sparse * * \param a Dense/sparse matrix a * \param b Dense/sparse matrix b @@ -228,14 +241,6 @@ KroneckerProductSparse<A,B> kroneckerProduct(const EigenBase<A>& a, const EigenB template<typename Derived> template<typename Lhs, typename Rhs> -Derived& MatrixBase<Derived>::lazyAssign(const KroneckerProduct<Lhs,Rhs>& other) -{ - other.evalTo(derived()); - return derived(); -} - -template<typename Derived> -template<typename Lhs, typename Rhs> Derived& SparseMatrixBase<Derived>::operator=(const KroneckerProductSparse<Lhs,Rhs>& product) { product.evalTo(derived()); |