diff options
author | Benoit Jacob <jacob.benoit.1@gmail.com> | 2011-01-25 21:22:04 -0500 |
---|---|---|
committer | Benoit Jacob <jacob.benoit.1@gmail.com> | 2011-01-25 21:22:04 -0500 |
commit | 1d98cc5e5da88254c784c4f02517bf5a47f007bc (patch) | |
tree | de9bc9696892cf0bfa35a820b19b34524f4c954d | |
parent | 4fbadfd230af9859c5e82c09a4974396af45473f (diff) |
eigen2 support: implement part<SelfAdjoint>, mimic eigen2 behavior braindeadness-for-braindeadness
-rw-r--r-- | Eigen/src/Core/MatrixBase.h | 8 | ||||
-rw-r--r-- | Eigen/src/Core/SelfAdjointView.h | 24 | ||||
-rw-r--r-- | Eigen/src/Core/TriangularMatrix.h | 45 | ||||
-rw-r--r-- | Eigen/src/Core/util/ForwardDeclarations.h | 3 | ||||
-rw-r--r-- | test/eigen2/eigen2_triangular.cpp | 29 |
5 files changed, 92 insertions, 17 deletions
diff --git a/Eigen/src/Core/MatrixBase.h b/Eigen/src/Core/MatrixBase.h index da4af6bfd..fbdc059cf 100644 --- a/Eigen/src/Core/MatrixBase.h +++ b/Eigen/src/Core/MatrixBase.h @@ -242,16 +242,16 @@ template<typename Derived> class MatrixBase typename MatrixBase::template DiagonalIndexReturnType<Dynamic>::Type diagonal(Index index); typename MatrixBase::template ConstDiagonalIndexReturnType<Dynamic>::Type diagonal(Index index) const; - #ifdef EIGEN2_SUPPORT - template<unsigned int Mode> TriangularView<Derived, Mode> part(); - template<unsigned int Mode> const TriangularView<Derived, Mode> part() const; + //#ifdef EIGEN2_SUPPORT + template<unsigned int Mode> typename internal::eigen2_part_return_type<Derived, Mode>::type part(); + template<unsigned int Mode> const typename internal::eigen2_part_return_type<Derived, Mode>::type part() const; // huuuge hack. make Eigen2's matrix.part<Diagonal>() work in eigen3. Problem: Diagonal is now a class template instead // of an integer constant. Solution: overload the part() method template wrt template parameters list. template<template<typename T, int n> class U> const DiagonalWrapper<ConstDiagonalReturnType> part() const { return diagonal().asDiagonal(); } - #endif // EIGEN2_SUPPORT + //#endif // EIGEN2_SUPPORT template<unsigned int Mode> struct TriangularViewReturnType { typedef TriangularView<Derived, Mode> Type; }; template<unsigned int Mode> struct ConstTriangularViewReturnType { typedef const TriangularView<const Derived, Mode> Type; }; diff --git a/Eigen/src/Core/SelfAdjointView.h b/Eigen/src/Core/SelfAdjointView.h index 5d8468884..92d58b9f8 100644 --- a/Eigen/src/Core/SelfAdjointView.h +++ b/Eigen/src/Core/SelfAdjointView.h @@ -48,6 +48,7 @@ struct traits<SelfAdjointView<MatrixType, UpLo> > : traits<MatrixType> typedef typename nested<MatrixType>::type MatrixTypeNested; typedef typename remove_reference<MatrixTypeNested>::type _MatrixTypeNested; typedef MatrixType ExpressionType; + typedef typename MatrixType::PlainObject DenseMatrixType; enum { Mode = UpLo | SelfAdjoint, Flags = _MatrixTypeNested::Flags & (HereditaryBits) @@ -171,6 +172,29 @@ template<typename MatrixType, unsigned int UpLo> class SelfAdjointView EigenvaluesReturnType eigenvalues() const; RealScalar operatorNorm() const; + + #ifdef EIGEN2_SUPPORT + template<typename OtherDerived> + SelfAdjointView& operator=(const MatrixBase<OtherDerived>& other) + { + enum { + OtherPart = UpLo == Upper ? StrictlyLower : StrictlyUpper + }; + m_matrix.const_cast_derived().template triangularView<UpLo>() = other; + m_matrix.const_cast_derived().template triangularView<OtherPart>() = other.adjoint(); + return *this; + } + template<typename OtherMatrixType, unsigned int OtherMode> + SelfAdjointView& operator=(const TriangularView<OtherMatrixType, OtherMode>& other) + { + enum { + OtherPart = UpLo == Upper ? StrictlyLower : StrictlyUpper + }; + m_matrix.const_cast_derived().template triangularView<UpLo>() = other.toDenseMatrix(); + m_matrix.const_cast_derived().template triangularView<OtherPart>() = other.toDenseMatrix().adjoint(); + return *this; + } + #endif protected: const typename MatrixType::Nested m_matrix; diff --git a/Eigen/src/Core/TriangularMatrix.h b/Eigen/src/Core/TriangularMatrix.h index ce5b53631..714d56a5b 100644 --- a/Eigen/src/Core/TriangularMatrix.h +++ b/Eigen/src/Core/TriangularMatrix.h @@ -48,6 +48,7 @@ template<typename Derived> class TriangularBase : public EigenBase<Derived> typedef typename internal::traits<Derived>::Scalar Scalar; typedef typename internal::traits<Derived>::StorageKind StorageKind; typedef typename internal::traits<Derived>::Index Index; + typedef typename internal::traits<Derived>::DenseMatrixType DenseMatrixType; inline TriangularBase() { eigen_assert(!((Mode&UnitDiag) && (Mode&ZeroDiag))); } @@ -88,6 +89,13 @@ template<typename Derived> class TriangularBase : public EigenBase<Derived> template<typename DenseDerived> void evalToLazy(MatrixBase<DenseDerived> &other) const; + DenseMatrixType toDenseMatrix() const + { + DenseMatrixType res(rows(), cols()); + evalToLazy(res); + return res; + } + protected: void check_coordinates(Index row, Index col) const @@ -137,6 +145,7 @@ struct traits<TriangularView<MatrixType, _Mode> > : traits<MatrixType> typedef typename nested<MatrixType>::type MatrixTypeNested; typedef typename remove_reference<MatrixTypeNested>::type _MatrixTypeNested; typedef MatrixType ExpressionType; + typedef typename MatrixType::PlainObject DenseMatrixType; enum { Mode = _Mode, Flags = (_MatrixTypeNested::Flags & (HereditaryBits) & (~(PacketAccessBit | DirectAccessBit | LinearAccessBit))) | Mode, @@ -159,7 +168,7 @@ template<typename _MatrixType, unsigned int _Mode> class TriangularView typedef typename internal::traits<TriangularView>::Scalar Scalar; typedef _MatrixType MatrixType; - typedef typename MatrixType::PlainObject DenseMatrixType; + typedef typename internal::traits<TriangularView>::DenseMatrixType DenseMatrixType; protected: typedef typename MatrixType::Nested MatrixTypeNested; @@ -269,13 +278,6 @@ template<typename _MatrixType, unsigned int _Mode> class TriangularView inline const TriangularView<Transpose<MatrixType>,TransposeMode> transpose() const { return m_matrix.transpose(); } - DenseMatrixType toDenseMatrix() const - { - DenseMatrixType res(rows(), cols()); - evalToLazy(res); - return res; - } - /** Efficient triangular matrix times vector/matrix product */ template<typename OtherDerived> TriangularProduct<Mode,true,MatrixType,false,OtherDerived,OtherDerived::IsVectorAtCompileTime> @@ -310,18 +312,18 @@ template<typename _MatrixType, unsigned int _Mode> class TriangularView const typename eigen2_product_return_type<OtherMatrixType>::type operator*(const TriangularView<OtherMatrixType, Mode>& rhs) const { - return toDenseMatrix() * rhs.toDenseMatrix(); + return this->toDenseMatrix() * rhs.toDenseMatrix(); } template<typename OtherMatrixType> bool isApprox(const TriangularView<OtherMatrixType, Mode>& other, typename NumTraits<Scalar>::Real precision = NumTraits<Scalar>::dummy_precision()) const { - return toDenseMatrix().isApprox(other.toDenseMatrix(), precision); + return this->toDenseMatrix().isApprox(other.toDenseMatrix(), precision); } template<typename OtherDerived> bool isApprox(const MatrixBase<OtherDerived>& other, typename NumTraits<Scalar>::Real precision = NumTraits<Scalar>::dummy_precision()) const { - return toDenseMatrix().isApprox(other, precision); + return this->toDenseMatrix().isApprox(other, precision); } #endif // EIGEN2_SUPPORT @@ -707,10 +709,27 @@ void TriangularBase<Derived>::evalToLazy(MatrixBase<DenseDerived> &other) const ***************************************************************************/ #ifdef EIGEN2_SUPPORT + +// implementation of part<>(), including the SelfAdjoint case. + +namespace internal { +template<typename MatrixType, unsigned int Mode> +struct eigen2_part_return_type +{ + typedef TriangularView<MatrixType, Mode> type; +}; + +template<typename MatrixType> +struct eigen2_part_return_type<MatrixType, SelfAdjoint> +{ + typedef SelfAdjointView<MatrixType, Upper> type; +}; +} + /** \deprecated use MatrixBase::triangularView() */ template<typename Derived> template<unsigned int Mode> -const TriangularView<Derived, Mode> MatrixBase<Derived>::part() const +const typename internal::eigen2_part_return_type<Derived, Mode>::type MatrixBase<Derived>::part() const { return derived(); } @@ -718,7 +737,7 @@ const TriangularView<Derived, Mode> MatrixBase<Derived>::part() const /** \deprecated use MatrixBase::triangularView() */ template<typename Derived> template<unsigned int Mode> -TriangularView<Derived, Mode> MatrixBase<Derived>::part() +typename internal::eigen2_part_return_type<Derived, Mode>::type MatrixBase<Derived>::part() { return derived(); } diff --git a/Eigen/src/Core/util/ForwardDeclarations.h b/Eigen/src/Core/util/ForwardDeclarations.h index 578f8d8e6..548da3986 100644 --- a/Eigen/src/Core/util/ForwardDeclarations.h +++ b/Eigen/src/Core/util/ForwardDeclarations.h @@ -268,6 +268,9 @@ template<typename ExpressionType> class Cwise; template<typename MatrixType> class Minor; template<typename MatrixType> class LU; template<typename MatrixType> class QR; +namespace internal { +template<typename MatrixType, unsigned int Mode> struct eigen2_part_return_type; +} #endif #endif // EIGEN_FORWARDDECLARATIONS_H diff --git a/test/eigen2/eigen2_triangular.cpp b/test/eigen2/eigen2_triangular.cpp index c81fad0da..43b42e3a5 100644 --- a/test/eigen2/eigen2_triangular.cpp +++ b/test/eigen2/eigen2_triangular.cpp @@ -124,8 +124,37 @@ template<typename MatrixType> void triangular(const MatrixType& m) } +void selfadjoint() +{ + Matrix2i m; + m << 1, 2, + 3, 4; + + Matrix2i m1 = Matrix2i::Zero(); + m1.part<SelfAdjoint>() = m; + Matrix2i ref1; + ref1 << 1, 2, + 2, 4; + VERIFY(m1 == ref1); + + Matrix2i m2 = Matrix2i::Zero(); + m2.part<SelfAdjoint>() = m.part<UpperTriangular>(); + Matrix2i ref2; + ref2 << 1, 2, + 2, 4; + VERIFY(m2 == ref2); + + Matrix2i m3 = Matrix2i::Zero(); + m3.part<SelfAdjoint>() = m.part<LowerTriangular>(); + Matrix2i ref3; + ref3 << 1, 0, + 0, 4; + VERIFY(m3 == ref3); +} + void test_eigen2_triangular() { + CALL_SUBTEST_8( selfadjoint() ); for(int i = 0; i < g_repeat ; i++) { CALL_SUBTEST_1( triangular(Matrix<float, 1, 1>()) ); CALL_SUBTEST_2( triangular(Matrix<float, 2, 2>()) ); |