diff options
author | Gael Guennebaud <g.gael@free.fr> | 2015-06-19 10:51:57 +0200 |
---|---|---|
committer | Gael Guennebaud <g.gael@free.fr> | 2015-06-19 10:51:57 +0200 |
commit | fad36cc8148fc4f4581ebb5b7c4a0ae4438df00a (patch) | |
tree | 1e93e8b1a18d54ca3475dc3a8f3e61ef73f10b04 /Eigen/src/Core | |
parent | 06036d8bb1ff918ac63995f349b80204b895143c (diff) |
Clean implementation of permutation * matrix products.
Diffstat (limited to 'Eigen/src/Core')
-rw-r--r-- | Eigen/src/Core/PermutationMatrix.h | 78 | ||||
-rw-r--r-- | Eigen/src/Core/ProductEvaluators.h | 93 |
2 files changed, 76 insertions, 95 deletions
diff --git a/Eigen/src/Core/PermutationMatrix.h b/Eigen/src/Core/PermutationMatrix.h index 99f5aecdd..9a0c03612 100644 --- a/Eigen/src/Core/PermutationMatrix.h +++ b/Eigen/src/Core/PermutationMatrix.h @@ -42,10 +42,6 @@ namespace Eigen { namespace internal { -template<typename PermutationType, typename MatrixType, int Side, bool Transposed=false> -struct permut_matrix_product_retval; -template<typename PermutationType, typename MatrixType, int Side, bool Transposed=false> -struct permut_sparsematrix_product_retval; enum PermPermProduct_t {PermPermProduct}; } // end namespace internal @@ -570,80 +566,6 @@ operator*(const PermutationBase<PermutationDerived> &permutation, namespace internal { -template<typename PermutationType, typename MatrixType, int Side, bool Transposed> -struct traits<permut_matrix_product_retval<PermutationType, MatrixType, Side, Transposed> > - : traits<typename MatrixType::PlainObject> -{ - typedef typename MatrixType::PlainObject ReturnType; -}; - -template<typename PermutationType, typename MatrixType, int Side, bool Transposed> -struct permut_matrix_product_retval - : public ReturnByValue<permut_matrix_product_retval<PermutationType, MatrixType, Side, Transposed> > -{ - typedef typename remove_all<typename MatrixType::Nested>::type MatrixTypeNestedCleaned; - typedef typename MatrixType::StorageIndex StorageIndex; - - permut_matrix_product_retval(const PermutationType& perm, const MatrixType& matrix) - : m_permutation(perm), m_matrix(matrix) - {} - - inline Index rows() const { return m_matrix.rows(); } - inline Index cols() const { return m_matrix.cols(); } - - template<typename Dest> inline void evalTo(Dest& dst) const - { - const Index n = Side==OnTheLeft ? rows() : cols(); - // FIXME we need an is_same for expression that is not sensitive to constness. For instance - // is_same_xpr<Block<const Matrix>, Block<Matrix> >::value should be true. - //if(is_same<MatrixTypeNestedCleaned,Dest>::value && extract_data(dst) == extract_data(m_matrix)) - if(is_same_dense(dst, m_matrix)) - { - // apply the permutation inplace - Matrix<bool,PermutationType::RowsAtCompileTime,1,0,PermutationType::MaxRowsAtCompileTime> mask(m_permutation.size()); - mask.fill(false); - Index r = 0; - while(r < m_permutation.size()) - { - // search for the next seed - while(r<m_permutation.size() && mask[r]) r++; - if(r>=m_permutation.size()) - break; - // we got one, let's follow it until we are back to the seed - Index k0 = r++; - Index kPrev = k0; - mask.coeffRef(k0) = true; - for(Index k=m_permutation.indices().coeff(k0); k!=k0; k=m_permutation.indices().coeff(k)) - { - Block<Dest, Side==OnTheLeft ? 1 : Dest::RowsAtCompileTime, Side==OnTheRight ? 1 : Dest::ColsAtCompileTime>(dst, k) - .swap(Block<Dest, Side==OnTheLeft ? 1 : Dest::RowsAtCompileTime, Side==OnTheRight ? 1 : Dest::ColsAtCompileTime> - (dst,((Side==OnTheLeft) ^ Transposed) ? k0 : kPrev)); - - mask.coeffRef(k) = true; - kPrev = k; - } - } - } - else - { - for(Index i = 0; i < n; ++i) - { - Block<Dest, Side==OnTheLeft ? 1 : Dest::RowsAtCompileTime, Side==OnTheRight ? 1 : Dest::ColsAtCompileTime> - (dst, ((Side==OnTheLeft) ^ Transposed) ? m_permutation.indices().coeff(i) : i) - - = - - Block<const MatrixTypeNestedCleaned,Side==OnTheLeft ? 1 : MatrixType::RowsAtCompileTime,Side==OnTheRight ? 1 : MatrixType::ColsAtCompileTime> - (m_matrix, ((Side==OnTheRight) ^ Transposed) ? m_permutation.indices().coeff(i) : i); - } - } - } - - protected: - const PermutationType& m_permutation; - typename MatrixType::Nested m_matrix; -}; - /* Template partial specialization for transposed/inverse permutations */ template<typename Derived> diff --git a/Eigen/src/Core/ProductEvaluators.h b/Eigen/src/Core/ProductEvaluators.h index 22b5e024b..9d1cb5d56 100644 --- a/Eigen/src/Core/ProductEvaluators.h +++ b/Eigen/src/Core/ProductEvaluators.h @@ -825,48 +825,107 @@ struct product_evaluator<Product<Lhs, Rhs, ProductKind>, ProductTag, DenseShape, /*************************************************************************** * Products with permutation matrices ***************************************************************************/ - -template<typename Lhs, typename Rhs, int ProductTag> -struct generic_product_impl<Lhs, Rhs, PermutationShape, DenseShape, ProductTag> + +/** \internal + * \class permutation_matrix_product + * Internal helper class implementing the product between a permutation matrix and a matrix. + * This class is specialized for DenseShape below and for SparseShape in SparseCore/SparsePermutation.h + */ +template<typename MatrixType, int Side, bool Transposed, typename MatrixShape> +struct permutation_matrix_product; + +template<typename MatrixType, int Side, bool Transposed> +struct permutation_matrix_product<MatrixType, Side, Transposed, DenseShape> +{ + typedef typename remove_all<MatrixType>::type MatrixTypeCleaned; + + template<typename Dest, typename PermutationType> + static inline void run(Dest& dst, const PermutationType& perm, const MatrixType& mat) + { + const Index n = Side==OnTheLeft ? mat.rows() : mat.cols(); + // FIXME we need an is_same for expression that is not sensitive to constness. For instance + // is_same_xpr<Block<const Matrix>, Block<Matrix> >::value should be true. + //if(is_same<MatrixTypeCleaned,Dest>::value && extract_data(dst) == extract_data(mat)) + if(is_same_dense(dst, mat)) + { + // apply the permutation inplace + Matrix<bool,PermutationType::RowsAtCompileTime,1,0,PermutationType::MaxRowsAtCompileTime> mask(perm.size()); + mask.fill(false); + Index r = 0; + while(r < perm.size()) + { + // search for the next seed + while(r<perm.size() && mask[r]) r++; + if(r>=perm.size()) + break; + // we got one, let's follow it until we are back to the seed + Index k0 = r++; + Index kPrev = k0; + mask.coeffRef(k0) = true; + for(Index k=perm.indices().coeff(k0); k!=k0; k=perm.indices().coeff(k)) + { + Block<Dest, Side==OnTheLeft ? 1 : Dest::RowsAtCompileTime, Side==OnTheRight ? 1 : Dest::ColsAtCompileTime>(dst, k) + .swap(Block<Dest, Side==OnTheLeft ? 1 : Dest::RowsAtCompileTime, Side==OnTheRight ? 1 : Dest::ColsAtCompileTime> + (dst,((Side==OnTheLeft) ^ Transposed) ? k0 : kPrev)); + + mask.coeffRef(k) = true; + kPrev = k; + } + } + } + else + { + for(Index i = 0; i < n; ++i) + { + Block<Dest, Side==OnTheLeft ? 1 : Dest::RowsAtCompileTime, Side==OnTheRight ? 1 : Dest::ColsAtCompileTime> + (dst, ((Side==OnTheLeft) ^ Transposed) ? perm.indices().coeff(i) : i) + + = + + Block<const MatrixTypeCleaned,Side==OnTheLeft ? 1 : MatrixType::RowsAtCompileTime,Side==OnTheRight ? 1 : MatrixType::ColsAtCompileTime> + (mat, ((Side==OnTheRight) ^ Transposed) ? perm.indices().coeff(i) : i); + } + } + } +}; + +template<typename Lhs, typename Rhs, int ProductTag, typename MatrixShape> +struct generic_product_impl<Lhs, Rhs, PermutationShape, MatrixShape, ProductTag> { template<typename Dest> static void evalTo(Dest& dst, const Lhs& lhs, const Rhs& rhs) { - permut_matrix_product_retval<Lhs, Rhs, OnTheLeft, false> pmpr(lhs, rhs); - pmpr.evalTo(dst); + permutation_matrix_product<Rhs, OnTheLeft, false, MatrixShape>::run(dst, lhs, rhs); } }; -template<typename Lhs, typename Rhs, int ProductTag> -struct generic_product_impl<Lhs, Rhs, DenseShape, PermutationShape, ProductTag> +template<typename Lhs, typename Rhs, int ProductTag, typename MatrixShape> +struct generic_product_impl<Lhs, Rhs, MatrixShape, PermutationShape, ProductTag> { template<typename Dest> static void evalTo(Dest& dst, const Lhs& lhs, const Rhs& rhs) { - permut_matrix_product_retval<Rhs, Lhs, OnTheRight, false> pmpr(rhs, lhs); - pmpr.evalTo(dst); + permutation_matrix_product<Lhs, OnTheRight, false, MatrixShape>::run(dst, rhs, lhs); } }; -template<typename Lhs, typename Rhs, int ProductTag> -struct generic_product_impl<Transpose<Lhs>, Rhs, PermutationShape, DenseShape, ProductTag> +template<typename Lhs, typename Rhs, int ProductTag, typename MatrixShape> +struct generic_product_impl<Transpose<Lhs>, Rhs, PermutationShape, MatrixShape, ProductTag> { template<typename Dest> static void evalTo(Dest& dst, const Transpose<Lhs>& lhs, const Rhs& rhs) { - permut_matrix_product_retval<Lhs, Rhs, OnTheLeft, true> pmpr(lhs.nestedPermutation(), rhs); - pmpr.evalTo(dst); + permutation_matrix_product<Rhs, OnTheLeft, true, MatrixShape>::run(dst, lhs.nestedPermutation(), rhs); } }; -template<typename Lhs, typename Rhs, int ProductTag> -struct generic_product_impl<Lhs, Transpose<Rhs>, DenseShape, PermutationShape, ProductTag> +template<typename Lhs, typename Rhs, int ProductTag, typename MatrixShape> +struct generic_product_impl<Lhs, Transpose<Rhs>, MatrixShape, PermutationShape, ProductTag> { template<typename Dest> static void evalTo(Dest& dst, const Lhs& lhs, const Transpose<Rhs>& rhs) { - permut_matrix_product_retval<Rhs, Lhs, OnTheRight, true> pmpr(rhs.nestedPermutation(), lhs); - pmpr.evalTo(dst); + permutation_matrix_product<Lhs, OnTheRight, true, MatrixShape>::run(dst, rhs.nestedPermutation(), lhs); } }; |