From e313826890f581f1b9665422bab7b83b9daf5bfd Mon Sep 17 00:00:00 2001 From: Gael Guennebaud Date: Fri, 25 Jun 2010 11:36:38 +0200 Subject: add mixed sparse-dense outer product --- Eigen/src/Sparse/SparseDenseProduct.h | 116 +++++++++++++++++++++++++++++++++- Eigen/src/Sparse/SparseMatrixBase.h | 10 +-- Eigen/src/Sparse/SparseUtil.h | 3 + 3 files changed, 122 insertions(+), 7 deletions(-) (limited to 'Eigen') diff --git a/Eigen/src/Sparse/SparseDenseProduct.h b/Eigen/src/Sparse/SparseDenseProduct.h index 8eaa60a26..0489c68db 100644 --- a/Eigen/src/Sparse/SparseDenseProduct.h +++ b/Eigen/src/Sparse/SparseDenseProduct.h @@ -25,6 +25,118 @@ #ifndef EIGEN_SPARSEDENSEPRODUCT_H #define EIGEN_SPARSEDENSEPRODUCT_H +template struct SparseDenseProductReturnType +{ + typedef SparseTimeDenseProduct Type; +}; + +template struct SparseDenseProductReturnType +{ + typedef SparseDenseOuterProduct Type; +}; + +template struct DenseSparseProductReturnType +{ + typedef DenseTimeSparseProduct Type; +}; + +template struct DenseSparseProductReturnType +{ + typedef SparseDenseOuterProduct Type; +}; + +template +struct ei_traits > +{ + typedef Sparse StorageKind; + typedef typename ei_scalar_product_traits::Scalar, + typename ei_traits::Scalar>::ReturnType Scalar; + typedef typename Lhs::Index Index; + typedef typename Lhs::Nested LhsNested; + typedef typename Rhs::Nested RhsNested; + typedef typename ei_cleantype::type _LhsNested; + typedef typename ei_cleantype::type _RhsNested; + + enum { + LhsCoeffReadCost = ei_traits<_LhsNested>::CoeffReadCost, + RhsCoeffReadCost = ei_traits<_RhsNested>::CoeffReadCost, + + RowsAtCompileTime = Tr ? int(ei_traits::RowsAtCompileTime) : int(ei_traits::RowsAtCompileTime), + ColsAtCompileTime = Tr ? int(ei_traits::ColsAtCompileTime) : int(ei_traits::ColsAtCompileTime), + MaxRowsAtCompileTime = Tr ? int(ei_traits::MaxRowsAtCompileTime) : int(ei_traits::MaxRowsAtCompileTime), + MaxColsAtCompileTime = Tr ? int(ei_traits::MaxColsAtCompileTime) : int(ei_traits::MaxColsAtCompileTime), + + Flags = Tr ? RowMajorBit : 0, + + CoeffReadCost = LhsCoeffReadCost + RhsCoeffReadCost + NumTraits::MulCost + }; +}; + +template +class SparseDenseOuterProduct + : public SparseMatrixBase > +{ + public: + + typedef SparseMatrixBase Base; + EIGEN_DENSE_PUBLIC_INTERFACE(SparseDenseOuterProduct) + typedef ei_traits Traits; + + private: + + typedef typename Traits::LhsNested LhsNested; + typedef typename Traits::RhsNested RhsNested; + typedef typename Traits::_LhsNested _LhsNested; + typedef typename Traits::_RhsNested _RhsNested; + + public: + + class InnerIterator; + + EIGEN_STRONG_INLINE SparseDenseOuterProduct(const Lhs& lhs, const Rhs& rhs) + : m_lhs(lhs), m_rhs(rhs) + { + EIGEN_STATIC_ASSERT(!Tr,YOU_MADE_A_PROGRAMMING_MISTAKE); + } + + EIGEN_STRONG_INLINE SparseDenseOuterProduct(const Rhs& rhs, const Lhs& lhs) + : m_lhs(lhs), m_rhs(rhs) + { + EIGEN_STATIC_ASSERT(Tr,YOU_MADE_A_PROGRAMMING_MISTAKE); + } + + EIGEN_STRONG_INLINE Index rows() const { return Tr ? m_rhs.rows() : m_lhs.rows(); } + EIGEN_STRONG_INLINE Index cols() const { return Tr ? m_lhs.cols() : m_rhs.cols(); } + + EIGEN_STRONG_INLINE const _LhsNested& lhs() const { return m_lhs; } + EIGEN_STRONG_INLINE const _RhsNested& rhs() const { return m_rhs; } + + protected: + LhsNested m_lhs; + RhsNested m_rhs; +}; + +template +class SparseDenseOuterProduct::InnerIterator : public _LhsNested::InnerIterator +{ + typedef typename _LhsNested::InnerIterator Base; + public: + EIGEN_STRONG_INLINE InnerIterator(const SparseDenseOuterProduct& prod, Index outer) + : Base(prod.lhs(), 0), m_outer(outer), m_factor(prod.rhs().coeff(outer)) + { + } + + inline Index outer() const { return m_outer; } + inline Index row() const { return Transpose ? Base::row() : m_outer; } + inline Index col() const { return Transpose ? m_outer : Base::row(); } + + inline Scalar value() const { return Base::value() * m_factor; } + + protected: + int m_outer; + Scalar m_factor; +}; + template struct ei_traits > : ei_traits, Lhs, Rhs> > @@ -102,10 +214,10 @@ class DenseTimeSparseProduct // sparse * dense template template -inline const SparseTimeDenseProduct +inline const typename SparseDenseProductReturnType::Type SparseMatrixBase::operator*(const MatrixBase &other) const { - return SparseTimeDenseProduct(derived(), other.derived()); + return typename SparseDenseProductReturnType::Type(derived(), other.derived()); } #endif // EIGEN_SPARSEDENSEPRODUCT_H diff --git a/Eigen/src/Sparse/SparseMatrixBase.h b/Eigen/src/Sparse/SparseMatrixBase.h index 12a1cb538..5ca3b604b 100644 --- a/Eigen/src/Sparse/SparseMatrixBase.h +++ b/Eigen/src/Sparse/SparseMatrixBase.h @@ -362,15 +362,15 @@ template class SparseMatrixBase : public EigenBase operator*(const DiagonalBase &lhs, const SparseMatrixBase& rhs) { return SparseDiagonalProduct(lhs.derived(), rhs.derived()); } - // dense * sparse (return a dense object) + /** dense * sparse (return a dense object unless it is an outer product) */ template friend - const DenseTimeSparseProduct + const typename DenseSparseProductReturnType::Type operator*(const MatrixBase& lhs, const Derived& rhs) - { return DenseTimeSparseProduct(lhs.derived(),rhs); } + { return typename DenseSparseProductReturnType::Type(lhs.derived(),rhs); } - // sparse * dense (returns a dense object) + /** sparse * dense (returns a dense object unless it is an outer product) */ template - const SparseTimeDenseProduct + const typename SparseDenseProductReturnType::Type operator*(const MatrixBase &other) const; template diff --git a/Eigen/src/Sparse/SparseUtil.h b/Eigen/src/Sparse/SparseUtil.h index 81941994f..ddfa115dc 100644 --- a/Eigen/src/Sparse/SparseUtil.h +++ b/Eigen/src/Sparse/SparseUtil.h @@ -95,8 +95,11 @@ template class SparseView; template class SparseSparseProduct; template class SparseTimeDenseProduct; template class DenseTimeSparseProduct; +template class SparseDenseOuterProduct; template struct SparseSparseProductReturnType; +template::ColsAtCompileTime> struct DenseSparseProductReturnType; +template::ColsAtCompileTime> struct SparseDenseProductReturnType; template struct ei_eval { -- cgit v1.2.3