diff options
Diffstat (limited to 'Eigen/src/SparseCore/SparseDiagonalProduct.h')
-rw-r--r-- | Eigen/src/SparseCore/SparseDiagonalProduct.h | 133 |
1 files changed, 130 insertions, 3 deletions
diff --git a/Eigen/src/SparseCore/SparseDiagonalProduct.h b/Eigen/src/SparseCore/SparseDiagonalProduct.h index c056b4914..9f465a828 100644 --- a/Eigen/src/SparseCore/SparseDiagonalProduct.h +++ b/Eigen/src/SparseCore/SparseDiagonalProduct.h @@ -1,7 +1,7 @@ // This file is part of Eigen, a lightweight C++ template library // for linear algebra. // -// Copyright (C) 2009 Gael Guennebaud <gael.guennebaud@inria.fr> +// Copyright (C) 2009-2014 Gael Guennebaud <gael.guennebaud@inria.fr> // // This Source Code Form is subject to the terms of the Mozilla // Public License v. 2.0. If a copy of the MPL was not distributed @@ -24,8 +24,10 @@ namespace Eigen { // for that particular case // The two other cases are symmetric. +#ifndef EIGEN_TEST_EVALUATORS + namespace internal { - + template<typename Lhs, typename Rhs> struct traits<SparseDiagonalProduct<Lhs, Rhs> > { @@ -102,9 +104,14 @@ class SparseDiagonalProduct LhsNested m_lhs; RhsNested m_rhs; }; +#endif namespace internal { +#ifndef EIGEN_TEST_EVALUATORS + + + template<typename Lhs, typename Rhs, typename SparseDiagonalProductType> class sparse_diagonal_product_inner_iterator_selector <Lhs,Rhs,SparseDiagonalProductType,SDP_IsDiagonal,SDP_IsSparseRowMajor> @@ -181,10 +188,129 @@ class sparse_diagonal_product_inner_iterator_selector inline Index row() const { return m_outer; } }; +#else // EIGEN_TEST_EVALUATORS +enum { + SDP_AsScalarProduct, + SDP_AsCwiseProduct +}; + +template<typename SparseXprType, typename DiagonalCoeffType, int SDP_Tag> +struct sparse_diagonal_product_evaluator; + +template<typename Lhs, typename Rhs, int ProductTag> +struct product_evaluator<Product<Lhs, Rhs, DefaultProduct>, ProductTag, DiagonalShape, SparseShape, typename traits<Lhs>::Scalar, typename traits<Rhs>::Scalar> + : public sparse_diagonal_product_evaluator<Rhs, typename Lhs::DiagonalVectorType, Rhs::Flags&RowMajorBit?SDP_AsScalarProduct:SDP_AsCwiseProduct> +{ + typedef Product<Lhs, Rhs, DefaultProduct> XprType; + typedef evaluator<XprType> type; + typedef evaluator<XprType> nestedType; + enum { CoeffReadCost = Dynamic, Flags = Rhs::Flags&RowMajorBit }; // FIXME CoeffReadCost & Flags + + typedef sparse_diagonal_product_evaluator<Rhs, typename Lhs::DiagonalVectorType, Rhs::Flags&RowMajorBit?SDP_AsScalarProduct:SDP_AsCwiseProduct> Base; + product_evaluator(const XprType& xpr) : Base(xpr.rhs(), xpr.lhs().diagonal()) {} +}; + +template<typename Lhs, typename Rhs, int ProductTag> +struct product_evaluator<Product<Lhs, Rhs, DefaultProduct>, ProductTag, SparseShape, DiagonalShape, typename traits<Lhs>::Scalar, typename traits<Rhs>::Scalar> + : public sparse_diagonal_product_evaluator<Lhs, Transpose<const typename Rhs::DiagonalVectorType>, Lhs::Flags&RowMajorBit?SDP_AsCwiseProduct:SDP_AsScalarProduct> +{ + typedef Product<Lhs, Rhs, DefaultProduct> XprType; + typedef evaluator<XprType> type; + typedef evaluator<XprType> nestedType; + enum { CoeffReadCost = Dynamic, Flags = Lhs::Flags&RowMajorBit }; // FIXME CoeffReadCost & Flags + + typedef sparse_diagonal_product_evaluator<Lhs, Transpose<const typename Rhs::DiagonalVectorType>, Lhs::Flags&RowMajorBit?SDP_AsCwiseProduct:SDP_AsScalarProduct> Base; + product_evaluator(const XprType& xpr) : Base(xpr.lhs(), xpr.rhs().diagonal()) {} +}; + +template<typename SparseXprType, typename DiagonalCoeffType> +struct sparse_diagonal_product_evaluator<SparseXprType, DiagonalCoeffType, SDP_AsScalarProduct> +{ +protected: + typedef typename evaluator<SparseXprType>::InnerIterator SparseXprInnerIterator; + typedef typename SparseXprType::Scalar Scalar; + typedef typename SparseXprType::Index Index; + +public: + class InnerIterator : public SparseXprInnerIterator + { + public: + InnerIterator(const sparse_diagonal_product_evaluator &xprEval, Index outer) + : SparseXprInnerIterator(xprEval.m_sparseXprImpl, outer), + m_coeff(xprEval.m_diagCoeffImpl.coeff(outer)) + {} + + EIGEN_STRONG_INLINE Scalar value() const { return m_coeff * SparseXprInnerIterator::value(); } + protected: + typename DiagonalCoeffType::Scalar m_coeff; + }; + + sparse_diagonal_product_evaluator(const SparseXprType &sparseXpr, const DiagonalCoeffType &diagCoeff) + : m_sparseXprImpl(sparseXpr), m_diagCoeffImpl(diagCoeff) + {} + +protected: + typename evaluator<SparseXprType>::nestedType m_sparseXprImpl; + typename evaluator<DiagonalCoeffType>::nestedType m_diagCoeffImpl; +}; + + +template<typename SparseXprType, typename DiagCoeffType> +struct sparse_diagonal_product_evaluator<SparseXprType, DiagCoeffType, SDP_AsCwiseProduct> +{ + typedef typename SparseXprType::Scalar Scalar; + typedef typename SparseXprType::Index Index; + + typedef CwiseBinaryOp<scalar_product_op<Scalar>, + const typename SparseXprType::ConstInnerVectorReturnType, + const DiagCoeffType> CwiseProductType; + + typedef typename evaluator<CwiseProductType>::type CwiseProductEval; + typedef typename evaluator<CwiseProductType>::InnerIterator CwiseProductIterator; + + class InnerIterator + { + public: + InnerIterator(const sparse_diagonal_product_evaluator &xprEval, Index outer) + : m_cwiseEval(xprEval.m_sparseXprNested.innerVector(outer).cwiseProduct(xprEval.m_diagCoeffNested)), + m_cwiseIter(m_cwiseEval, 0), + m_outer(outer) + {} + + inline Scalar value() const { return m_cwiseIter.value(); } + inline Index index() const { return m_cwiseIter.index(); } + inline Index outer() const { return m_outer; } + inline Index col() const { return SparseXprType::IsRowMajor ? m_cwiseIter.index() : m_outer; } + inline Index row() const { return SparseXprType::IsRowMajor ? m_outer : m_cwiseIter.index(); } + + EIGEN_STRONG_INLINE InnerIterator& operator++() + { ++m_cwiseIter; return *this; } + inline operator bool() const { return m_cwiseIter; } + + protected: + CwiseProductEval m_cwiseEval; + CwiseProductIterator m_cwiseIter; + Index m_outer; + }; + + sparse_diagonal_product_evaluator(const SparseXprType &sparseXpr, const DiagCoeffType &diagCoeff) + : m_sparseXprNested(sparseXpr), m_diagCoeffNested(diagCoeff) + {} + +protected: + typename nested_eval<SparseXprType,1>::type m_sparseXprNested; + typename nested_eval<DiagCoeffType,SparseXprType::IsRowMajor ? SparseXprType::RowsAtCompileTime + : SparseXprType::ColsAtCompileTime>::type m_diagCoeffNested; +}; + +#endif // EIGEN_TEST_EVALUATORS + + } // end namespace internal -// SparseMatrixBase functions +#ifndef EIGEN_TEST_EVALUATORS +// SparseMatrixBase functions template<typename Derived> template<typename OtherDerived> const SparseDiagonalProduct<Derived,OtherDerived> @@ -192,6 +318,7 @@ SparseMatrixBase<Derived>::operator*(const DiagonalBase<OtherDerived> &other) co { return SparseDiagonalProduct<Derived,OtherDerived>(this->derived(), other.derived()); } +#endif // EIGEN_TEST_EVALUATORS } // end namespace Eigen |