diff options
author | Gael Guennebaud <g.gael@free.fr> | 2015-06-19 10:51:57 +0200 |
---|---|---|
committer | Gael Guennebaud <g.gael@free.fr> | 2015-06-19 10:51:57 +0200 |
commit | fad36cc8148fc4f4581ebb5b7c4a0ae4438df00a (patch) | |
tree | 1e93e8b1a18d54ca3475dc3a8f3e61ef73f10b04 /Eigen/src/SparseCore/SparsePermutation.h | |
parent | 06036d8bb1ff918ac63995f349b80204b895143c (diff) |
Clean implementation of permutation * matrix products.
Diffstat (limited to 'Eigen/src/SparseCore/SparsePermutation.h')
-rw-r--r-- | Eigen/src/SparseCore/SparsePermutation.h | 133 |
1 files changed, 32 insertions, 101 deletions
diff --git a/Eigen/src/SparseCore/SparsePermutation.h b/Eigen/src/SparseCore/SparsePermutation.h index 4be93c18c..b15128979 100644 --- a/Eigen/src/SparseCore/SparsePermutation.h +++ b/Eigen/src/SparseCore/SparsePermutation.h @@ -16,25 +16,8 @@ namespace Eigen { namespace internal { -template<typename PermutationType, typename MatrixType, int Side, bool Transposed> -struct traits<permut_sparsematrix_product_retval<PermutationType, MatrixType, Side, Transposed> > -{ - typedef typename remove_all<typename MatrixType::Nested>::type MatrixTypeNestedCleaned; - typedef typename MatrixTypeNestedCleaned::Scalar Scalar; - typedef typename MatrixTypeNestedCleaned::StorageIndex StorageIndex; - enum { - SrcStorageOrder = MatrixTypeNestedCleaned::Flags&RowMajorBit ? RowMajor : ColMajor, - MoveOuter = SrcStorageOrder==RowMajor ? Side==OnTheLeft : Side==OnTheRight - }; - - typedef typename internal::conditional<MoveOuter, - SparseMatrix<Scalar,SrcStorageOrder,StorageIndex>, - SparseMatrix<Scalar,int(SrcStorageOrder)==RowMajor?ColMajor:RowMajor,StorageIndex> >::type ReturnType; -}; - -template<typename PermutationType, typename MatrixType, int Side, bool Transposed> -struct permut_sparsematrix_product_retval - : public ReturnByValue<permut_sparsematrix_product_retval<PermutationType, MatrixType, Side, Transposed> > +template<typename MatrixType, int Side, bool Transposed> +struct permutation_matrix_product<MatrixType, Side, Transposed, SparseShape> { typedef typename remove_all<typename MatrixType::Nested>::type MatrixTypeNestedCleaned; typedef typename MatrixTypeNestedCleaned::Scalar Scalar; @@ -44,61 +27,55 @@ struct permut_sparsematrix_product_retval SrcStorageOrder = MatrixTypeNestedCleaned::Flags&RowMajorBit ? RowMajor : ColMajor, MoveOuter = SrcStorageOrder==RowMajor ? Side==OnTheLeft : Side==OnTheRight }; + + typedef typename internal::conditional<MoveOuter, + SparseMatrix<Scalar,SrcStorageOrder,StorageIndex>, + SparseMatrix<Scalar,int(SrcStorageOrder)==RowMajor?ColMajor:RowMajor,StorageIndex> >::type ReturnType; - permut_sparsematrix_product_retval(const PermutationType& perm, const MatrixType& matrix) - : m_permutation(perm), m_matrix(matrix) - {} - - inline int rows() const { return m_matrix.rows(); } - inline int cols() const { return m_matrix.cols(); } - - template<typename Dest> inline void evalTo(Dest& dst) const + template<typename Dest,typename PermutationType> + static inline void run(Dest& dst, const PermutationType& perm, const MatrixType& mat) { if(MoveOuter) { - SparseMatrix<Scalar,SrcStorageOrder,StorageIndex> tmp(m_matrix.rows(), m_matrix.cols()); - Matrix<StorageIndex,Dynamic,1> sizes(m_matrix.outerSize()); - for(Index j=0; j<m_matrix.outerSize(); ++j) + SparseMatrix<Scalar,SrcStorageOrder,StorageIndex> tmp(mat.rows(), mat.cols()); + Matrix<StorageIndex,Dynamic,1> sizes(mat.outerSize()); + for(Index j=0; j<mat.outerSize(); ++j) { - Index jp = m_permutation.indices().coeff(j); - sizes[((Side==OnTheLeft) ^ Transposed) ? jp : j] = StorageIndex(m_matrix.innerVector(((Side==OnTheRight) ^ Transposed) ? jp : j).nonZeros()); + Index jp = perm.indices().coeff(j); + sizes[((Side==OnTheLeft) ^ Transposed) ? jp : j] = StorageIndex(mat.innerVector(((Side==OnTheRight) ^ Transposed) ? jp : j).nonZeros()); } tmp.reserve(sizes); - for(Index j=0; j<m_matrix.outerSize(); ++j) + for(Index j=0; j<mat.outerSize(); ++j) { - Index jp = m_permutation.indices().coeff(j); + Index jp = perm.indices().coeff(j); Index jsrc = ((Side==OnTheRight) ^ Transposed) ? jp : j; Index jdst = ((Side==OnTheLeft) ^ Transposed) ? jp : j; - for(typename MatrixTypeNestedCleaned::InnerIterator it(m_matrix,jsrc); it; ++it) + for(typename MatrixTypeNestedCleaned::InnerIterator it(mat,jsrc); it; ++it) tmp.insertByOuterInner(jdst,it.index()) = it.value(); } dst = tmp; } else { - SparseMatrix<Scalar,int(SrcStorageOrder)==RowMajor?ColMajor:RowMajor,StorageIndex> tmp(m_matrix.rows(), m_matrix.cols()); + SparseMatrix<Scalar,int(SrcStorageOrder)==RowMajor?ColMajor:RowMajor,StorageIndex> tmp(mat.rows(), mat.cols()); Matrix<StorageIndex,Dynamic,1> sizes(tmp.outerSize()); sizes.setZero(); - PermutationMatrix<Dynamic,Dynamic,StorageIndex> perm; + PermutationMatrix<Dynamic,Dynamic,StorageIndex> perm_cpy; if((Side==OnTheLeft) ^ Transposed) - perm = m_permutation; + perm_cpy = perm; else - perm = m_permutation.transpose(); + perm_cpy = perm.transpose(); - for(Index j=0; j<m_matrix.outerSize(); ++j) - for(typename MatrixTypeNestedCleaned::InnerIterator it(m_matrix,j); it; ++it) - sizes[perm.indices().coeff(it.index())]++; + for(Index j=0; j<mat.outerSize(); ++j) + for(typename MatrixTypeNestedCleaned::InnerIterator it(mat,j); it; ++it) + sizes[perm_cpy.indices().coeff(it.index())]++; tmp.reserve(sizes); - for(Index j=0; j<m_matrix.outerSize(); ++j) - for(typename MatrixTypeNestedCleaned::InnerIterator it(m_matrix,j); it; ++it) - tmp.insertByOuterInner(perm.indices().coeff(it.index()),j) = it.value(); + for(Index j=0; j<mat.outerSize(); ++j) + for(typename MatrixTypeNestedCleaned::InnerIterator it(mat,j); it; ++it) + tmp.insertByOuterInner(perm_cpy.indices().coeff(it.index()),j) = it.value(); dst = tmp; } } - - protected: - const PermutationType& m_permutation; - typename MatrixType::Nested m_matrix; }; } @@ -107,63 +84,17 @@ namespace internal { template <int ProductTag> struct product_promote_storage_type<Sparse, PermutationStorage, ProductTag> { typedef Sparse ret; }; template <int ProductTag> struct product_promote_storage_type<PermutationStorage, Sparse, ProductTag> { typedef Sparse ret; }; - -// TODO, the following need cleaning, this is just a copy-past of the dense case - -template<typename Lhs, typename Rhs, int ProductTag> -struct generic_product_impl<Lhs, Rhs, PermutationShape, SparseShape, ProductTag> -{ - template<typename Dest> - static void evalTo(Dest& dst, const Lhs& lhs, const Rhs& rhs) - { - permut_sparsematrix_product_retval<Lhs, Rhs, OnTheLeft, false> pmpr(lhs, rhs); - pmpr.evalTo(dst); - } -}; - -template<typename Lhs, typename Rhs, int ProductTag> -struct generic_product_impl<Lhs, Rhs, SparseShape, PermutationShape, ProductTag> -{ - template<typename Dest> - static void evalTo(Dest& dst, const Lhs& lhs, const Rhs& rhs) - { - permut_sparsematrix_product_retval<Rhs, Lhs, OnTheRight, false> pmpr(rhs, lhs); - pmpr.evalTo(dst); - } -}; - -template<typename Lhs, typename Rhs, int ProductTag> -struct generic_product_impl<Transpose<Lhs>, Rhs, PermutationShape, SparseShape, ProductTag> -{ - template<typename Dest> - static void evalTo(Dest& dst, const Transpose<Lhs>& lhs, const Rhs& rhs) - { - permut_sparsematrix_product_retval<Lhs, Rhs, OnTheLeft, true> pmpr(lhs.nestedPermutation(), rhs); - pmpr.evalTo(dst); - } -}; - -template<typename Lhs, typename Rhs, int ProductTag> -struct generic_product_impl<Lhs, Transpose<Rhs>, SparseShape, PermutationShape, ProductTag> -{ - template<typename Dest> - static void evalTo(Dest& dst, const Lhs& lhs, const Transpose<Rhs>& rhs) - { - permut_sparsematrix_product_retval<Rhs, Lhs, OnTheRight, true> pmpr(rhs.nestedPermutation(), lhs); - pmpr.evalTo(dst); - } -}; // TODO, the following two overloads are only needed to define the right temporary type through -// typename traits<permut_sparsematrix_product_retval<Rhs,Lhs,OnTheRight,false> >::ReturnType -// while it should be correctly handled by traits<Product<> >::PlainObject +// typename traits<permutation_sparse_matrix_product<Rhs,Lhs,OnTheRight,false> >::ReturnType +// whereas it should be correctly handled by traits<Product<> >::PlainObject template<typename Lhs, typename Rhs, int ProductTag> struct product_evaluator<Product<Lhs, Rhs, DefaultProduct>, ProductTag, PermutationShape, SparseShape, typename traits<Lhs>::Scalar, typename traits<Rhs>::Scalar> - : public evaluator<typename traits<permut_sparsematrix_product_retval<Lhs,Rhs,OnTheRight,false> >::ReturnType>::type + : public evaluator<typename permutation_matrix_product<Rhs,OnTheRight,false,SparseShape>::ReturnType>::type { typedef Product<Lhs, Rhs, DefaultProduct> XprType; - typedef typename traits<permut_sparsematrix_product_retval<Lhs,Rhs,OnTheRight,false> >::ReturnType PlainObject; + typedef typename permutation_matrix_product<Rhs,OnTheRight,false,SparseShape>::ReturnType PlainObject; typedef typename evaluator<PlainObject>::type Base; explicit product_evaluator(const XprType& xpr) @@ -179,10 +110,10 @@ protected: template<typename Lhs, typename Rhs, int ProductTag> struct product_evaluator<Product<Lhs, Rhs, DefaultProduct>, ProductTag, SparseShape, PermutationShape, typename traits<Lhs>::Scalar, typename traits<Rhs>::Scalar> - : public evaluator<typename traits<permut_sparsematrix_product_retval<Rhs,Lhs,OnTheRight,false> >::ReturnType>::type + : public evaluator<typename permutation_matrix_product<Lhs,OnTheRight,false,SparseShape>::ReturnType>::type { typedef Product<Lhs, Rhs, DefaultProduct> XprType; - typedef typename traits<permut_sparsematrix_product_retval<Rhs,Lhs,OnTheRight,false> >::ReturnType PlainObject; + typedef typename permutation_matrix_product<Lhs,OnTheRight,false,SparseShape>::ReturnType PlainObject; typedef typename evaluator<PlainObject>::type Base; explicit product_evaluator(const XprType& xpr) |