diff options
author | Gael Guennebaud <g.gael@free.fr> | 2014-06-27 15:54:44 +0200 |
---|---|---|
committer | Gael Guennebaud <g.gael@free.fr> | 2014-06-27 15:54:44 +0200 |
commit | 73e686c6a44adfd54fe75a0d7581fd14bfa58f54 (patch) | |
tree | 9e4da476ae70595190bdacb8a0df4e079d6620af /Eigen/src/SparseCore | |
parent | ae039dde135a6af852d7028abd772316613a5249 (diff) |
Implement evaluators for sparse times diagonal products.
Diffstat (limited to 'Eigen/src/SparseCore')
-rw-r--r-- | Eigen/src/SparseCore/SparseDiagonalProduct.h | 126 | ||||
-rw-r--r-- | Eigen/src/SparseCore/SparseMatrixBase.h | 14 |
2 files changed, 138 insertions, 2 deletions
diff --git a/Eigen/src/SparseCore/SparseDiagonalProduct.h b/Eigen/src/SparseCore/SparseDiagonalProduct.h index 1bb590e64..cf0f77342 100644 --- a/Eigen/src/SparseCore/SparseDiagonalProduct.h +++ b/Eigen/src/SparseCore/SparseDiagonalProduct.h @@ -24,8 +24,10 @@ namespace Eigen { // for that particular case // The two other cases are symmetric. +#ifndef EIGEN_TEST_EVALUATORS + namespace internal { - + template<typename Lhs, typename Rhs> struct traits<SparseDiagonalProduct<Lhs, Rhs> > { @@ -100,9 +102,14 @@ class SparseDiagonalProduct LhsNested m_lhs; RhsNested m_rhs; }; +#endif namespace internal { +#ifndef EIGEN_TEST_EVALUATORS + + + template<typename Lhs, typename Rhs, typename SparseDiagonalProductType> class sparse_diagonal_product_inner_iterator_selector <Lhs,Rhs,SparseDiagonalProductType,SDP_IsDiagonal,SDP_IsSparseRowMajor> @@ -179,10 +186,124 @@ class sparse_diagonal_product_inner_iterator_selector inline Index row() const { return m_outer; } }; +#else // EIGEN_TEST_EVALUATORS +enum { + SDP_AsScalarProduct, + SDP_AsCwiseProduct +}; + +template<typename SparseXprType, typename DiagonalCoeffType, int SDP_Tag> +struct sparse_diagonal_product_evaluator; + +template<typename Lhs, typename Rhs, int Options, int ProductTag> +struct product_evaluator<Product<Lhs, Rhs, Options>, ProductTag, DiagonalShape, SparseShape, typename Lhs::Scalar, typename Rhs::Scalar> + : public sparse_diagonal_product_evaluator<Rhs, typename Lhs::DiagonalVectorType, Rhs::Flags&RowMajorBit?SDP_AsScalarProduct:SDP_AsCwiseProduct> +{ + typedef Product<Lhs, Rhs, Options> XprType; + typedef evaluator<XprType> type; + typedef evaluator<XprType> nestedType; + enum { CoeffReadCost = Dynamic, Flags = Rhs::Flags&RowMajorBit }; // FIXME CoeffReadCost & Flags + + typedef sparse_diagonal_product_evaluator<Rhs, typename Lhs::DiagonalVectorType, Rhs::Flags&RowMajorBit?SDP_AsScalarProduct:SDP_AsCwiseProduct> Base; + product_evaluator(const XprType& xpr) : Base(xpr.rhs(), xpr.lhs().diagonal()) {} +}; + +template<typename Lhs, typename Rhs, int Options, int ProductTag> +struct product_evaluator<Product<Lhs, Rhs, Options>, ProductTag, SparseShape, DiagonalShape, typename Lhs::Scalar, typename Rhs::Scalar> + : public sparse_diagonal_product_evaluator<Lhs, Transpose<const typename Rhs::DiagonalVectorType>, Lhs::Flags&RowMajorBit?SDP_AsCwiseProduct:SDP_AsScalarProduct> +{ + typedef Product<Lhs, Rhs, Options> XprType; + typedef evaluator<XprType> type; + typedef evaluator<XprType> nestedType; + enum { CoeffReadCost = Dynamic, Flags = Lhs::Flags&RowMajorBit }; // FIXME CoeffReadCost & Flags + + typedef sparse_diagonal_product_evaluator<Lhs, Transpose<const typename Rhs::DiagonalVectorType>, Lhs::Flags&RowMajorBit?SDP_AsCwiseProduct:SDP_AsScalarProduct> Base; + product_evaluator(const XprType& xpr) : Base(xpr.lhs(), xpr.rhs().diagonal()) {} +}; + +template<typename SparseXprType, typename DiagonalCoeffType> +struct sparse_diagonal_product_evaluator<SparseXprType, DiagonalCoeffType, SDP_AsScalarProduct> +{ +protected: + typedef typename evaluator<SparseXprType>::InnerIterator SparseXprInnerIterator; + typedef typename SparseXprType::Scalar Scalar; + typedef typename SparseXprType::Index Index; + +public: + class InnerIterator : public SparseXprInnerIterator + { + public: + InnerIterator(const sparse_diagonal_product_evaluator &xprEval, Index outer) + : SparseXprInnerIterator(xprEval.m_sparseXprImpl, outer), + m_coeff(xprEval.m_diagCoeffImpl.coeff(outer)) + {} + + EIGEN_STRONG_INLINE Scalar value() const { return m_coeff * SparseXprInnerIterator::value(); } + protected: + typename DiagonalCoeffType::Scalar m_coeff; + }; + + sparse_diagonal_product_evaluator(const SparseXprType &sparseXpr, const DiagonalCoeffType &diagCoeff) + : m_sparseXprImpl(sparseXpr), m_diagCoeffImpl(diagCoeff) + {} + +protected: + typename evaluator<SparseXprType>::nestedType m_sparseXprImpl; + typename evaluator<DiagonalCoeffType>::nestedType m_diagCoeffImpl; +}; + + +template<typename SparseXprType, typename DiagCoeffType> +struct sparse_diagonal_product_evaluator<SparseXprType, DiagCoeffType, SDP_AsCwiseProduct> +{ + typedef typename SparseXprType::Scalar Scalar; + typedef typename SparseXprType::Index Index; + + typedef CwiseBinaryOp<scalar_product_op<Scalar>, + const typename SparseXprType::ConstInnerVectorReturnType, + const DiagCoeffType> CwiseProductType; + + typedef typename evaluator<CwiseProductType>::type CwiseProductEval; + typedef typename evaluator<CwiseProductType>::InnerIterator CwiseProductIterator; + + class InnerIterator : public CwiseProductIterator + { + public: + InnerIterator(const sparse_diagonal_product_evaluator &xprEval, Index outer) + : CwiseProductIterator(CwiseProductEval(xprEval.m_sparseXprNested.innerVector(outer).cwiseProduct(xprEval.m_diagCoeffNested)),0), + m_cwiseEval(xprEval.m_sparseXprNested.innerVector(outer).cwiseProduct(xprEval.m_diagCoeffNested)), + m_outer(outer) + { + ::new (static_cast<CwiseProductIterator*>(this)) CwiseProductIterator(m_cwiseEval,0); + } + + inline Index outer() const { return m_outer; } + inline Index col() const { return SparseXprType::IsRowMajor ? CwiseProductIterator::index() : m_outer; } + inline Index row() const { return SparseXprType::IsRowMajor ? m_outer : CwiseProductIterator::index(); } + + protected: + Index m_outer; + CwiseProductEval m_cwiseEval; + }; + + sparse_diagonal_product_evaluator(const SparseXprType &sparseXpr, const DiagCoeffType &diagCoeff) + : m_sparseXprNested(sparseXpr), m_diagCoeffNested(diagCoeff) + {} + +protected: + typename nested_eval<SparseXprType,1>::type m_sparseXprNested; + typename nested_eval<DiagCoeffType,SparseXprType::IsRowMajor ? SparseXprType::RowsAtCompileTime + : SparseXprType::ColsAtCompileTime>::type m_diagCoeffNested; +}; + +#endif // EIGEN_TEST_EVALUATORS + + } // end namespace internal -// SparseMatrixBase functions +#ifndef EIGEN_TEST_EVALUATORS +// SparseMatrixBase functions template<typename Derived> template<typename OtherDerived> const SparseDiagonalProduct<Derived,OtherDerived> @@ -190,6 +311,7 @@ SparseMatrixBase<Derived>::operator*(const DiagonalBase<OtherDerived> &other) co { return SparseDiagonalProduct<Derived,OtherDerived>(this->derived(), other.derived()); } +#endif // EIGEN_TEST_EVALUATORS } // end namespace Eigen diff --git a/Eigen/src/SparseCore/SparseMatrixBase.h b/Eigen/src/SparseCore/SparseMatrixBase.h index 4c46008ae..c71244d3e 100644 --- a/Eigen/src/SparseCore/SparseMatrixBase.h +++ b/Eigen/src/SparseCore/SparseMatrixBase.h @@ -269,6 +269,7 @@ template<typename Derived> class SparseMatrixBase : public EigenBase<Derived> const typename SparseSparseProductReturnType<Derived,OtherDerived>::Type operator*(const SparseMatrixBase<OtherDerived> &other) const; +#ifndef EIGEN_TEST_EVALUATORS // sparse * diagonal template<typename OtherDerived> const SparseDiagonalProduct<Derived,OtherDerived> @@ -279,6 +280,19 @@ template<typename Derived> class SparseMatrixBase : public EigenBase<Derived> const SparseDiagonalProduct<OtherDerived,Derived> operator*(const DiagonalBase<OtherDerived> &lhs, const SparseMatrixBase& rhs) { return SparseDiagonalProduct<OtherDerived,Derived>(lhs.derived(), rhs.derived()); } +#else // EIGEN_TEST_EVALUATORS + // sparse * diagonal + template<typename OtherDerived> + const Product<Derived,OtherDerived> + operator*(const DiagonalBase<OtherDerived> &other) const + { return Product<Derived,OtherDerived>(derived(), other.derived()); } + + // diagonal * sparse + template<typename OtherDerived> friend + const Product<OtherDerived,Derived> + operator*(const DiagonalBase<OtherDerived> &lhs, const SparseMatrixBase& rhs) + { return Product<OtherDerived,Derived>(lhs.derived(), rhs.derived()); } +#endif // EIGEN_TEST_EVALUATORS /** dense * sparse (return a dense object unless it is an outer product) */ template<typename OtherDerived> friend |