diff options
-rw-r--r-- | Eigen/src/Core/PermutationMatrix.h | 30 | ||||
-rw-r--r-- | test/permutationmatrices.cpp | 9 |
2 files changed, 36 insertions, 3 deletions
diff --git a/Eigen/src/Core/PermutationMatrix.h b/Eigen/src/Core/PermutationMatrix.h index aaccb4e7b..1c66cde8e 100644 --- a/Eigen/src/Core/PermutationMatrix.h +++ b/Eigen/src/Core/PermutationMatrix.h @@ -105,10 +105,10 @@ class PermutationMatrix : public AnyMatrixBase<PermutationMatrix<SizeAtCompileTi ei_assert(rows == cols); } - /** \returns the number of columns */ + /** \returns the number of rows */ inline int rows() const { return m_indices.size(); } - /** \returns the number of rows */ + /** \returns the number of columns */ inline int cols() const { return m_indices.size(); } template<typename DenseDerived> @@ -126,7 +126,31 @@ class PermutationMatrix : public AnyMatrixBase<PermutationMatrix<SizeAtCompileTi const IndicesType& indices() const { return m_indices; } IndicesType& indices() { return m_indices; } - + + /**** inversion and multiplication helpers to hopefully get RVO ****/ + + protected: + enum Inverse_t {Inverse}; + PermutationMatrix(Inverse_t, const PermutationMatrix& other) + : m_indices(other.m_indices.size()) + { + for (int i=0; i<rows();++i) m_indices.coeffRef(other.m_indices.coeff(i)) = i; + } + enum Product_t {Product}; + PermutationMatrix(Product_t, const PermutationMatrix& lhs, const PermutationMatrix& rhs) + : m_indices(lhs.m_indices.size()) + { + ei_assert(lhs.cols() == rhs.rows()); + for (int i=0; i<rows();++i) m_indices.coeffRef(i) = lhs.m_indices.coeff(rhs.m_indices.coeff(i)); + } + + public: + inline PermutationMatrix inverse() const + { return PermutationMatrix(Inverse, *this); } + template<int OtherSize, int OtherMaxSize> + inline PermutationMatrix operator*(const PermutationMatrix<OtherSize, OtherMaxSize>& other) const + { return PermutationMatrix(Product, *this, other); } + protected: IndicesType m_indices; diff --git a/test/permutationmatrices.cpp b/test/permutationmatrices.cpp index 13b01cd83..ec3a8541c 100644 --- a/test/permutationmatrices.cpp +++ b/test/permutationmatrices.cpp @@ -72,6 +72,15 @@ template<typename MatrixType> void permutationmatrices(const MatrixType& m) Matrix<Scalar,Cols,Cols> rm(rp); VERIFY_IS_APPROX(m_permuted, lm*m_original*rm); + + VERIFY_IS_APPROX(lp.inverse()*m_permuted*rp.inverse(), m_original); + VERIFY((lp*lp.inverse()).toDenseMatrix().isIdentity()); + + LeftPermutationVectorType lv2; + randomPermutationVector(lv2, rows); + LeftPermutationType lp2(lv2); + Matrix<Scalar,Rows,Rows> lm2(lp2); + VERIFY_IS_APPROX((lp*lp2).toDenseMatrix().template cast<Scalar>(), lm2*lm); } void test_permutationmatrices() |