diff options
-rw-r--r-- | Eigen/src/Core/CoreEvaluators.h | 2 | ||||
-rw-r--r-- | Eigen/src/SparseCore/SparseAssign.h | 11 | ||||
-rw-r--r-- | Eigen/src/SparseCore/SparseDenseProduct.h | 155 | ||||
-rw-r--r-- | Eigen/src/SparseCore/SparseMatrix.h | 4 |
4 files changed, 163 insertions, 9 deletions
diff --git a/Eigen/src/Core/CoreEvaluators.h b/Eigen/src/Core/CoreEvaluators.h index f9f229b09..b19a29e53 100644 --- a/Eigen/src/Core/CoreEvaluators.h +++ b/Eigen/src/Core/CoreEvaluators.h @@ -109,7 +109,7 @@ struct evaluator_base typedef evaluator<ExpressionType> type; typedef evaluator<ExpressionType> nestedType; - typedef typename ExpressionType::Index Index; + typedef typename traits<ExpressionType>::Index Index; // TODO that's not very nice to have to propagate all these traits. They are currently only needed to handle outer,inner indices. typedef traits<ExpressionType> ExpressionTraits; }; diff --git a/Eigen/src/SparseCore/SparseAssign.h b/Eigen/src/SparseCore/SparseAssign.h index 066bc2de1..528d63e35 100644 --- a/Eigen/src/SparseCore/SparseAssign.h +++ b/Eigen/src/SparseCore/SparseAssign.h @@ -186,14 +186,14 @@ void assign_sparse_to_sparse(DstXprType &dst, const SrcXprType &src) SrcEvaluatorType srcEvaluator(src); const bool transpose = (DstEvaluatorType::Flags & RowMajorBit) != (SrcEvaluatorType::Flags & RowMajorBit); - const Index outerSize = (int(SrcEvaluatorType::Flags) & RowMajorBit) ? src.rows() : src.cols(); + const Index outerEvaluationSize = (SrcEvaluatorType::Flags&RowMajorBit) ? src.rows() : src.cols(); if ((!transpose) && src.isRValue()) { // eval without temporary dst.resize(src.rows(), src.cols()); dst.setZero(); dst.reserve((std::max)(src.rows(),src.cols())*2); - for (Index j=0; j<outerSize; ++j) + for (Index j=0; j<outerEvaluationSize; ++j) { dst.startVec(j); for (typename SrcEvaluatorType::InnerIterator it(srcEvaluator, j); it; ++it) @@ -213,11 +213,11 @@ void assign_sparse_to_sparse(DstXprType &dst, const SrcXprType &src) enum { Flip = (DstEvaluatorType::Flags & RowMajorBit) != (SrcEvaluatorType::Flags & RowMajorBit) }; - const Index outerSize = src.outerSize(); + DstXprType temp(src.rows(), src.cols()); temp.reserve((std::max)(src.rows(),src.cols())*2); - for (Index j=0; j<outerSize; ++j) + for (Index j=0; j<outerEvaluationSize; ++j) { temp.startVec(j); for (typename SrcEvaluatorType::InnerIterator it(srcEvaluator, j); it; ++it) @@ -256,7 +256,8 @@ struct Assignment<DstXprType, SrcXprType, Functor, Sparse2Dense, Scalar> dst.setZero(); typename internal::evaluator<SrcXprType>::type srcEval(src); typename internal::evaluator<DstXprType>::type dstEval(dst); - for (Index j=0; j<src.outerSize(); ++j) + const Index outerEvaluationSize = (internal::evaluator<SrcXprType>::Flags&RowMajorBit) ? src.rows() : src.cols(); + for (Index j=0; j<outerEvaluationSize; ++j) for (typename internal::evaluator<SrcXprType>::InnerIterator i(srcEval,j); i; ++i) dstEval.coeffRef(i.row(),i.col()) = i.value(); } diff --git a/Eigen/src/SparseCore/SparseDenseProduct.h b/Eigen/src/SparseCore/SparseDenseProduct.h index 116edd62e..883e24acb 100644 --- a/Eigen/src/SparseCore/SparseDenseProduct.h +++ b/Eigen/src/SparseCore/SparseDenseProduct.h @@ -13,7 +13,10 @@ namespace Eigen { namespace internal { - + +template <> struct product_promote_storage_type<Sparse,Dense, OuterProduct> { typedef Sparse ret; }; +template <> struct product_promote_storage_type<Dense,Sparse, OuterProduct> { typedef Sparse ret; }; + template<typename SparseLhsType, typename DenseRhsType, typename DenseResType, typename AlphaType, int LhsStorageOrder = ((SparseLhsType::Flags&RowMajorBit)==RowMajorBit) ? RowMajor : ColMajor, @@ -445,6 +448,156 @@ protected: PlainObject m_result; }; + +// template<typename Lhs, typename Rhs, bool Transpose, typename LhsIterator> +// class sparse_dense_outer_product_iterator : public LhsIterator +// { +// typedef typename SparseDenseOuterProduct::Index Index; +// public: +// template<typename XprEval> +// EIGEN_STRONG_INLINE InnerIterator(const XprEval& prod, Index outer) +// : LhsIterator(prod.lhs(), 0), +// m_outer(outer), m_empty(false), m_factor(get(prod.rhs(), outer, typename internal::traits<Rhs>::StorageKind() )) +// {} +// +// inline Index outer() const { return m_outer; } +// inline Index row() const { return Transpose ? m_outer : Base::index(); } +// inline Index col() const { return Transpose ? Base::index() : m_outer; } +// +// inline Scalar value() const { return Base::value() * m_factor; } +// inline operator bool() const { return Base::operator bool() && !m_empty; } +// +// protected: +// Scalar get(const _RhsNested &rhs, Index outer, Dense = Dense()) const +// { +// return rhs.coeff(outer); +// } +// +// Scalar get(const _RhsNested &rhs, Index outer, Sparse = Sparse()) +// { +// typename Traits::_RhsNested::InnerIterator it(rhs, outer); +// if (it && it.index()==0 && it.value()!=Scalar(0)) +// return it.value(); +// m_empty = true; +// return Scalar(0); +// } +// +// Index m_outer; +// bool m_empty; +// Scalar m_factor; +// }; + +template<typename LhsT, typename RhsT, bool Transpose> +struct sparse_dense_outer_product_evaluator +{ +protected: + typedef typename conditional<Transpose,RhsT,LhsT>::type Lhs1; + typedef typename conditional<Transpose,LhsT,RhsT>::type Rhs; + typedef Product<LhsT,RhsT> ProdXprType; + + // if the actual left-hand side is a dense vector, + // then build a sparse-view so that we can seamlessly iterator over it. + typedef typename conditional<is_same<typename internal::traits<Lhs1>::StorageKind,Sparse>::value, + Lhs1, SparseView<Lhs1> >::type Lhs; + typedef typename conditional<is_same<typename internal::traits<Lhs1>::StorageKind,Sparse>::value, + Lhs1 const&, SparseView<Lhs1> >::type LhsArg; + + typedef typename evaluator<Lhs>::type LhsEval; + typedef typename evaluator<Rhs>::type RhsEval; + typedef typename evaluator<Lhs>::InnerIterator LhsIterator; + typedef typename ProdXprType::Scalar Scalar; + typedef typename ProdXprType::Index Index; + +public: + enum { + Flags = Transpose ? RowMajorBit : 0, + CoeffReadCost = Dynamic + }; + + class InnerIterator : public LhsIterator + { + public: + InnerIterator(const sparse_dense_outer_product_evaluator &xprEval, Index outer) + : LhsIterator(xprEval.m_lhsXprImpl, 0), + m_outer(outer), + m_empty(false), + m_factor(get(xprEval.m_rhsXprImpl, outer, typename internal::traits<Rhs>::StorageKind() )) + {} + + EIGEN_STRONG_INLINE Index outer() const { return m_outer; } + EIGEN_STRONG_INLINE Index row() const { return Transpose ? m_outer : LhsIterator::index(); } + EIGEN_STRONG_INLINE Index col() const { return Transpose ? LhsIterator::index() : m_outer; } + + EIGEN_STRONG_INLINE Scalar value() const { return LhsIterator::value() * m_factor; } + EIGEN_STRONG_INLINE operator bool() const { return LhsIterator::operator bool() && (!m_empty); } + + + protected: + Scalar get(const RhsEval &rhs, Index outer, Dense = Dense()) const + { + return rhs.coeff(outer); + } + + Scalar get(const RhsEval &rhs, Index outer, Sparse = Sparse()) + { + typename RhsEval::InnerIterator it(rhs, outer); + if (it && it.index()==0 && it.value()!=Scalar(0)) + return it.value(); + m_empty = true; + return Scalar(0); + } + + Index m_outer; + bool m_empty; + Scalar m_factor; + }; + + sparse_dense_outer_product_evaluator(const Lhs &lhs, const Rhs &rhs) + : m_lhs(lhs), m_lhsXprImpl(m_lhs), m_rhsXprImpl(rhs) + {} + + // transpose case + sparse_dense_outer_product_evaluator(const Rhs &rhs, const Lhs1 &lhs) + : m_lhs(lhs), m_lhsXprImpl(m_lhs), m_rhsXprImpl(rhs) + {} + +protected: + const LhsArg m_lhs; + typename evaluator<Lhs>::nestedType m_lhsXprImpl; + typename evaluator<Rhs>::nestedType m_rhsXprImpl; +}; + +// sparse * dense outer product +template<typename Lhs, typename Rhs> +struct product_evaluator<Product<Lhs, Rhs, DefaultProduct>, OuterProduct, SparseShape, DenseShape, typename Lhs::Scalar, typename Rhs::Scalar> + : sparse_dense_outer_product_evaluator<Lhs,Rhs, Lhs::IsRowMajor> +{ + typedef sparse_dense_outer_product_evaluator<Lhs,Rhs, Lhs::IsRowMajor> Base; + + typedef Product<Lhs, Rhs> XprType; + typedef typename XprType::PlainObject PlainObject; + + product_evaluator(const XprType& xpr) + : Base(xpr.lhs(), xpr.rhs()) + {} + +}; + +template<typename Lhs, typename Rhs> +struct product_evaluator<Product<Lhs, Rhs, DefaultProduct>, OuterProduct, DenseShape, SparseShape, typename Lhs::Scalar, typename Rhs::Scalar> + : sparse_dense_outer_product_evaluator<Lhs,Rhs, Rhs::IsRowMajor> +{ + typedef sparse_dense_outer_product_evaluator<Lhs,Rhs, Rhs::IsRowMajor> Base; + + typedef Product<Lhs, Rhs> XprType; + typedef typename XprType::PlainObject PlainObject; + + product_evaluator(const XprType& xpr) + : Base(xpr.lhs(), xpr.rhs()) + {} + +}; + } // end namespace internal #endif // EIGEN_TEST_EVALUATORS diff --git a/Eigen/src/SparseCore/SparseMatrix.h b/Eigen/src/SparseCore/SparseMatrix.h index 9d0d6c3ab..36a024cc7 100644 --- a/Eigen/src/SparseCore/SparseMatrix.h +++ b/Eigen/src/SparseCore/SparseMatrix.h @@ -1132,7 +1132,7 @@ EIGEN_DONT_INLINE SparseMatrix<Scalar,_Options,_Index>& SparseMatrix<Scalar,_Opt { EIGEN_STATIC_ASSERT((internal::is_same<Scalar, typename OtherDerived::Scalar>::value), YOU_MIXED_DIFFERENT_NUMERIC_TYPES__YOU_NEED_TO_USE_THE_CAST_METHOD_OF_MATRIXBASE_TO_CAST_NUMERIC_TYPES_EXPLICITLY) - + const bool needToTranspose = (Flags & RowMajorBit) != (internal::evaluator<OtherDerived>::Flags & RowMajorBit); if (needToTranspose) { @@ -1140,7 +1140,7 @@ EIGEN_DONT_INLINE SparseMatrix<Scalar,_Options,_Index>& SparseMatrix<Scalar,_Opt // 1 - compute the number of coeffs per dest inner vector // 2 - do the actual copy/eval // Since each coeff of the rhs has to be evaluated twice, let's evaluate it if needed - typedef typename internal::nested_eval<OtherDerived,2>::type OtherCopy; + typedef typename internal::nested_eval<OtherDerived,2,typename internal::plain_matrix_type<OtherDerived>::type >::type OtherCopy; typedef typename internal::remove_all<OtherCopy>::type _OtherCopy; typedef internal::evaluator<_OtherCopy> OtherCopyEval; OtherCopy otherCopy(other.derived()); |