diff options
author | Gael Guennebaud <g.gael@free.fr> | 2010-06-25 11:36:38 +0200 |
---|---|---|
committer | Gael Guennebaud <g.gael@free.fr> | 2010-06-25 11:36:38 +0200 |
commit | e313826890f581f1b9665422bab7b83b9daf5bfd (patch) | |
tree | a00dbdf84203cb652008e58646b8661049d21f3a /Eigen/src | |
parent | 1927b4dff513f66866de205a46c66a1f2c877d01 (diff) |
add mixed sparse-dense outer product
Diffstat (limited to 'Eigen/src')
-rw-r--r-- | Eigen/src/Sparse/SparseDenseProduct.h | 116 | ||||
-rw-r--r-- | Eigen/src/Sparse/SparseMatrixBase.h | 10 | ||||
-rw-r--r-- | Eigen/src/Sparse/SparseUtil.h | 3 |
3 files changed, 122 insertions, 7 deletions
diff --git a/Eigen/src/Sparse/SparseDenseProduct.h b/Eigen/src/Sparse/SparseDenseProduct.h index 8eaa60a26..0489c68db 100644 --- a/Eigen/src/Sparse/SparseDenseProduct.h +++ b/Eigen/src/Sparse/SparseDenseProduct.h @@ -25,6 +25,118 @@ #ifndef EIGEN_SPARSEDENSEPRODUCT_H #define EIGEN_SPARSEDENSEPRODUCT_H +template<typename Lhs, typename Rhs, int InnerSize> struct SparseDenseProductReturnType +{ + typedef SparseTimeDenseProduct<Lhs,Rhs> Type; +}; + +template<typename Lhs, typename Rhs> struct SparseDenseProductReturnType<Lhs,Rhs,1> +{ + typedef SparseDenseOuterProduct<Lhs,Rhs,false> Type; +}; + +template<typename Lhs, typename Rhs, int InnerSize> struct DenseSparseProductReturnType +{ + typedef DenseTimeSparseProduct<Lhs,Rhs> Type; +}; + +template<typename Lhs, typename Rhs> struct DenseSparseProductReturnType<Lhs,Rhs,1> +{ + typedef SparseDenseOuterProduct<Rhs,Lhs,true> Type; +}; + +template<typename Lhs, typename Rhs, bool Tr> +struct ei_traits<SparseDenseOuterProduct<Lhs,Rhs,Tr> > +{ + typedef Sparse StorageKind; + typedef typename ei_scalar_product_traits<typename ei_traits<Lhs>::Scalar, + typename ei_traits<Rhs>::Scalar>::ReturnType Scalar; + typedef typename Lhs::Index Index; + typedef typename Lhs::Nested LhsNested; + typedef typename Rhs::Nested RhsNested; + typedef typename ei_cleantype<LhsNested>::type _LhsNested; + typedef typename ei_cleantype<RhsNested>::type _RhsNested; + + enum { + LhsCoeffReadCost = ei_traits<_LhsNested>::CoeffReadCost, + RhsCoeffReadCost = ei_traits<_RhsNested>::CoeffReadCost, + + RowsAtCompileTime = Tr ? int(ei_traits<Rhs>::RowsAtCompileTime) : int(ei_traits<Lhs>::RowsAtCompileTime), + ColsAtCompileTime = Tr ? int(ei_traits<Lhs>::ColsAtCompileTime) : int(ei_traits<Rhs>::ColsAtCompileTime), + MaxRowsAtCompileTime = Tr ? int(ei_traits<Rhs>::MaxRowsAtCompileTime) : int(ei_traits<Lhs>::MaxRowsAtCompileTime), + MaxColsAtCompileTime = Tr ? int(ei_traits<Lhs>::MaxColsAtCompileTime) : int(ei_traits<Rhs>::MaxColsAtCompileTime), + + Flags = Tr ? RowMajorBit : 0, + + CoeffReadCost = LhsCoeffReadCost + RhsCoeffReadCost + NumTraits<Scalar>::MulCost + }; +}; + +template<typename Lhs, typename Rhs, bool Tr> +class SparseDenseOuterProduct + : public SparseMatrixBase<SparseDenseOuterProduct<Lhs,Rhs,Tr> > +{ + public: + + typedef SparseMatrixBase<SparseDenseOuterProduct> Base; + EIGEN_DENSE_PUBLIC_INTERFACE(SparseDenseOuterProduct) + typedef ei_traits<SparseDenseOuterProduct> Traits; + + private: + + typedef typename Traits::LhsNested LhsNested; + typedef typename Traits::RhsNested RhsNested; + typedef typename Traits::_LhsNested _LhsNested; + typedef typename Traits::_RhsNested _RhsNested; + + public: + + class InnerIterator; + + EIGEN_STRONG_INLINE SparseDenseOuterProduct(const Lhs& lhs, const Rhs& rhs) + : m_lhs(lhs), m_rhs(rhs) + { + EIGEN_STATIC_ASSERT(!Tr,YOU_MADE_A_PROGRAMMING_MISTAKE); + } + + EIGEN_STRONG_INLINE SparseDenseOuterProduct(const Rhs& rhs, const Lhs& lhs) + : m_lhs(lhs), m_rhs(rhs) + { + EIGEN_STATIC_ASSERT(Tr,YOU_MADE_A_PROGRAMMING_MISTAKE); + } + + EIGEN_STRONG_INLINE Index rows() const { return Tr ? m_rhs.rows() : m_lhs.rows(); } + EIGEN_STRONG_INLINE Index cols() const { return Tr ? m_lhs.cols() : m_rhs.cols(); } + + EIGEN_STRONG_INLINE const _LhsNested& lhs() const { return m_lhs; } + EIGEN_STRONG_INLINE const _RhsNested& rhs() const { return m_rhs; } + + protected: + LhsNested m_lhs; + RhsNested m_rhs; +}; + +template<typename Lhs, typename Rhs, bool Transpose> +class SparseDenseOuterProduct<Lhs,Rhs,Transpose>::InnerIterator : public _LhsNested::InnerIterator +{ + typedef typename _LhsNested::InnerIterator Base; + public: + EIGEN_STRONG_INLINE InnerIterator(const SparseDenseOuterProduct& prod, Index outer) + : Base(prod.lhs(), 0), m_outer(outer), m_factor(prod.rhs().coeff(outer)) + { + } + + inline Index outer() const { return m_outer; } + inline Index row() const { return Transpose ? Base::row() : m_outer; } + inline Index col() const { return Transpose ? m_outer : Base::row(); } + + inline Scalar value() const { return Base::value() * m_factor; } + + protected: + int m_outer; + Scalar m_factor; +}; + template<typename Lhs, typename Rhs> struct ei_traits<SparseTimeDenseProduct<Lhs,Rhs> > : ei_traits<ProductBase<SparseTimeDenseProduct<Lhs,Rhs>, Lhs, Rhs> > @@ -102,10 +214,10 @@ class DenseTimeSparseProduct // sparse * dense template<typename Derived> template<typename OtherDerived> -inline const SparseTimeDenseProduct<Derived,OtherDerived> +inline const typename SparseDenseProductReturnType<Derived,OtherDerived>::Type SparseMatrixBase<Derived>::operator*(const MatrixBase<OtherDerived> &other) const { - return SparseTimeDenseProduct<Derived,OtherDerived>(derived(), other.derived()); + return typename SparseDenseProductReturnType<Derived,OtherDerived>::Type(derived(), other.derived()); } #endif // EIGEN_SPARSEDENSEPRODUCT_H diff --git a/Eigen/src/Sparse/SparseMatrixBase.h b/Eigen/src/Sparse/SparseMatrixBase.h index 12a1cb538..5ca3b604b 100644 --- a/Eigen/src/Sparse/SparseMatrixBase.h +++ b/Eigen/src/Sparse/SparseMatrixBase.h @@ -362,15 +362,15 @@ template<typename Derived> class SparseMatrixBase : public EigenBase<Derived> operator*(const DiagonalBase<OtherDerived> &lhs, const SparseMatrixBase& rhs) { return SparseDiagonalProduct<OtherDerived,Derived>(lhs.derived(), rhs.derived()); } - // dense * sparse (return a dense object) + /** dense * sparse (return a dense object unless it is an outer product) */ template<typename OtherDerived> friend - const DenseTimeSparseProduct<OtherDerived,Derived> + const typename DenseSparseProductReturnType<OtherDerived,Derived>::Type operator*(const MatrixBase<OtherDerived>& lhs, const Derived& rhs) - { return DenseTimeSparseProduct<OtherDerived,Derived>(lhs.derived(),rhs); } + { return typename DenseSparseProductReturnType<OtherDerived,Derived>::Type(lhs.derived(),rhs); } - // sparse * dense (returns a dense object) + /** sparse * dense (returns a dense object unless it is an outer product) */ template<typename OtherDerived> - const SparseTimeDenseProduct<Derived,OtherDerived> + const typename SparseDenseProductReturnType<Derived,OtherDerived>::Type operator*(const MatrixBase<OtherDerived> &other) const; template<typename OtherDerived> diff --git a/Eigen/src/Sparse/SparseUtil.h b/Eigen/src/Sparse/SparseUtil.h index 81941994f..ddfa115dc 100644 --- a/Eigen/src/Sparse/SparseUtil.h +++ b/Eigen/src/Sparse/SparseUtil.h @@ -95,8 +95,11 @@ template<typename MatrixType> class SparseView; template<typename Lhs, typename Rhs> class SparseSparseProduct; template<typename Lhs, typename Rhs> class SparseTimeDenseProduct; template<typename Lhs, typename Rhs> class DenseTimeSparseProduct; +template<typename Lhs, typename Rhs, bool Transpose> class SparseDenseOuterProduct; template<typename Lhs, typename Rhs> struct SparseSparseProductReturnType; +template<typename Lhs, typename Rhs, int InnerSize = ei_traits<Lhs>::ColsAtCompileTime> struct DenseSparseProductReturnType; +template<typename Lhs, typename Rhs, int InnerSize = ei_traits<Lhs>::ColsAtCompileTime> struct SparseDenseProductReturnType; template<typename T> struct ei_eval<T,Sparse> { |