diff options
author | Gael Guennebaud <g.gael@free.fr> | 2009-07-31 17:35:55 +0200 |
---|---|---|
committer | Gael Guennebaud <g.gael@free.fr> | 2009-07-31 17:35:55 +0200 |
commit | 18429156a145c1adddcb313512f9f1179a9141cf (patch) | |
tree | abedf8b3755237cb584df2fea27fa0b917032d99 | |
parent | 2796bcabb1151ad8de2bf2ab9117baea40ae4d30 (diff) |
add selfadjointView from a trinagularView
-rw-r--r-- | Eigen/src/Core/TriangularMatrix.h | 21 | ||||
-rw-r--r-- | test/product_trsm.cpp | 65 |
2 files changed, 36 insertions, 50 deletions
diff --git a/Eigen/src/Core/TriangularMatrix.h b/Eigen/src/Core/TriangularMatrix.h index a41adb190..8b6c9a23b 100644 --- a/Eigen/src/Core/TriangularMatrix.h +++ b/Eigen/src/Core/TriangularMatrix.h @@ -156,7 +156,9 @@ template<typename _MatrixType, unsigned int _Mode> class TriangularView typedef typename ei_traits<TriangularView>::Scalar Scalar; typedef _MatrixType MatrixType; typedef typename MatrixType::PlainMatrixType PlainMatrixType; - + typedef typename MatrixType::Nested MatrixTypeNested; + typedef typename ei_cleantype<MatrixTypeNested>::type _MatrixTypeNested; + enum { Mode = _Mode, TransposeMode = (Mode & UpperTriangularBit ? LowerTriangularBit : 0) @@ -286,6 +288,17 @@ template<typename _MatrixType, unsigned int _Mode> class TriangularView void solveInPlace(const MatrixBase<OtherDerived>& other) const { return solveInPlace<OnTheLeft>(other); } + const SelfAdjointView<_MatrixTypeNested,Mode> selfadjointView() const + { + EIGEN_STATIC_ASSERT((Mode&UnitDiagBit)==0,PROGRAMMING_ERROR); + return SelfAdjointView<_MatrixTypeNested,Mode>(m_matrix); + } + SelfAdjointView<_MatrixTypeNested,Mode> selfadjointView() + { + EIGEN_STATIC_ASSERT((Mode&UnitDiagBit)==0,PROGRAMMING_ERROR); + return SelfAdjointView<_MatrixTypeNested,Mode>(m_matrix); + } + template<typename OtherDerived> void swap(const TriangularBase<OtherDerived>& other) { @@ -300,7 +313,7 @@ template<typename _MatrixType, unsigned int _Mode> class TriangularView protected: - const typename MatrixType::Nested m_matrix; + const MatrixTypeNested m_matrix; }; /*************************************************************************** @@ -563,6 +576,10 @@ void TriangularBase<Derived>::evalToDenseLazy(MatrixBase<DenseDerived> &other) c } /*************************************************************************** +* Implementation of TriangularView methods +***************************************************************************/ + +/*************************************************************************** * Implementation of MatrixBase methods ***************************************************************************/ diff --git a/test/product_trsm.cpp b/test/product_trsm.cpp index bda158048..4f0fd15be 100644 --- a/test/product_trsm.cpp +++ b/test/product_trsm.cpp @@ -24,12 +24,11 @@ #include "main.h" -template<typename Lhs, typename Rhs> -void solve_ref(const Lhs& lhs, Rhs& rhs) -{ - for (int j=0; j<rhs.cols(); ++j) - lhs.solveInPlace(rhs.col(j)); -} +#define VERIFY_TRSM(TRI,XB) { \ + XB.setRandom(); ref = XB; \ + TRI.template solveInPlace(XB); \ + VERIFY_IS_APPROX(TRI.toDense() * XB, ref); \ + } template<typename Scalar> void trsm(int size,int cols) { @@ -37,53 +36,23 @@ template<typename Scalar> void trsm(int size,int cols) Matrix<Scalar,Dynamic,Dynamic,ColMajor> cmLhs(size,size); Matrix<Scalar,Dynamic,Dynamic,RowMajor> rmLhs(size,size); - - Matrix<Scalar,Dynamic,Dynamic,ColMajor> cmRef(size,cols), cmRhs(size,cols); - Matrix<Scalar,Dynamic,Dynamic,RowMajor> rmRef(size,cols), rmRhs(size,cols); - - cmLhs.setRandom(); cmLhs.diagonal().cwise() += 10; - rmLhs.setRandom(); rmLhs.diagonal().cwise() += 10; - - cmRhs.setRandom(); cmRef = cmRhs; - cmLhs.conjugate().template triangularView<LowerTriangular>().solveInPlace(cmRhs); - solve_ref(cmLhs.conjugate().template triangularView<LowerTriangular>(),cmRef); - VERIFY_IS_APPROX(cmRhs, cmRef); - - cmRhs.setRandom(); cmRef = cmRhs; - cmLhs.conjugate().template triangularView<UpperTriangular>().solveInPlace(cmRhs); - solve_ref(cmLhs.conjugate().template triangularView<UpperTriangular>(),cmRef); - VERIFY_IS_APPROX(cmRhs, cmRef); - - rmRhs.setRandom(); rmRef = rmRhs; - cmLhs.template triangularView<LowerTriangular>().solveInPlace(rmRhs); - solve_ref(cmLhs.template triangularView<LowerTriangular>(),rmRef); - VERIFY_IS_APPROX(rmRhs, rmRef); - - rmRhs.setRandom(); rmRef = rmRhs; - cmLhs.template triangularView<UpperTriangular>().solveInPlace(rmRhs); - solve_ref(cmLhs.template triangularView<UpperTriangular>(),rmRef); - VERIFY_IS_APPROX(rmRhs, rmRef); + Matrix<Scalar,Dynamic,Dynamic,ColMajor> cmRhs(size,cols), ref(size,cols); + Matrix<Scalar,Dynamic,Dynamic,RowMajor> rmRhs(size,cols); - cmRhs.setRandom(); cmRef = cmRhs; - rmLhs.template triangularView<UnitLowerTriangular>().solveInPlace(cmRhs); - solve_ref(rmLhs.template triangularView<UnitLowerTriangular>(),cmRef); - VERIFY_IS_APPROX(cmRhs, cmRef); + cmLhs.setRandom(); cmLhs *= 0.1; cmLhs.diagonal().cwise() += 1; + rmLhs.setRandom(); rmLhs *= 0.1; rmLhs.diagonal().cwise() += 1; - cmRhs.setRandom(); cmRef = cmRhs; - rmLhs.template triangularView<UnitUpperTriangular>().solveInPlace(cmRhs); - solve_ref(rmLhs.template triangularView<UnitUpperTriangular>(),cmRef); - VERIFY_IS_APPROX(cmRhs, cmRef); + VERIFY_TRSM(cmLhs.conjugate().template triangularView<LowerTriangular>(), cmRhs); + VERIFY_TRSM(cmLhs .template triangularView<UpperTriangular>(), cmRhs); + VERIFY_TRSM(cmLhs .template triangularView<LowerTriangular>(), rmRhs); + VERIFY_TRSM(cmLhs.conjugate().template triangularView<UpperTriangular>(), rmRhs); - rmRhs.setRandom(); rmRef = rmRhs; - rmLhs.template triangularView<LowerTriangular>().solveInPlace(rmRhs); - solve_ref(rmLhs.template triangularView<LowerTriangular>(),rmRef); - VERIFY_IS_APPROX(rmRhs, rmRef); + VERIFY_TRSM(cmLhs.conjugate().template triangularView<UnitLowerTriangular>(), cmRhs); + VERIFY_TRSM(cmLhs .template triangularView<UnitUpperTriangular>(), rmRhs); - rmRhs.setRandom(); rmRef = rmRhs; - rmLhs.template triangularView<UpperTriangular>().solveInPlace(rmRhs); - solve_ref(rmLhs.template triangularView<UpperTriangular>(),rmRef); - VERIFY_IS_APPROX(rmRhs, rmRef); + VERIFY_TRSM(rmLhs .template triangularView<LowerTriangular>(), cmRhs); + VERIFY_TRSM(rmLhs.conjugate().template triangularView<UnitUpperTriangular>(), rmRhs); } void test_product_trsm() |