diff options
author | Gael Guennebaud <g.gael@free.fr> | 2015-10-13 11:30:41 +0200 |
---|---|---|
committer | Gael Guennebaud <g.gael@free.fr> | 2015-10-13 11:30:41 +0200 |
commit | b4c79ee1d3d7b44e58f2bea48cd597aa0fa7e007 (patch) | |
tree | e7f58f4fb55fffba3234b1cdf8856aeaabd5bac5 /Eigen/src/SparseCore/SparseMatrix.h | |
parent | b9d81c915009e08a2397a2fc2d36a15d16b3b32f (diff) |
Update custom setFromTripplets API to allow passing a functor object, and add a collapseDuplicates method to cleanup the API. Also add respective unit test
Diffstat (limited to 'Eigen/src/SparseCore/SparseMatrix.h')
-rw-r--r-- | Eigen/src/SparseCore/SparseMatrix.h | 35 |
1 files changed, 20 insertions, 15 deletions
diff --git a/Eigen/src/SparseCore/SparseMatrix.h b/Eigen/src/SparseCore/SparseMatrix.h index 22a6bd803..5e2b14554 100644 --- a/Eigen/src/SparseCore/SparseMatrix.h +++ b/Eigen/src/SparseCore/SparseMatrix.h @@ -437,11 +437,13 @@ class SparseMatrix template<typename InputIterators> void setFromTriplets(const InputIterators& begin, const InputIterators& end); - template<typename DupFunctor, typename InputIterators> - void setFromTriplets(const InputIterators& begin, const InputIterators& end); + template<typename InputIterators,typename DupFunctor> + void setFromTriplets(const InputIterators& begin, const InputIterators& end, DupFunctor dup_func); + + void sumupDuplicates() { collapseDuplicates(internal::scalar_sum_op<Scalar>()); } template<typename DupFunctor> - void sumupDuplicates(); + void collapseDuplicates(DupFunctor dup_func = DupFunctor()); //--- @@ -894,9 +896,8 @@ private: namespace internal { template<typename InputIterator, typename SparseMatrixType, typename DupFunctor> -void set_from_triplets(const InputIterator& begin, const InputIterator& end, SparseMatrixType& mat, int Options = 0) +void set_from_triplets(const InputIterator& begin, const InputIterator& end, SparseMatrixType& mat, DupFunctor dup_func) { - EIGEN_UNUSED_VARIABLE(Options); enum { IsRowMajor = SparseMatrixType::IsRowMajor }; typedef typename SparseMatrixType::Scalar Scalar; typedef typename SparseMatrixType::StorageIndex StorageIndex; @@ -919,7 +920,7 @@ void set_from_triplets(const InputIterator& begin, const InputIterator& end, Spa trMat.insertBackUncompressed(it->row(),it->col()) = it->value(); // pass 3: - trMat.template sumupDuplicates<DupFunctor>(); + trMat.collapseDuplicates(dup_func); } // pass 4: transposed copy -> implicit sorting @@ -970,25 +971,29 @@ template<typename Scalar, int _Options, typename _Index> template<typename InputIterators> void SparseMatrix<Scalar,_Options,_Index>::setFromTriplets(const InputIterators& begin, const InputIterators& end) { - internal::set_from_triplets<InputIterators, SparseMatrix<Scalar,_Options,_Index>, internal::scalar_sum_op<Scalar> >(begin, end, *this); + internal::set_from_triplets<InputIterators, SparseMatrix<Scalar,_Options,_Index> >(begin, end, *this, internal::scalar_sum_op<Scalar>()); } -/** The same as setFromTriplets but when duplicates are met the functor \a DupFunctor is applied: +/** The same as setFromTriplets but when duplicates are met the functor \a dup_func is applied: * \code - * value = DupFunctor()(OldValue, NewValue) + * value = dup_func(OldValue, NewValue) * \endcode - */ + * Here is a C++11 example keeping the latest entry only: + * \code + * mat.setFromTriplets(triplets.begin(), triplets.end(), [] (const Scalar&,const Scalar &b) { return b; }); + * \endcode + */ template<typename Scalar, int _Options, typename _Index> -template<typename DupFunctor, typename InputIterators> -void SparseMatrix<Scalar,_Options,_Index>::setFromTriplets(const InputIterators& begin, const InputIterators& end) +template<typename InputIterators,typename DupFunctor> +void SparseMatrix<Scalar,_Options,_Index>::setFromTriplets(const InputIterators& begin, const InputIterators& end, DupFunctor dup_func) { - internal::set_from_triplets<InputIterators, SparseMatrix<Scalar,_Options,_Index>, DupFunctor>(begin, end, *this); + internal::set_from_triplets<InputIterators, SparseMatrix<Scalar,_Options,_Index>, DupFunctor>(begin, end, *this, dup_func); } /** \internal */ template<typename Scalar, int _Options, typename _Index> template<typename DupFunctor> -void SparseMatrix<Scalar,_Options,_Index>::sumupDuplicates() +void SparseMatrix<Scalar,_Options,_Index>::collapseDuplicates(DupFunctor dup_func) { eigen_assert(!isCompressed()); // TODO, in practice we should be able to use m_innerNonZeros for that task @@ -1006,7 +1011,7 @@ void SparseMatrix<Scalar,_Options,_Index>::sumupDuplicates() if(wi(i)>=start) { // we already meet this entry => accumulate it - m_data.value(wi(i)) = DupFunctor()(m_data.value(wi(i)), m_data.value(k)); + m_data.value(wi(i)) = dup_func(m_data.value(wi(i)), m_data.value(k)); } else { |