diff options
author | Benoit Jacob <jacob.benoit.1@gmail.com> | 2008-08-05 21:55:57 +0000 |
---|---|---|
committer | Benoit Jacob <jacob.benoit.1@gmail.com> | 2008-08-05 21:55:57 +0000 |
commit | 88bb2087c100d21f40546829759ebd58d14947ff (patch) | |
tree | ad83966c8bcafbce92a740b5b2704a49e47b810d /Eigen/src | |
parent | c94be35bc816ecd963d6d2f93b0f1fa67009e3d2 (diff) |
New implementation of Swap as discussed, reusing Assign. Makes LU run
10% faster overall.
Diffstat (limited to 'Eigen/src')
-rw-r--r-- | Eigen/src/Core/Swap.h | 136 | ||||
-rw-r--r-- | Eigen/src/Core/util/ForwardDeclarations.h | 1 |
2 files changed, 90 insertions, 47 deletions
diff --git a/Eigen/src/Core/Swap.h b/Eigen/src/Core/Swap.h index 3b864789e..0ee57017e 100644 --- a/Eigen/src/Core/Swap.h +++ b/Eigen/src/Core/Swap.h @@ -25,8 +25,94 @@ #ifndef EIGEN_SWAP_H #define EIGEN_SWAP_H -template <typename Derived, typename OtherDerived, bool IsVector = Derived::IsVectorAtCompileTime> -struct ei_swap_selector; +/** \class SwapWrapper + * + * \brief Expression which must be nested by value + * + * \param ExpressionType the type of the object of which we are requiring nesting-by-value + * + * This class is the return type of MatrixBase::nestByValue() + * and most of the time this is the only way it is used. + * + * \sa MatrixBase::nestByValue() + */ +template<typename ExpressionType> +struct ei_traits<SwapWrapper<ExpressionType> > +{ + typedef typename ExpressionType::Scalar Scalar; + enum { + RowsAtCompileTime = ExpressionType::RowsAtCompileTime, + ColsAtCompileTime = ExpressionType::ColsAtCompileTime, + MaxRowsAtCompileTime = ExpressionType::MaxRowsAtCompileTime, + MaxColsAtCompileTime = ExpressionType::MaxColsAtCompileTime, + Flags = ExpressionType::Flags, + CoeffReadCost = ExpressionType::CoeffReadCost + }; +}; + +template<typename ExpressionType> class SwapWrapper + : public MatrixBase<SwapWrapper<ExpressionType> > +{ + public: + + EIGEN_GENERIC_PUBLIC_INTERFACE(SwapWrapper) + typedef typename ei_packet_traits<Scalar>::type Packet; + + inline SwapWrapper(ExpressionType& matrix) : m_expression(matrix) {} + + inline int rows() const { return m_expression.rows(); } + inline int cols() const { return m_expression.cols(); } + inline int stride() const { return m_expression.stride(); } + + template<typename OtherDerived> + void copyCoeff(int row, int col, const MatrixBase<OtherDerived>& other) + { + OtherDerived& _other = other.const_cast_derived(); + ei_internal_assert(row >= 0 && row < rows() + && col >= 0 && col < cols()); + Scalar tmp = m_expression.coeff(row, col); + m_expression.coeffRef(row, col) = _other.coeff(row, col); + _other.coeffRef(row, col) = tmp; + } + + template<typename OtherDerived> + void copyCoeff(int index, const MatrixBase<OtherDerived>& other) + { + OtherDerived& _other = other.const_cast_derived(); + ei_internal_assert(index >= 0 && index < m_expression.size()); + Scalar tmp = m_expression.coeff(index); + m_expression.coeffRef(index) = _other.coeff(index); + _other.coeffRef(index) = tmp; + } + + template<typename OtherDerived, int LoadStoreMode> + void copyPacket(int row, int col, const MatrixBase<OtherDerived>& other) + { + OtherDerived& _other = other.const_cast_derived(); + ei_internal_assert(row >= 0 && row < rows() + && col >= 0 && col < cols()); + Packet tmp = m_expression.template packet<LoadStoreMode>(row, col); + m_expression.template writePacket<LoadStoreMode>(row, col, + _other.template packet<LoadStoreMode>(row, col) + ); + _other.template writePacket<LoadStoreMode>(row, col, tmp); + } + + template<typename OtherDerived, int LoadStoreMode> + void copyPacket(int index, const MatrixBase<OtherDerived>& other) + { + OtherDerived& _other = other.const_cast_derived(); + ei_internal_assert(index >= 0 && index < m_expression.size()); + Packet tmp = m_expression.template packet<LoadStoreMode>(index); + m_expression.template writePacket<LoadStoreMode>(index, + _other.template packet<LoadStoreMode>(index) + ); + _other.template writePacket<LoadStoreMode>(index, tmp); + } + + protected: + ExpressionType m_expression; +}; /** swaps *this with the expression \a other. * @@ -41,51 +127,7 @@ template<typename Derived> template<typename OtherDerived> void MatrixBase<Derived>::swap(const MatrixBase<OtherDerived>& other) { - MatrixBase<OtherDerived> *_other = const_cast<MatrixBase<OtherDerived>*>(&other); - - // disable that path: it makes LU decomposition fail ! I can't see the bug though. - if(false /*SizeAtCompileTime == Dynamic*/) - { - ei_swap_selector<Derived,OtherDerived>::run(derived(),other.const_cast_derived()); - } - else // SizeAtCompileTime != Dynamic - { - typename Derived::Eval buf(*this); - *this = other; - *_other = buf; - } + SwapWrapper<Derived>(derived()).lazyAssign(other); } -template<typename Derived, typename OtherDerived> -struct ei_swap_selector<Derived,OtherDerived,true> -{ - inline static void run(Derived& src, OtherDerived& other) - { - typename Derived::Scalar tmp; - ei_assert(OtherDerived::IsVectorAtCompileTime && src.size() == other.size()); - for(int i = 0; i < src.size(); i++) - { - tmp = src.coeff(i); - src.coeffRef(i) = other.coeff(i); - other.coeffRef(i) = tmp; - } - } -}; - -template<typename Derived, typename OtherDerived> -struct ei_swap_selector<Derived,OtherDerived,false> -{ - inline static void run(Derived& src, OtherDerived& other) - { - typename Derived::Scalar tmp; - for(int j = 0; j < src.cols(); j++) - for(int i = 0; i < src.rows(); i++) - { - tmp = src.coeff(i, j); - src.coeffRef(i, j) = other.coeff(i, j); - other.coeffRef(i, j) = tmp; - } - } -}; - #endif // EIGEN_SWAP_H diff --git a/Eigen/src/Core/util/ForwardDeclarations.h b/Eigen/src/Core/util/ForwardDeclarations.h index 067ccd0b0..a886a90d0 100644 --- a/Eigen/src/Core/util/ForwardDeclarations.h +++ b/Eigen/src/Core/util/ForwardDeclarations.h @@ -41,6 +41,7 @@ class Matrix; template<typename ExpressionType, unsigned int Added, unsigned int Removed> class Flagged; template<typename ExpressionType> class NestByValue; +template<typename ExpressionType> class SwapWrapper; template<typename MatrixType> class Minor; template<typename MatrixType, int BlockRows=Dynamic, int BlockCols=Dynamic, int DirectAccessStatus = ei_traits<MatrixType>::Flags&DirectAccessBit> class Block; |