diff options
Diffstat (limited to 'Eigen/src/SparseCore/SparseDenseProduct.h')
-rw-r--r-- | Eigen/src/SparseCore/SparseDenseProduct.h | 346 |
1 files changed, 157 insertions, 189 deletions
diff --git a/Eigen/src/SparseCore/SparseDenseProduct.h b/Eigen/src/SparseCore/SparseDenseProduct.h index d40e966c1..5aea11425 100644 --- a/Eigen/src/SparseCore/SparseDenseProduct.h +++ b/Eigen/src/SparseCore/SparseDenseProduct.h @@ -1,7 +1,7 @@ // This file is part of Eigen, a lightweight C++ template library // for linear algebra. // -// Copyright (C) 2008-2010 Gael Guennebaud <gael.guennebaud@inria.fr> +// Copyright (C) 2008-2014 Gael Guennebaud <gael.guennebaud@inria.fr> // // This Source Code Form is subject to the terms of the Mozilla // Public License v. 2.0. If a copy of the MPL was not distributed @@ -12,152 +12,10 @@ namespace Eigen { -template<typename Lhs, typename Rhs, int InnerSize> struct SparseDenseProductReturnType -{ - typedef SparseTimeDenseProduct<Lhs,Rhs> Type; -}; - -template<typename Lhs, typename Rhs> struct SparseDenseProductReturnType<Lhs,Rhs,1> -{ - typedef typename internal::conditional< - Lhs::IsRowMajor, - SparseDenseOuterProduct<Rhs,Lhs,true>, - SparseDenseOuterProduct<Lhs,Rhs,false> >::type Type; -}; - -template<typename Lhs, typename Rhs, int InnerSize> struct DenseSparseProductReturnType -{ - typedef DenseTimeSparseProduct<Lhs,Rhs> Type; -}; - -template<typename Lhs, typename Rhs> struct DenseSparseProductReturnType<Lhs,Rhs,1> -{ - typedef typename internal::conditional< - Rhs::IsRowMajor, - SparseDenseOuterProduct<Rhs,Lhs,true>, - SparseDenseOuterProduct<Lhs,Rhs,false> >::type Type; -}; - namespace internal { -template<typename Lhs, typename Rhs, bool Tr> -struct traits<SparseDenseOuterProduct<Lhs,Rhs,Tr> > -{ - typedef Sparse StorageKind; - typedef typename scalar_product_traits<typename traits<Lhs>::Scalar, - typename traits<Rhs>::Scalar>::ReturnType Scalar; - typedef typename Lhs::Index Index; - typedef typename Lhs::Nested LhsNested; - typedef typename Rhs::Nested RhsNested; - typedef typename remove_all<LhsNested>::type _LhsNested; - typedef typename remove_all<RhsNested>::type _RhsNested; - - enum { - LhsCoeffReadCost = traits<_LhsNested>::CoeffReadCost, - RhsCoeffReadCost = traits<_RhsNested>::CoeffReadCost, - - RowsAtCompileTime = Tr ? int(traits<Rhs>::RowsAtCompileTime) : int(traits<Lhs>::RowsAtCompileTime), - ColsAtCompileTime = Tr ? int(traits<Lhs>::ColsAtCompileTime) : int(traits<Rhs>::ColsAtCompileTime), - MaxRowsAtCompileTime = Tr ? int(traits<Rhs>::MaxRowsAtCompileTime) : int(traits<Lhs>::MaxRowsAtCompileTime), - MaxColsAtCompileTime = Tr ? int(traits<Lhs>::MaxColsAtCompileTime) : int(traits<Rhs>::MaxColsAtCompileTime), - - Flags = Tr ? RowMajorBit : 0, - - CoeffReadCost = LhsCoeffReadCost + RhsCoeffReadCost + NumTraits<Scalar>::MulCost - }; -}; - -} // end namespace internal - -template<typename Lhs, typename Rhs, bool Tr> -class SparseDenseOuterProduct - : public SparseMatrixBase<SparseDenseOuterProduct<Lhs,Rhs,Tr> > -{ - public: - - typedef SparseMatrixBase<SparseDenseOuterProduct> Base; - EIGEN_DENSE_PUBLIC_INTERFACE(SparseDenseOuterProduct) - typedef internal::traits<SparseDenseOuterProduct> Traits; - - private: - - typedef typename Traits::LhsNested LhsNested; - typedef typename Traits::RhsNested RhsNested; - typedef typename Traits::_LhsNested _LhsNested; - typedef typename Traits::_RhsNested _RhsNested; - - public: - - class InnerIterator; - - EIGEN_STRONG_INLINE SparseDenseOuterProduct(const Lhs& lhs, const Rhs& rhs) - : m_lhs(lhs), m_rhs(rhs) - { - EIGEN_STATIC_ASSERT(!Tr,YOU_MADE_A_PROGRAMMING_MISTAKE); - } - - EIGEN_STRONG_INLINE SparseDenseOuterProduct(const Rhs& rhs, const Lhs& lhs) - : m_lhs(lhs), m_rhs(rhs) - { - EIGEN_STATIC_ASSERT(Tr,YOU_MADE_A_PROGRAMMING_MISTAKE); - } - - EIGEN_STRONG_INLINE Index rows() const { return Tr ? Index(m_rhs.rows()) : m_lhs.rows(); } - EIGEN_STRONG_INLINE Index cols() const { return Tr ? m_lhs.cols() : Index(m_rhs.cols()); } - - EIGEN_STRONG_INLINE const _LhsNested& lhs() const { return m_lhs; } - EIGEN_STRONG_INLINE const _RhsNested& rhs() const { return m_rhs; } - - protected: - LhsNested m_lhs; - RhsNested m_rhs; -}; - -template<typename Lhs, typename Rhs, bool Transpose> -class SparseDenseOuterProduct<Lhs,Rhs,Transpose>::InnerIterator : public _LhsNested::InnerIterator -{ - typedef typename _LhsNested::InnerIterator Base; - typedef typename SparseDenseOuterProduct::Index Index; - public: - EIGEN_STRONG_INLINE InnerIterator(const SparseDenseOuterProduct& prod, Index outer) - : Base(prod.lhs(), 0), m_outer(outer), m_empty(false), m_factor(get(prod.rhs(), outer, typename internal::traits<Rhs>::StorageKind() )) - {} - - inline Index outer() const { return m_outer; } - inline Index row() const { return Transpose ? m_outer : Base::index(); } - inline Index col() const { return Transpose ? Base::index() : m_outer; } - - inline Scalar value() const { return Base::value() * m_factor; } - inline operator bool() const { return Base::operator bool() && !m_empty; } - - protected: - Scalar get(const _RhsNested &rhs, Index outer, Dense = Dense()) const - { - return rhs.coeff(outer); - } - - Scalar get(const _RhsNested &rhs, Index outer, Sparse = Sparse()) - { - typename Traits::_RhsNested::InnerIterator it(rhs, outer); - if (it && it.index()==0 && it.value()!=Scalar(0)) - return it.value(); - m_empty = true; - return Scalar(0); - } - - Index m_outer; - bool m_empty; - Scalar m_factor; -}; - -namespace internal { -template<typename Lhs, typename Rhs> -struct traits<SparseTimeDenseProduct<Lhs,Rhs> > - : traits<ProductBase<SparseTimeDenseProduct<Lhs,Rhs>, Lhs, Rhs> > -{ - typedef Dense StorageKind; - typedef MatrixXpr XprKind; -}; +template <> struct product_promote_storage_type<Sparse,Dense, OuterProduct> { typedef Sparse ret; }; +template <> struct product_promote_storage_type<Dense,Sparse, OuterProduct> { typedef Sparse ret; }; template<typename SparseLhsType, typename DenseRhsType, typename DenseResType, typename AlphaType, @@ -172,16 +30,17 @@ struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, t typedef typename internal::remove_all<DenseRhsType>::type Rhs; typedef typename internal::remove_all<DenseResType>::type Res; typedef typename Lhs::Index Index; - typedef typename Lhs::InnerIterator LhsInnerIterator; + typedef typename evaluator<Lhs>::InnerIterator LhsInnerIterator; static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const typename Res::Scalar& alpha) { + typename evaluator<Lhs>::type lhsEval(lhs); for(Index c=0; c<rhs.cols(); ++c) { Index n = lhs.outerSize(); for(Index j=0; j<n; ++j) { typename Res::Scalar tmp(0); - for(LhsInnerIterator it(lhs,j); it ;++it) + for(LhsInnerIterator it(lhsEval,j); it ;++it) tmp += it.value() * rhs.coeff(it.index(),c); res.coeffRef(j,c) = alpha * tmp; } @@ -203,17 +62,18 @@ struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, A typedef typename internal::remove_all<SparseLhsType>::type Lhs; typedef typename internal::remove_all<DenseRhsType>::type Rhs; typedef typename internal::remove_all<DenseResType>::type Res; - typedef typename Lhs::InnerIterator LhsInnerIterator; typedef typename Lhs::Index Index; + typedef typename evaluator<Lhs>::InnerIterator LhsInnerIterator; static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const AlphaType& alpha) { + typename evaluator<Lhs>::type lhsEval(lhs); for(Index c=0; c<rhs.cols(); ++c) { for(Index j=0; j<lhs.outerSize(); ++j) { // typename Res::Scalar rhs_j = alpha * rhs.coeff(j,c); typename internal::scalar_product_traits<AlphaType, typename Rhs::Scalar>::ReturnType rhs_j(alpha * rhs.coeff(j,c)); - for(LhsInnerIterator it(lhs,j); it ;++it) + for(LhsInnerIterator it(lhsEval,j); it ;++it) res.coeffRef(it.index(),c) += it.value() * rhs_j; } } @@ -226,14 +86,15 @@ struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, t typedef typename internal::remove_all<SparseLhsType>::type Lhs; typedef typename internal::remove_all<DenseRhsType>::type Rhs; typedef typename internal::remove_all<DenseResType>::type Res; - typedef typename Lhs::InnerIterator LhsInnerIterator; typedef typename Lhs::Index Index; + typedef typename evaluator<Lhs>::InnerIterator LhsInnerIterator; static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const typename Res::Scalar& alpha) { + typename evaluator<Lhs>::type lhsEval(lhs); for(Index j=0; j<lhs.outerSize(); ++j) { typename Res::RowXpr res_j(res.row(j)); - for(LhsInnerIterator it(lhs,j); it ;++it) + for(LhsInnerIterator it(lhsEval,j); it ;++it) res_j += (alpha*it.value()) * rhs.row(it.index()); } } @@ -245,14 +106,15 @@ struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, t typedef typename internal::remove_all<SparseLhsType>::type Lhs; typedef typename internal::remove_all<DenseRhsType>::type Rhs; typedef typename internal::remove_all<DenseResType>::type Res; - typedef typename Lhs::InnerIterator LhsInnerIterator; typedef typename Lhs::Index Index; + typedef typename evaluator<Lhs>::InnerIterator LhsInnerIterator; static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const typename Res::Scalar& alpha) { + typename evaluator<Lhs>::type lhsEval(lhs); for(Index j=0; j<lhs.outerSize(); ++j) { typename Rhs::ConstRowXpr rhs_j(rhs.row(j)); - for(LhsInnerIterator it(lhs,j); it ;++it) + for(LhsInnerIterator it(lhsEval,j); it ;++it) res.row(it.index()) += (alpha*it.value()) * rhs_j; } } @@ -266,58 +128,164 @@ inline void sparse_time_dense_product(const SparseLhsType& lhs, const DenseRhsTy } // end namespace internal -template<typename Lhs, typename Rhs> -class SparseTimeDenseProduct - : public ProductBase<SparseTimeDenseProduct<Lhs,Rhs>, Lhs, Rhs> -{ - public: - EIGEN_PRODUCT_PUBLIC_INTERFACE(SparseTimeDenseProduct) - - SparseTimeDenseProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs) - {} - - template<typename Dest> void scaleAndAddTo(Dest& dest, const Scalar& alpha) const - { - internal::sparse_time_dense_product(m_lhs, m_rhs, dest, alpha); - } +namespace internal { - private: - SparseTimeDenseProduct& operator=(const SparseTimeDenseProduct&); +template<typename Lhs, typename Rhs, int ProductType> +struct generic_product_impl<Lhs, Rhs, SparseShape, DenseShape, ProductType> +{ + template<typename Dest> + static void evalTo(Dest& dst, const Lhs& lhs, const Rhs& rhs) + { + typedef typename nested_eval<Lhs,Dynamic>::type LhsNested; + typedef typename nested_eval<Rhs,Dynamic>::type RhsNested; + LhsNested lhsNested(lhs); + RhsNested rhsNested(rhs); + + dst.setZero(); + internal::sparse_time_dense_product(lhsNested, rhsNested, dst, typename Dest::Scalar(1)); + } }; +template<typename Lhs, typename Rhs, int ProductType> +struct generic_product_impl<Lhs, Rhs, SparseTriangularShape, DenseShape, ProductType> + : generic_product_impl<Lhs, Rhs, SparseShape, DenseShape, ProductType> +{}; -// dense = dense * sparse -namespace internal { -template<typename Lhs, typename Rhs> -struct traits<DenseTimeSparseProduct<Lhs,Rhs> > - : traits<ProductBase<DenseTimeSparseProduct<Lhs,Rhs>, Lhs, Rhs> > +template<typename Lhs, typename Rhs, int ProductType> +struct generic_product_impl<Lhs, Rhs, DenseShape, SparseShape, ProductType> { - typedef Dense StorageKind; + template<typename Dest> + static void evalTo(Dest& dst, const Lhs& lhs, const Rhs& rhs) + { + typedef typename nested_eval<Lhs,Dynamic>::type LhsNested; + typedef typename nested_eval<Rhs,Dynamic>::type RhsNested; + LhsNested lhsNested(lhs); + RhsNested rhsNested(rhs); + + dst.setZero(); + // transpose everything + Transpose<Dest> dstT(dst); + internal::sparse_time_dense_product(rhsNested.transpose(), lhsNested.transpose(), dstT, typename Dest::Scalar(1)); + } }; -} // end namespace internal -template<typename Lhs, typename Rhs> -class DenseTimeSparseProduct - : public ProductBase<DenseTimeSparseProduct<Lhs,Rhs>, Lhs, Rhs> +template<typename Lhs, typename Rhs, int ProductType> +struct generic_product_impl<Lhs, Rhs, DenseShape, SparseTriangularShape, ProductType> + : generic_product_impl<Lhs, Rhs, DenseShape, SparseShape, ProductType> +{}; + +template<typename LhsT, typename RhsT, bool NeedToTranspose> +struct sparse_dense_outer_product_evaluator { +protected: + typedef typename conditional<NeedToTranspose,RhsT,LhsT>::type Lhs1; + typedef typename conditional<NeedToTranspose,LhsT,RhsT>::type ActualRhs; + typedef Product<LhsT,RhsT,DefaultProduct> ProdXprType; + + // if the actual left-hand side is a dense vector, + // then build a sparse-view so that we can seamlessly iterate over it. + typedef typename conditional<is_same<typename internal::traits<Lhs1>::StorageKind,Sparse>::value, + Lhs1, SparseView<Lhs1> >::type ActualLhs; + typedef typename conditional<is_same<typename internal::traits<Lhs1>::StorageKind,Sparse>::value, + Lhs1 const&, SparseView<Lhs1> >::type LhsArg; + + typedef typename evaluator<ActualLhs>::type LhsEval; + typedef typename evaluator<ActualRhs>::type RhsEval; + typedef typename evaluator<ActualLhs>::InnerIterator LhsIterator; + typedef typename ProdXprType::Scalar Scalar; + typedef typename ProdXprType::Index Index; + +public: + enum { + Flags = NeedToTranspose ? RowMajorBit : 0, + CoeffReadCost = Dynamic + }; + + class InnerIterator : public LhsIterator + { public: - EIGEN_PRODUCT_PUBLIC_INTERFACE(DenseTimeSparseProduct) - - DenseTimeSparseProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs) + InnerIterator(const sparse_dense_outer_product_evaluator &xprEval, Index outer) + : LhsIterator(xprEval.m_lhsXprImpl, 0), + m_outer(outer), + m_empty(false), + m_factor(get(xprEval.m_rhsXprImpl, outer, typename internal::traits<ActualRhs>::StorageKind() )) {} + + EIGEN_STRONG_INLINE Index outer() const { return m_outer; } + EIGEN_STRONG_INLINE Index row() const { return NeedToTranspose ? m_outer : LhsIterator::index(); } + EIGEN_STRONG_INLINE Index col() const { return NeedToTranspose ? LhsIterator::index() : m_outer; } - template<typename Dest> void scaleAndAddTo(Dest& dest, const Scalar& alpha) const + EIGEN_STRONG_INLINE Scalar value() const { return LhsIterator::value() * m_factor; } + EIGEN_STRONG_INLINE operator bool() const { return LhsIterator::operator bool() && (!m_empty); } + + protected: + Scalar get(const RhsEval &rhs, Index outer, Dense = Dense()) const { - Transpose<const _LhsNested> lhs_t(m_lhs); - Transpose<const _RhsNested> rhs_t(m_rhs); - Transpose<Dest> dest_t(dest); - internal::sparse_time_dense_product(rhs_t, lhs_t, dest_t, alpha); + return rhs.coeff(outer); + } + + Scalar get(const RhsEval &rhs, Index outer, Sparse = Sparse()) + { + typename RhsEval::InnerIterator it(rhs, outer); + if (it && it.index()==0 && it.value()!=Scalar(0)) + return it.value(); + m_empty = true; + return Scalar(0); } + + Index m_outer; + bool m_empty; + Scalar m_factor; + }; + + sparse_dense_outer_product_evaluator(const Lhs1 &lhs, const ActualRhs &rhs) + : m_lhs(lhs), m_lhsXprImpl(m_lhs), m_rhsXprImpl(rhs) + {} + + // transpose case + sparse_dense_outer_product_evaluator(const ActualRhs &rhs, const Lhs1 &lhs) + : m_lhs(lhs), m_lhsXprImpl(m_lhs), m_rhsXprImpl(rhs) + {} + +protected: + const LhsArg m_lhs; + typename evaluator<ActualLhs>::nestedType m_lhsXprImpl; + typename evaluator<ActualRhs>::nestedType m_rhsXprImpl; +}; - private: - DenseTimeSparseProduct& operator=(const DenseTimeSparseProduct&); +// sparse * dense outer product +template<typename Lhs, typename Rhs> +struct product_evaluator<Product<Lhs, Rhs, DefaultProduct>, OuterProduct, SparseShape, DenseShape, typename traits<Lhs>::Scalar, typename traits<Rhs>::Scalar> + : sparse_dense_outer_product_evaluator<Lhs,Rhs, Lhs::IsRowMajor> +{ + typedef sparse_dense_outer_product_evaluator<Lhs,Rhs, Lhs::IsRowMajor> Base; + + typedef Product<Lhs, Rhs> XprType; + typedef typename XprType::PlainObject PlainObject; + + explicit product_evaluator(const XprType& xpr) + : Base(xpr.lhs(), xpr.rhs()) + {} + }; +template<typename Lhs, typename Rhs> +struct product_evaluator<Product<Lhs, Rhs, DefaultProduct>, OuterProduct, DenseShape, SparseShape, typename traits<Lhs>::Scalar, typename traits<Rhs>::Scalar> + : sparse_dense_outer_product_evaluator<Lhs,Rhs, Rhs::IsRowMajor> +{ + typedef sparse_dense_outer_product_evaluator<Lhs,Rhs, Rhs::IsRowMajor> Base; + + typedef Product<Lhs, Rhs> XprType; + typedef typename XprType::PlainObject PlainObject; + + explicit product_evaluator(const XprType& xpr) + : Base(xpr.lhs(), xpr.rhs()) + {} + +}; + +} // end namespace internal + } // end namespace Eigen #endif // EIGEN_SPARSEDENSEPRODUCT_H |