aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Gael Guennebaud <g.gael@free.fr>2010-02-25 16:30:58 +0100
committerGravatar Gael Guennebaud <g.gael@free.fr>2010-02-25 16:30:58 +0100
commit959a1b5d6335833e9ad49a088502705bb6967ff5 (patch)
tree0b41aa2f93a5b009b3699809ccfd4000582c92bf
parentd9ca0c0d3643f4b777de686a2c0cddde075aa063 (diff)
detect and implement inplace permutations
-rw-r--r--Eigen/src/Core/PermutationMatrix.h53
-rw-r--r--Eigen/src/Core/Transpose.h19
-rw-r--r--Eigen/src/Core/util/BlasUtil.h18
-rw-r--r--Eigen/src/LU/FullPivLU.h8
-rw-r--r--Eigen/src/LU/PartialPivLU.h5
-rw-r--r--test/permutationmatrices.cpp19
6 files changed, 80 insertions, 42 deletions
diff --git a/Eigen/src/Core/PermutationMatrix.h b/Eigen/src/Core/PermutationMatrix.h
index c42812ec8..46884dc3f 100644
--- a/Eigen/src/Core/PermutationMatrix.h
+++ b/Eigen/src/Core/PermutationMatrix.h
@@ -326,21 +326,46 @@ struct ei_permut_matrix_product_retval
template<typename Dest> inline void evalTo(Dest& dst) const
{
const int n = Side==OnTheLeft ? rows() : cols();
- for(int i = 0; i < n; ++i)
+
+ if(ei_is_same_type<MatrixTypeNestedCleaned,Dest>::ret && ei_extract_data(dst) == ei_extract_data(m_matrix))
+ {
+ // apply the permutation inplace
+ Matrix<bool,PermutationType::RowsAtCompileTime,1,0,PermutationType::MaxRowsAtCompileTime> mask(m_permutation.size());
+ mask.fill(false);
+ int r = 0;
+ while(r < m_permutation.size())
+ {
+ // search for the next seed
+ while(r<m_permutation.size() && mask[r]) r++;
+ if(r>=m_permutation.size())
+ break;
+ // we got one, let's follow it until we are back to the seed
+ int k0 = r++;
+ int kPrev = k0;
+ mask.coeffRef(k0) = true;
+ for(int k=m_permutation.indices().coeff(k0); k!=k0; k=m_permutation.indices().coeff(k))
+ {
+ Block<Dest, Side==OnTheLeft ? 1 : Dest::RowsAtCompileTime, Side==OnTheRight ? 1 : Dest::ColsAtCompileTime>(dst, k)
+ .swap(Block<Dest, Side==OnTheLeft ? 1 : Dest::RowsAtCompileTime, Side==OnTheRight ? 1 : Dest::ColsAtCompileTime>
+ (dst,((Side==OnTheLeft) ^ Transposed) ? k0 : kPrev));
+
+ mask.coeffRef(k) = true;
+ kPrev = k;
+ }
+ }
+ }
+ else
{
- Block<
- Dest,
- Side==OnTheLeft ? 1 : Dest::RowsAtCompileTime,
- Side==OnTheRight ? 1 : Dest::ColsAtCompileTime
- >(dst, ((Side==OnTheLeft) ^ Transposed) ? m_permutation.indices().coeff(i) : i)
-
- =
-
- Block<
- MatrixTypeNestedCleaned,
- Side==OnTheLeft ? 1 : MatrixType::RowsAtCompileTime,
- Side==OnTheRight ? 1 : MatrixType::ColsAtCompileTime
- >(m_matrix, ((Side==OnTheRight) ^ Transposed) ? m_permutation.indices().coeff(i) : i);
+ for(int i = 0; i < n; ++i)
+ {
+ Block<Dest, Side==OnTheLeft ? 1 : Dest::RowsAtCompileTime, Side==OnTheRight ? 1 : Dest::ColsAtCompileTime>
+ (dst, ((Side==OnTheLeft) ^ Transposed) ? m_permutation.indices().coeff(i) : i)
+
+ =
+
+ Block<MatrixTypeNestedCleaned,Side==OnTheLeft ? 1 : MatrixType::RowsAtCompileTime,Side==OnTheRight ? 1 : MatrixType::ColsAtCompileTime>
+ (m_matrix, ((Side==OnTheRight) ^ Transposed) ? m_permutation.indices().coeff(i) : i);
+ }
}
}
diff --git a/Eigen/src/Core/Transpose.h b/Eigen/src/Core/Transpose.h
index bd06d8464..6c0e50de2 100644
--- a/Eigen/src/Core/Transpose.h
+++ b/Eigen/src/Core/Transpose.h
@@ -295,25 +295,6 @@ struct ei_blas_traits<SelfCwiseBinaryOp<BinOp,NestedXpr> >
static inline const XprType extract(const XprType& x) { return x; }
};
-
-template<typename T, int Access=ei_blas_traits<T>::ActualAccess>
-struct ei_extract_data_selector {
- static typename T::Scalar* run(const T& m)
- {
- return &ei_blas_traits<T>::extract(m).const_cast_derived().coeffRef(0,0);
- }
-};
-
-template<typename T>
-struct ei_extract_data_selector<T,NoDirectAccess> {
- static typename T::Scalar* run(const T&) { return 0; }
-};
-
-template<typename T> typename T::Scalar* ei_extract_data(const T& m)
-{
- return ei_extract_data_selector<T>::run(m);
-}
-
template<typename Scalar, bool DestIsTranposed, typename OtherDerived>
struct ei_check_transpose_aliasing_selector
{
diff --git a/Eigen/src/Core/util/BlasUtil.h b/Eigen/src/Core/util/BlasUtil.h
index 2ca463d5d..4d216d77a 100644
--- a/Eigen/src/Core/util/BlasUtil.h
+++ b/Eigen/src/Core/util/BlasUtil.h
@@ -236,4 +236,22 @@ struct ei_blas_traits<Transpose<NestedXpr> >
static inline Scalar extractScalarFactor(const XprType& x) { return Base::extractScalarFactor(x.nestedExpression()); }
};
+template<typename T, int Access=ei_blas_traits<T>::ActualAccess>
+struct ei_extract_data_selector {
+ static const typename T::Scalar* run(const T& m)
+ {
+ return &ei_blas_traits<T>::extract(m).const_cast_derived().coeffRef(0,0); // FIXME this should be .data()
+ }
+};
+
+template<typename T>
+struct ei_extract_data_selector<T,NoDirectAccess> {
+ static typename T::Scalar* run(const T&) { return 0; }
+};
+
+template<typename T> const typename T::Scalar* ei_extract_data(const T& m)
+{
+ return ei_extract_data_selector<T>::run(m);
+}
+
#endif // EIGEN_BLASUTIL_H
diff --git a/Eigen/src/LU/FullPivLU.h b/Eigen/src/LU/FullPivLU.h
index cd63b9ec7..dea6ec41c 100644
--- a/Eigen/src/LU/FullPivLU.h
+++ b/Eigen/src/LU/FullPivLU.h
@@ -119,7 +119,7 @@ template<typename _MatrixType> class FullPivLU
* diagonal coefficient of U.
*/
RealScalar maxPivot() const { return m_maxpivot; }
-
+
/** \returns the permutation matrix P
*
* \sa permutationQ()
@@ -506,12 +506,10 @@ MatrixType FullPivLU<MatrixType>::reconstructedMatrix() const
.template triangularView<Upper>().toDenseMatrix();
// P^{-1}(LU)
- // FIXME implement inplace permutation
- res = (m_p.inverse() * res).eval();
+ res = m_p.inverse() * res;
// (P^{-1}LU)Q^{-1}
- // FIXME implement inplace permutation
- res = (res * m_q.inverse()).eval();
+ res = res * m_q.inverse();
return res;
}
diff --git a/Eigen/src/LU/PartialPivLU.h b/Eigen/src/LU/PartialPivLU.h
index fcffc2458..ad0d6b810 100644
--- a/Eigen/src/LU/PartialPivLU.h
+++ b/Eigen/src/LU/PartialPivLU.h
@@ -412,10 +412,9 @@ MatrixType PartialPivLU<MatrixType>::reconstructedMatrix() const
// LU
MatrixType res = m_lu.template triangularView<UnitLower>().toDenseMatrix()
* m_lu.template triangularView<Upper>();
-
+
// P^{-1}(LU)
- // FIXME implement inplace permutation
- res = (m_p.inverse() * res).eval();
+ res = m_p.inverse() * res;
return res;
}
diff --git a/test/permutationmatrices.cpp b/test/permutationmatrices.cpp
index ae1bd8b85..89142d910 100644
--- a/test/permutationmatrices.cpp
+++ b/test/permutationmatrices.cpp
@@ -86,6 +86,23 @@ template<typename MatrixType> void permutationmatrices(const MatrixType& m)
identityp.setIdentity(rows);
VERIFY_IS_APPROX(m_original, identityp*m_original);
+ // check inplace permutations
+ m_permuted = m_original;
+ m_permuted = lp.inverse() * m_permuted;
+ VERIFY_IS_APPROX(m_permuted, lp.inverse()*m_original);
+
+ m_permuted = m_original;
+ m_permuted = m_permuted * rp.inverse();
+ VERIFY_IS_APPROX(m_permuted, m_original*rp.inverse());
+
+ m_permuted = m_original;
+ m_permuted = lp * m_permuted;
+ VERIFY_IS_APPROX(m_permuted, lp*m_original);
+
+ m_permuted = m_original;
+ m_permuted = m_permuted * rp;
+ VERIFY_IS_APPROX(m_permuted, m_original*rp);
+
if(rows>1 && cols>1)
{
lp2 = lp;
@@ -114,7 +131,7 @@ void test_permutationmatrices()
CALL_SUBTEST_2( permutationmatrices(Matrix3f()) );
CALL_SUBTEST_3( permutationmatrices(Matrix<double,3,3,RowMajor>()) );
CALL_SUBTEST_4( permutationmatrices(Matrix4d()) );
- CALL_SUBTEST_5( permutationmatrices(Matrix<double,4,6>()) );
+ CALL_SUBTEST_5( permutationmatrices(Matrix<double,40,60>()) );
CALL_SUBTEST_6( permutationmatrices(Matrix<double,Dynamic,Dynamic,RowMajor>(20, 30)) );
CALL_SUBTEST_7( permutationmatrices(MatrixXcf(15, 10)) );
}