diff options
author | Gael Guennebaud <g.gael@free.fr> | 2016-12-20 16:33:53 +0100 |
---|---|---|
committer | Gael Guennebaud <g.gael@free.fr> | 2016-12-20 16:33:53 +0100 |
commit | 684cfc762d70e8ab766bc94968d8d5e462c44074 (patch) | |
tree | 57c3868e3a092abb9a1b28c34ba04cb1f1988e76 | |
parent | 8bd0d3aa345b1b10b1666401aad0e66e0a3a8303 (diff) |
Add transpose, adjoint, conjugate methods to SelfAdjointView (useful to write generic code)
-rw-r--r-- | Eigen/src/Core/SelfAdjointView.h | 36 | ||||
-rw-r--r-- | test/product_symm.cpp | 18 |
2 files changed, 52 insertions, 2 deletions
diff --git a/Eigen/src/Core/SelfAdjointView.h b/Eigen/src/Core/SelfAdjointView.h index 62d4180da..06484ab30 100644 --- a/Eigen/src/Core/SelfAdjointView.h +++ b/Eigen/src/Core/SelfAdjointView.h @@ -45,7 +45,7 @@ struct traits<SelfAdjointView<MatrixType, UpLo> > : traits<MatrixType> }; } -// FIXME could also be called SelfAdjointWrapper to be consistent with DiagonalWrapper ?? + template<typename _MatrixType, unsigned int UpLo> class SelfAdjointView : public TriangularBase<SelfAdjointView<_MatrixType, UpLo> > { @@ -60,10 +60,12 @@ template<typename _MatrixType, unsigned int UpLo> class SelfAdjointView /** \brief The type of coefficients in this matrix */ typedef typename internal::traits<SelfAdjointView>::Scalar Scalar; typedef typename MatrixType::StorageIndex StorageIndex; + typedef typename internal::remove_all<typename MatrixType::ConjugateReturnType>::type MatrixConjugateReturnType; enum { Mode = internal::traits<SelfAdjointView>::Mode, - Flags = internal::traits<SelfAdjointView>::Flags + Flags = internal::traits<SelfAdjointView>::Flags, + TransposeMode = ((Mode & Upper) ? Lower : 0) | ((Mode & Lower) ? Upper : 0) }; typedef typename MatrixType::PlainObject PlainObject; @@ -187,6 +189,36 @@ template<typename _MatrixType, unsigned int UpLo> class SelfAdjointView TriangularView<typename MatrixType::AdjointReturnType,TriMode> >::type(tmp2); } + typedef SelfAdjointView<const MatrixConjugateReturnType,Mode> ConjugateReturnType; + /** \sa MatrixBase::conjugate() const */ + EIGEN_DEVICE_FUNC + inline const ConjugateReturnType conjugate() const + { return ConjugateReturnType(m_matrix.conjugate()); } + + typedef SelfAdjointView<const typename MatrixType::AdjointReturnType,TransposeMode> AdjointReturnType; + /** \sa MatrixBase::adjoint() const */ + EIGEN_DEVICE_FUNC + inline const AdjointReturnType adjoint() const + { return AdjointReturnType(m_matrix.adjoint()); } + + typedef SelfAdjointView<typename MatrixType::TransposeReturnType,TransposeMode> TransposeReturnType; + /** \sa MatrixBase::transpose() */ + EIGEN_DEVICE_FUNC + inline TransposeReturnType transpose() + { + EIGEN_STATIC_ASSERT_LVALUE(MatrixType) + typename MatrixType::TransposeReturnType tmp(m_matrix); + return TransposeReturnType(tmp); + } + + typedef SelfAdjointView<const typename MatrixType::ConstTransposeReturnType,TransposeMode> ConstTransposeReturnType; + /** \sa MatrixBase::transpose() const */ + EIGEN_DEVICE_FUNC + inline const ConstTransposeReturnType transpose() const + { + return ConstTransposeReturnType(m_matrix.transpose()); + } + /** \returns a const expression of the main diagonal of the matrix \c *this * * This method simply returns the diagonal of the nested expression, thus by-passing the SelfAdjointView decorator. diff --git a/test/product_symm.cpp b/test/product_symm.cpp index 74d7329b1..8c44383f9 100644 --- a/test/product_symm.cpp +++ b/test/product_symm.cpp @@ -39,6 +39,24 @@ template<typename Scalar, int Size, int OtherSize> void symm(int size = Size, in VERIFY_IS_APPROX(rhs12 = (s1*m2).template selfadjointView<Lower>() * (s2*rhs1), rhs13 = (s1*m1) * (s2*rhs1)); + VERIFY_IS_APPROX(rhs12 = (s1*m2).transpose().template selfadjointView<Upper>() * (s2*rhs1), + rhs13 = (s1*m1.transpose()) * (s2*rhs1)); + + VERIFY_IS_APPROX(rhs12 = (s1*m2).template selfadjointView<Lower>().transpose() * (s2*rhs1), + rhs13 = (s1*m1.transpose()) * (s2*rhs1)); + + VERIFY_IS_APPROX(rhs12 = (s1*m2).conjugate().template selfadjointView<Lower>() * (s2*rhs1), + rhs13 = (s1*m1).conjugate() * (s2*rhs1)); + + VERIFY_IS_APPROX(rhs12 = (s1*m2).template selfadjointView<Lower>().conjugate() * (s2*rhs1), + rhs13 = (s1*m1).conjugate() * (s2*rhs1)); + + VERIFY_IS_APPROX(rhs12 = (s1*m2).adjoint().template selfadjointView<Upper>() * (s2*rhs1), + rhs13 = (s1*m1).adjoint() * (s2*rhs1)); + + VERIFY_IS_APPROX(rhs12 = (s1*m2).template selfadjointView<Lower>().adjoint() * (s2*rhs1), + rhs13 = (s1*m1).adjoint() * (s2*rhs1)); + m2 = m1.template triangularView<Upper>(); rhs12.setRandom(); rhs13 = rhs12; m3 = m2.template selfadjointView<Upper>(); VERIFY_IS_EQUAL(m1, m3); |