diff options
author | Gael Guennebaud <g.gael@free.fr> | 2015-06-19 11:50:24 +0200 |
---|---|---|
committer | Gael Guennebaud <g.gael@free.fr> | 2015-06-19 11:50:24 +0200 |
commit | 3af4c6c1c9327411d13386e4719ce48f866c7567 (patch) | |
tree | 1cc76e62ae2eaba007b658526534096abe540d43 /Eigen | |
parent | 82b6ac08646f7b12770665134acaf3bb3cdc4dd3 (diff) |
Make Transpositions use evaluators
Diffstat (limited to 'Eigen')
-rw-r--r-- | Eigen/src/Core/PermutationMatrix.h | 3 | ||||
-rw-r--r-- | Eigen/src/Core/ProductEvaluators.h | 73 | ||||
-rw-r--r-- | Eigen/src/Core/Transpositions.h | 127 | ||||
-rw-r--r-- | Eigen/src/Core/util/Constants.h | 1 |
4 files changed, 128 insertions, 76 deletions
diff --git a/Eigen/src/Core/PermutationMatrix.h b/Eigen/src/Core/PermutationMatrix.h index 9a0c03612..8c9afd4ee 100644 --- a/Eigen/src/Core/PermutationMatrix.h +++ b/Eigen/src/Core/PermutationMatrix.h @@ -537,9 +537,6 @@ class PermutationWrapper : public PermutationBase<PermutationWrapper<_IndicesTyp }; -// TODO: Do we need to define these operator* functions? Would it be better to have them inherited -// from MatrixBase? - /** \returns the matrix with the permutation applied to the columns. */ template<typename MatrixDerived, typename PermutationDerived> diff --git a/Eigen/src/Core/ProductEvaluators.h b/Eigen/src/Core/ProductEvaluators.h index 9d1cb5d56..4c673a6cb 100644 --- a/Eigen/src/Core/ProductEvaluators.h +++ b/Eigen/src/Core/ProductEvaluators.h @@ -929,6 +929,79 @@ struct generic_product_impl<Lhs, Transpose<Rhs>, MatrixShape, PermutationShape, } }; + +/*************************************************************************** +* Products with transpositions matrices +***************************************************************************/ + +// FIXME could we unify Transpositions and Permutation into a single "shape"?? + +/** \internal + * \class transposition_matrix_product + * Internal helper class implementing the product between a permutation matrix and a matrix. + */ +template<typename MatrixType, int Side, bool Transposed, typename MatrixShape> +struct transposition_matrix_product +{ + template<typename Dest, typename TranspositionType> + static inline void evalTo(Dest& dst, const TranspositionType& tr, const MatrixType& mat) + { + typedef typename TranspositionType::StorageIndex StorageIndex; + const Index size = tr.size(); + StorageIndex j = 0; + + if(!(is_same<MatrixType,Dest>::value && extract_data(dst) == extract_data(mat))) + dst = mat; + + for(Index k=(Transposed?size-1:0) ; Transposed?k>=0:k<size ; Transposed?--k:++k) + if(Index(j=tr.coeff(k))!=k) + { + if(Side==OnTheLeft) dst.row(k).swap(dst.row(j)); + else if(Side==OnTheRight) dst.col(k).swap(dst.col(j)); + } + } +}; + +template<typename Lhs, typename Rhs, int ProductTag, typename MatrixShape> +struct generic_product_impl<Lhs, Rhs, TranspositionsShape, MatrixShape, ProductTag> +{ + template<typename Dest> + static void evalTo(Dest& dst, const Lhs& lhs, const Rhs& rhs) + { + transposition_matrix_product<Rhs, OnTheLeft, false, MatrixShape>::run(dst, lhs, rhs); + } +}; + +template<typename Lhs, typename Rhs, int ProductTag, typename MatrixShape> +struct generic_product_impl<Lhs, Rhs, MatrixShape, TranspositionsShape, ProductTag> +{ + template<typename Dest> + static void evalTo(Dest& dst, const Lhs& lhs, const Rhs& rhs) + { + transposition_matrix_product<Lhs, OnTheRight, false, MatrixShape>::run(dst, rhs, lhs); + } +}; + +template<typename Lhs, typename Rhs, int ProductTag, typename MatrixShape> +struct generic_product_impl<Transpose<Lhs>, Rhs, TranspositionsShape, MatrixShape, ProductTag> +{ + template<typename Dest> + static void evalTo(Dest& dst, const Transpose<Lhs>& lhs, const Rhs& rhs) + { + transposition_matrix_product<Rhs, OnTheLeft, true, MatrixShape>::run(dst, lhs.nestedPermutation(), rhs); + } +}; + +template<typename Lhs, typename Rhs, int ProductTag, typename MatrixShape> +struct generic_product_impl<Lhs, Transpose<Rhs>, MatrixShape, TranspositionsShape, ProductTag> +{ + template<typename Dest> + static void evalTo(Dest& dst, const Lhs& lhs, const Transpose<Rhs>& rhs) + { + transposition_matrix_product<Lhs, OnTheRight, true, MatrixShape>::run(dst, rhs.nestedPermutation(), lhs); + } +}; + } // end namespace internal } // end namespace Eigen diff --git a/Eigen/src/Core/Transpositions.h b/Eigen/src/Core/Transpositions.h index b08df1ead..dad4f56c9 100644 --- a/Eigen/src/Core/Transpositions.h +++ b/Eigen/src/Core/Transpositions.h @@ -41,10 +41,6 @@ namespace Eigen { * \sa class PermutationMatrix */ -namespace internal { -template<typename TranspositionType, typename MatrixType, int Side, bool Transposed=false> struct transposition_matrix_product_retval; -} - template<typename Derived> class TranspositionsBase { @@ -325,77 +321,32 @@ class TranspositionsWrapper const typename IndicesType::Nested m_indices; }; + + /** \returns the \a matrix with the \a transpositions applied to the columns. */ -template<typename Derived, typename TranspositionsDerived> -inline const internal::transposition_matrix_product_retval<TranspositionsDerived, Derived, OnTheRight> -operator*(const MatrixBase<Derived>& matrix, - const TranspositionsBase<TranspositionsDerived> &transpositions) +template<typename MatrixDerived, typename TranspositionsDerived> +EIGEN_DEVICE_FUNC +const Product<MatrixDerived, TranspositionsDerived, DefaultProduct> +operator*(const MatrixBase<MatrixDerived> &matrix, + const TranspositionsBase<TranspositionsDerived>& transpositions) { - return internal::transposition_matrix_product_retval - <TranspositionsDerived, Derived, OnTheRight> - (transpositions.derived(), matrix.derived()); + return Product<MatrixDerived, TranspositionsDerived, DefaultProduct> + (matrix.derived(), transpositions.derived()); } /** \returns the \a matrix with the \a transpositions applied to the rows. */ -template<typename Derived, typename TranspositionDerived> -inline const internal::transposition_matrix_product_retval - <TranspositionDerived, Derived, OnTheLeft> -operator*(const TranspositionsBase<TranspositionDerived> &transpositions, - const MatrixBase<Derived>& matrix) +template<typename TranspositionsDerived, typename MatrixDerived> +EIGEN_DEVICE_FUNC +const Product<TranspositionsDerived, MatrixDerived, DefaultProduct> +operator*(const TranspositionsBase<TranspositionsDerived> &transpositions, + const MatrixBase<MatrixDerived>& matrix) { - return internal::transposition_matrix_product_retval - <TranspositionDerived, Derived, OnTheLeft> - (transpositions.derived(), matrix.derived()); + return Product<TranspositionsDerived, MatrixDerived, DefaultProduct> + (transpositions.derived(), matrix.derived()); } -namespace internal { - -template<typename TranspositionType, typename MatrixType, int Side, bool Transposed> -struct traits<transposition_matrix_product_retval<TranspositionType, MatrixType, Side, Transposed> > -{ - typedef typename MatrixType::PlainObject ReturnType; -}; - -template<typename TranspositionType, typename MatrixType, int Side, bool Transposed> -struct transposition_matrix_product_retval - : public ReturnByValue<transposition_matrix_product_retval<TranspositionType, MatrixType, Side, Transposed> > -{ - typedef typename remove_all<typename MatrixType::Nested>::type MatrixTypeNestedCleaned; - typedef typename TranspositionType::StorageIndex StorageIndex; - - transposition_matrix_product_retval(const TranspositionType& tr, const MatrixType& matrix) - : m_transpositions(tr), m_matrix(matrix) - {} - - inline Index rows() const { return m_matrix.rows(); } - inline Index cols() const { return m_matrix.cols(); } - - template<typename Dest> inline void evalTo(Dest& dst) const - { - const Index size = m_transpositions.size(); - StorageIndex j = 0; - - if(!(is_same<MatrixTypeNestedCleaned,Dest>::value && extract_data(dst) == extract_data(m_matrix))) - dst = m_matrix; - - for(Index k=(Transposed?size-1:0) ; Transposed?k>=0:k<size ; Transposed?--k:++k) - if(Index(j=m_transpositions.coeff(k))!=k) - { - if(Side==OnTheLeft) - dst.row(k).swap(dst.row(j)); - else if(Side==OnTheRight) - dst.col(k).swap(dst.col(j)); - } - } - - protected: - const TranspositionType& m_transpositions; - typename MatrixType::Nested m_matrix; -}; - -} // end namespace internal /* Template partial specialization for transposed/inverse transpositions */ @@ -412,26 +363,56 @@ class Transpose<TranspositionsBase<TranspositionsDerived> > /** \returns the \a matrix with the inverse transpositions applied to the columns. */ - template<typename Derived> friend - inline const internal::transposition_matrix_product_retval<TranspositionType, Derived, OnTheRight, true> - operator*(const MatrixBase<Derived>& matrix, const Transpose& trt) + template<typename OtherDerived> friend + const Product<OtherDerived, Transpose, DefaultProduct> + operator*(const MatrixBase<OtherDerived>& matrix, const Transpose& trt) { - return internal::transposition_matrix_product_retval<TranspositionType, Derived, OnTheRight, true>(trt.m_transpositions, matrix.derived()); + return Product<OtherDerived, Transpose, DefaultProduct>(matrix.derived(), trt.derived()); } /** \returns the \a matrix with the inverse transpositions applied to the rows. */ - template<typename Derived> - inline const internal::transposition_matrix_product_retval<TranspositionType, Derived, OnTheLeft, true> - operator*(const MatrixBase<Derived>& matrix) const + template<typename OtherDerived> + const Product<Transpose, OtherDerived, DefaultProduct> + operator*(const MatrixBase<OtherDerived>& matrix) const { - return internal::transposition_matrix_product_retval<TranspositionType, Derived, OnTheLeft, true>(m_transpositions, matrix.derived()); + return Product<Transpose, OtherDerived, DefaultProduct>(*this, matrix.derived()); } protected: const TranspositionType& m_transpositions; }; +namespace internal { + +// TODO currently a Transpositions expression has the form Transpositions or TranspositionsWrapper +// or their transpose; in the future shape should be defined by the expression traits +template<int SizeAtCompileTime, int MaxSizeAtCompileTime, typename IndexType> +struct evaluator_traits<Transpositions<SizeAtCompileTime, MaxSizeAtCompileTime, IndexType> > +{ + typedef typename storage_kind_to_evaluator_kind<Dense>::Kind Kind; + typedef TranspositionsShape Shape; + static const int AssumeAliasing = 0; +}; + +template<typename IndicesType> +struct evaluator_traits<TranspositionsWrapper<IndicesType> > +{ + typedef typename storage_kind_to_evaluator_kind<Dense>::Kind Kind; + typedef TranspositionsShape Shape; + static const int AssumeAliasing = 0; +}; + +template<typename Derived> +struct evaluator_traits<Transpose<TranspositionsBase<Derived> > > +{ + typedef typename storage_kind_to_evaluator_kind<Dense>::Kind Kind; + typedef TranspositionsShape Shape; + static const int AssumeAliasing = 0; +}; + +} // end namespace internal + } // end namespace Eigen #endif // EIGEN_TRANSPOSITIONS_H diff --git a/Eigen/src/Core/util/Constants.h b/Eigen/src/Core/util/Constants.h index 419409608..3e6c75444 100644 --- a/Eigen/src/Core/util/Constants.h +++ b/Eigen/src/Core/util/Constants.h @@ -482,6 +482,7 @@ struct BandShape { static std::string debugName() { return "BandSha struct TriangularShape { static std::string debugName() { return "TriangularShape"; } }; struct SelfAdjointShape { static std::string debugName() { return "SelfAdjointShape"; } }; struct PermutationShape { static std::string debugName() { return "PermutationShape"; } }; +struct TranspositionsShape { static std::string debugName() { return "TranspositionsShape"; } }; struct SparseShape { static std::string debugName() { return "SparseShape"; } }; namespace internal { |