diff options
author | Gael Guennebaud <g.gael@free.fr> | 2015-06-19 11:50:24 +0200 |
---|---|---|
committer | Gael Guennebaud <g.gael@free.fr> | 2015-06-19 11:50:24 +0200 |
commit | 3af4c6c1c9327411d13386e4719ce48f866c7567 (patch) | |
tree | 1cc76e62ae2eaba007b658526534096abe540d43 /Eigen/src/Core/Transpositions.h | |
parent | 82b6ac08646f7b12770665134acaf3bb3cdc4dd3 (diff) |
Make Transpositions use evaluators
Diffstat (limited to 'Eigen/src/Core/Transpositions.h')
-rw-r--r-- | Eigen/src/Core/Transpositions.h | 127 |
1 files changed, 54 insertions, 73 deletions
diff --git a/Eigen/src/Core/Transpositions.h b/Eigen/src/Core/Transpositions.h index b08df1ead..dad4f56c9 100644 --- a/Eigen/src/Core/Transpositions.h +++ b/Eigen/src/Core/Transpositions.h @@ -41,10 +41,6 @@ namespace Eigen { * \sa class PermutationMatrix */ -namespace internal { -template<typename TranspositionType, typename MatrixType, int Side, bool Transposed=false> struct transposition_matrix_product_retval; -} - template<typename Derived> class TranspositionsBase { @@ -325,77 +321,32 @@ class TranspositionsWrapper const typename IndicesType::Nested m_indices; }; + + /** \returns the \a matrix with the \a transpositions applied to the columns. */ -template<typename Derived, typename TranspositionsDerived> -inline const internal::transposition_matrix_product_retval<TranspositionsDerived, Derived, OnTheRight> -operator*(const MatrixBase<Derived>& matrix, - const TranspositionsBase<TranspositionsDerived> &transpositions) +template<typename MatrixDerived, typename TranspositionsDerived> +EIGEN_DEVICE_FUNC +const Product<MatrixDerived, TranspositionsDerived, DefaultProduct> +operator*(const MatrixBase<MatrixDerived> &matrix, + const TranspositionsBase<TranspositionsDerived>& transpositions) { - return internal::transposition_matrix_product_retval - <TranspositionsDerived, Derived, OnTheRight> - (transpositions.derived(), matrix.derived()); + return Product<MatrixDerived, TranspositionsDerived, DefaultProduct> + (matrix.derived(), transpositions.derived()); } /** \returns the \a matrix with the \a transpositions applied to the rows. */ -template<typename Derived, typename TranspositionDerived> -inline const internal::transposition_matrix_product_retval - <TranspositionDerived, Derived, OnTheLeft> -operator*(const TranspositionsBase<TranspositionDerived> &transpositions, - const MatrixBase<Derived>& matrix) +template<typename TranspositionsDerived, typename MatrixDerived> +EIGEN_DEVICE_FUNC +const Product<TranspositionsDerived, MatrixDerived, DefaultProduct> +operator*(const TranspositionsBase<TranspositionsDerived> &transpositions, + const MatrixBase<MatrixDerived>& matrix) { - return internal::transposition_matrix_product_retval - <TranspositionDerived, Derived, OnTheLeft> - (transpositions.derived(), matrix.derived()); + return Product<TranspositionsDerived, MatrixDerived, DefaultProduct> + (transpositions.derived(), matrix.derived()); } -namespace internal { - -template<typename TranspositionType, typename MatrixType, int Side, bool Transposed> -struct traits<transposition_matrix_product_retval<TranspositionType, MatrixType, Side, Transposed> > -{ - typedef typename MatrixType::PlainObject ReturnType; -}; - -template<typename TranspositionType, typename MatrixType, int Side, bool Transposed> -struct transposition_matrix_product_retval - : public ReturnByValue<transposition_matrix_product_retval<TranspositionType, MatrixType, Side, Transposed> > -{ - typedef typename remove_all<typename MatrixType::Nested>::type MatrixTypeNestedCleaned; - typedef typename TranspositionType::StorageIndex StorageIndex; - - transposition_matrix_product_retval(const TranspositionType& tr, const MatrixType& matrix) - : m_transpositions(tr), m_matrix(matrix) - {} - - inline Index rows() const { return m_matrix.rows(); } - inline Index cols() const { return m_matrix.cols(); } - - template<typename Dest> inline void evalTo(Dest& dst) const - { - const Index size = m_transpositions.size(); - StorageIndex j = 0; - - if(!(is_same<MatrixTypeNestedCleaned,Dest>::value && extract_data(dst) == extract_data(m_matrix))) - dst = m_matrix; - - for(Index k=(Transposed?size-1:0) ; Transposed?k>=0:k<size ; Transposed?--k:++k) - if(Index(j=m_transpositions.coeff(k))!=k) - { - if(Side==OnTheLeft) - dst.row(k).swap(dst.row(j)); - else if(Side==OnTheRight) - dst.col(k).swap(dst.col(j)); - } - } - - protected: - const TranspositionType& m_transpositions; - typename MatrixType::Nested m_matrix; -}; - -} // end namespace internal /* Template partial specialization for transposed/inverse transpositions */ @@ -412,26 +363,56 @@ class Transpose<TranspositionsBase<TranspositionsDerived> > /** \returns the \a matrix with the inverse transpositions applied to the columns. */ - template<typename Derived> friend - inline const internal::transposition_matrix_product_retval<TranspositionType, Derived, OnTheRight, true> - operator*(const MatrixBase<Derived>& matrix, const Transpose& trt) + template<typename OtherDerived> friend + const Product<OtherDerived, Transpose, DefaultProduct> + operator*(const MatrixBase<OtherDerived>& matrix, const Transpose& trt) { - return internal::transposition_matrix_product_retval<TranspositionType, Derived, OnTheRight, true>(trt.m_transpositions, matrix.derived()); + return Product<OtherDerived, Transpose, DefaultProduct>(matrix.derived(), trt.derived()); } /** \returns the \a matrix with the inverse transpositions applied to the rows. */ - template<typename Derived> - inline const internal::transposition_matrix_product_retval<TranspositionType, Derived, OnTheLeft, true> - operator*(const MatrixBase<Derived>& matrix) const + template<typename OtherDerived> + const Product<Transpose, OtherDerived, DefaultProduct> + operator*(const MatrixBase<OtherDerived>& matrix) const { - return internal::transposition_matrix_product_retval<TranspositionType, Derived, OnTheLeft, true>(m_transpositions, matrix.derived()); + return Product<Transpose, OtherDerived, DefaultProduct>(*this, matrix.derived()); } protected: const TranspositionType& m_transpositions; }; +namespace internal { + +// TODO currently a Transpositions expression has the form Transpositions or TranspositionsWrapper +// or their transpose; in the future shape should be defined by the expression traits +template<int SizeAtCompileTime, int MaxSizeAtCompileTime, typename IndexType> +struct evaluator_traits<Transpositions<SizeAtCompileTime, MaxSizeAtCompileTime, IndexType> > +{ + typedef typename storage_kind_to_evaluator_kind<Dense>::Kind Kind; + typedef TranspositionsShape Shape; + static const int AssumeAliasing = 0; +}; + +template<typename IndicesType> +struct evaluator_traits<TranspositionsWrapper<IndicesType> > +{ + typedef typename storage_kind_to_evaluator_kind<Dense>::Kind Kind; + typedef TranspositionsShape Shape; + static const int AssumeAliasing = 0; +}; + +template<typename Derived> +struct evaluator_traits<Transpose<TranspositionsBase<Derived> > > +{ + typedef typename storage_kind_to_evaluator_kind<Dense>::Kind Kind; + typedef TranspositionsShape Shape; + static const int AssumeAliasing = 0; +}; + +} // end namespace internal + } // end namespace Eigen #endif // EIGEN_TRANSPOSITIONS_H |