diff options
author | Gael Guennebaud <g.gael@free.fr> | 2010-02-25 16:30:58 +0100 |
---|---|---|
committer | Gael Guennebaud <g.gael@free.fr> | 2010-02-25 16:30:58 +0100 |
commit | 959a1b5d6335833e9ad49a088502705bb6967ff5 (patch) | |
tree | 0b41aa2f93a5b009b3699809ccfd4000582c92bf | |
parent | d9ca0c0d3643f4b777de686a2c0cddde075aa063 (diff) |
detect and implement inplace permutations
-rw-r--r-- | Eigen/src/Core/PermutationMatrix.h | 53 | ||||
-rw-r--r-- | Eigen/src/Core/Transpose.h | 19 | ||||
-rw-r--r-- | Eigen/src/Core/util/BlasUtil.h | 18 | ||||
-rw-r--r-- | Eigen/src/LU/FullPivLU.h | 8 | ||||
-rw-r--r-- | Eigen/src/LU/PartialPivLU.h | 5 | ||||
-rw-r--r-- | test/permutationmatrices.cpp | 19 |
6 files changed, 80 insertions, 42 deletions
diff --git a/Eigen/src/Core/PermutationMatrix.h b/Eigen/src/Core/PermutationMatrix.h index c42812ec8..46884dc3f 100644 --- a/Eigen/src/Core/PermutationMatrix.h +++ b/Eigen/src/Core/PermutationMatrix.h @@ -326,21 +326,46 @@ struct ei_permut_matrix_product_retval template<typename Dest> inline void evalTo(Dest& dst) const { const int n = Side==OnTheLeft ? rows() : cols(); - for(int i = 0; i < n; ++i) + + if(ei_is_same_type<MatrixTypeNestedCleaned,Dest>::ret && ei_extract_data(dst) == ei_extract_data(m_matrix)) + { + // apply the permutation inplace + Matrix<bool,PermutationType::RowsAtCompileTime,1,0,PermutationType::MaxRowsAtCompileTime> mask(m_permutation.size()); + mask.fill(false); + int r = 0; + while(r < m_permutation.size()) + { + // search for the next seed + while(r<m_permutation.size() && mask[r]) r++; + if(r>=m_permutation.size()) + break; + // we got one, let's follow it until we are back to the seed + int k0 = r++; + int kPrev = k0; + mask.coeffRef(k0) = true; + for(int k=m_permutation.indices().coeff(k0); k!=k0; k=m_permutation.indices().coeff(k)) + { + Block<Dest, Side==OnTheLeft ? 1 : Dest::RowsAtCompileTime, Side==OnTheRight ? 1 : Dest::ColsAtCompileTime>(dst, k) + .swap(Block<Dest, Side==OnTheLeft ? 1 : Dest::RowsAtCompileTime, Side==OnTheRight ? 1 : Dest::ColsAtCompileTime> + (dst,((Side==OnTheLeft) ^ Transposed) ? k0 : kPrev)); + + mask.coeffRef(k) = true; + kPrev = k; + } + } + } + else { - Block< - Dest, - Side==OnTheLeft ? 1 : Dest::RowsAtCompileTime, - Side==OnTheRight ? 1 : Dest::ColsAtCompileTime - >(dst, ((Side==OnTheLeft) ^ Transposed) ? m_permutation.indices().coeff(i) : i) - - = - - Block< - MatrixTypeNestedCleaned, - Side==OnTheLeft ? 1 : MatrixType::RowsAtCompileTime, - Side==OnTheRight ? 1 : MatrixType::ColsAtCompileTime - >(m_matrix, ((Side==OnTheRight) ^ Transposed) ? m_permutation.indices().coeff(i) : i); + for(int i = 0; i < n; ++i) + { + Block<Dest, Side==OnTheLeft ? 1 : Dest::RowsAtCompileTime, Side==OnTheRight ? 1 : Dest::ColsAtCompileTime> + (dst, ((Side==OnTheLeft) ^ Transposed) ? m_permutation.indices().coeff(i) : i) + + = + + Block<MatrixTypeNestedCleaned,Side==OnTheLeft ? 1 : MatrixType::RowsAtCompileTime,Side==OnTheRight ? 1 : MatrixType::ColsAtCompileTime> + (m_matrix, ((Side==OnTheRight) ^ Transposed) ? m_permutation.indices().coeff(i) : i); + } } } diff --git a/Eigen/src/Core/Transpose.h b/Eigen/src/Core/Transpose.h index bd06d8464..6c0e50de2 100644 --- a/Eigen/src/Core/Transpose.h +++ b/Eigen/src/Core/Transpose.h @@ -295,25 +295,6 @@ struct ei_blas_traits<SelfCwiseBinaryOp<BinOp,NestedXpr> > static inline const XprType extract(const XprType& x) { return x; } }; - -template<typename T, int Access=ei_blas_traits<T>::ActualAccess> -struct ei_extract_data_selector { - static typename T::Scalar* run(const T& m) - { - return &ei_blas_traits<T>::extract(m).const_cast_derived().coeffRef(0,0); - } -}; - -template<typename T> -struct ei_extract_data_selector<T,NoDirectAccess> { - static typename T::Scalar* run(const T&) { return 0; } -}; - -template<typename T> typename T::Scalar* ei_extract_data(const T& m) -{ - return ei_extract_data_selector<T>::run(m); -} - template<typename Scalar, bool DestIsTranposed, typename OtherDerived> struct ei_check_transpose_aliasing_selector { diff --git a/Eigen/src/Core/util/BlasUtil.h b/Eigen/src/Core/util/BlasUtil.h index 2ca463d5d..4d216d77a 100644 --- a/Eigen/src/Core/util/BlasUtil.h +++ b/Eigen/src/Core/util/BlasUtil.h @@ -236,4 +236,22 @@ struct ei_blas_traits<Transpose<NestedXpr> > static inline Scalar extractScalarFactor(const XprType& x) { return Base::extractScalarFactor(x.nestedExpression()); } }; +template<typename T, int Access=ei_blas_traits<T>::ActualAccess> +struct ei_extract_data_selector { + static const typename T::Scalar* run(const T& m) + { + return &ei_blas_traits<T>::extract(m).const_cast_derived().coeffRef(0,0); // FIXME this should be .data() + } +}; + +template<typename T> +struct ei_extract_data_selector<T,NoDirectAccess> { + static typename T::Scalar* run(const T&) { return 0; } +}; + +template<typename T> const typename T::Scalar* ei_extract_data(const T& m) +{ + return ei_extract_data_selector<T>::run(m); +} + #endif // EIGEN_BLASUTIL_H diff --git a/Eigen/src/LU/FullPivLU.h b/Eigen/src/LU/FullPivLU.h index cd63b9ec7..dea6ec41c 100644 --- a/Eigen/src/LU/FullPivLU.h +++ b/Eigen/src/LU/FullPivLU.h @@ -119,7 +119,7 @@ template<typename _MatrixType> class FullPivLU * diagonal coefficient of U. */ RealScalar maxPivot() const { return m_maxpivot; } - + /** \returns the permutation matrix P * * \sa permutationQ() @@ -506,12 +506,10 @@ MatrixType FullPivLU<MatrixType>::reconstructedMatrix() const .template triangularView<Upper>().toDenseMatrix(); // P^{-1}(LU) - // FIXME implement inplace permutation - res = (m_p.inverse() * res).eval(); + res = m_p.inverse() * res; // (P^{-1}LU)Q^{-1} - // FIXME implement inplace permutation - res = (res * m_q.inverse()).eval(); + res = res * m_q.inverse(); return res; } diff --git a/Eigen/src/LU/PartialPivLU.h b/Eigen/src/LU/PartialPivLU.h index fcffc2458..ad0d6b810 100644 --- a/Eigen/src/LU/PartialPivLU.h +++ b/Eigen/src/LU/PartialPivLU.h @@ -412,10 +412,9 @@ MatrixType PartialPivLU<MatrixType>::reconstructedMatrix() const // LU MatrixType res = m_lu.template triangularView<UnitLower>().toDenseMatrix() * m_lu.template triangularView<Upper>(); - + // P^{-1}(LU) - // FIXME implement inplace permutation - res = (m_p.inverse() * res).eval(); + res = m_p.inverse() * res; return res; } diff --git a/test/permutationmatrices.cpp b/test/permutationmatrices.cpp index ae1bd8b85..89142d910 100644 --- a/test/permutationmatrices.cpp +++ b/test/permutationmatrices.cpp @@ -86,6 +86,23 @@ template<typename MatrixType> void permutationmatrices(const MatrixType& m) identityp.setIdentity(rows); VERIFY_IS_APPROX(m_original, identityp*m_original); + // check inplace permutations + m_permuted = m_original; + m_permuted = lp.inverse() * m_permuted; + VERIFY_IS_APPROX(m_permuted, lp.inverse()*m_original); + + m_permuted = m_original; + m_permuted = m_permuted * rp.inverse(); + VERIFY_IS_APPROX(m_permuted, m_original*rp.inverse()); + + m_permuted = m_original; + m_permuted = lp * m_permuted; + VERIFY_IS_APPROX(m_permuted, lp*m_original); + + m_permuted = m_original; + m_permuted = m_permuted * rp; + VERIFY_IS_APPROX(m_permuted, m_original*rp); + if(rows>1 && cols>1) { lp2 = lp; @@ -114,7 +131,7 @@ void test_permutationmatrices() CALL_SUBTEST_2( permutationmatrices(Matrix3f()) ); CALL_SUBTEST_3( permutationmatrices(Matrix<double,3,3,RowMajor>()) ); CALL_SUBTEST_4( permutationmatrices(Matrix4d()) ); - CALL_SUBTEST_5( permutationmatrices(Matrix<double,4,6>()) ); + CALL_SUBTEST_5( permutationmatrices(Matrix<double,40,60>()) ); CALL_SUBTEST_6( permutationmatrices(Matrix<double,Dynamic,Dynamic,RowMajor>(20, 30)) ); CALL_SUBTEST_7( permutationmatrices(MatrixXcf(15, 10)) ); } |