aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen/src/Core/Transpositions.h
diff options
context:
space:
mode:
authorGravatar Gael Guennebaud <g.gael@free.fr>2015-06-19 11:50:24 +0200
committerGravatar Gael Guennebaud <g.gael@free.fr>2015-06-19 11:50:24 +0200
commit3af4c6c1c9327411d13386e4719ce48f866c7567 (patch)
tree1cc76e62ae2eaba007b658526534096abe540d43 /Eigen/src/Core/Transpositions.h
parent82b6ac08646f7b12770665134acaf3bb3cdc4dd3 (diff)
Make Transpositions use evaluators
Diffstat (limited to 'Eigen/src/Core/Transpositions.h')
-rw-r--r--Eigen/src/Core/Transpositions.h127
1 files changed, 54 insertions, 73 deletions
diff --git a/Eigen/src/Core/Transpositions.h b/Eigen/src/Core/Transpositions.h
index b08df1ead..dad4f56c9 100644
--- a/Eigen/src/Core/Transpositions.h
+++ b/Eigen/src/Core/Transpositions.h
@@ -41,10 +41,6 @@ namespace Eigen {
* \sa class PermutationMatrix
*/
-namespace internal {
-template<typename TranspositionType, typename MatrixType, int Side, bool Transposed=false> struct transposition_matrix_product_retval;
-}
-
template<typename Derived>
class TranspositionsBase
{
@@ -325,77 +321,32 @@ class TranspositionsWrapper
const typename IndicesType::Nested m_indices;
};
+
+
/** \returns the \a matrix with the \a transpositions applied to the columns.
*/
-template<typename Derived, typename TranspositionsDerived>
-inline const internal::transposition_matrix_product_retval<TranspositionsDerived, Derived, OnTheRight>
-operator*(const MatrixBase<Derived>& matrix,
- const TranspositionsBase<TranspositionsDerived> &transpositions)
+template<typename MatrixDerived, typename TranspositionsDerived>
+EIGEN_DEVICE_FUNC
+const Product<MatrixDerived, TranspositionsDerived, DefaultProduct>
+operator*(const MatrixBase<MatrixDerived> &matrix,
+ const TranspositionsBase<TranspositionsDerived>& transpositions)
{
- return internal::transposition_matrix_product_retval
- <TranspositionsDerived, Derived, OnTheRight>
- (transpositions.derived(), matrix.derived());
+ return Product<MatrixDerived, TranspositionsDerived, DefaultProduct>
+ (matrix.derived(), transpositions.derived());
}
/** \returns the \a matrix with the \a transpositions applied to the rows.
*/
-template<typename Derived, typename TranspositionDerived>
-inline const internal::transposition_matrix_product_retval
- <TranspositionDerived, Derived, OnTheLeft>
-operator*(const TranspositionsBase<TranspositionDerived> &transpositions,
- const MatrixBase<Derived>& matrix)
+template<typename TranspositionsDerived, typename MatrixDerived>
+EIGEN_DEVICE_FUNC
+const Product<TranspositionsDerived, MatrixDerived, DefaultProduct>
+operator*(const TranspositionsBase<TranspositionsDerived> &transpositions,
+ const MatrixBase<MatrixDerived>& matrix)
{
- return internal::transposition_matrix_product_retval
- <TranspositionDerived, Derived, OnTheLeft>
- (transpositions.derived(), matrix.derived());
+ return Product<TranspositionsDerived, MatrixDerived, DefaultProduct>
+ (transpositions.derived(), matrix.derived());
}
-namespace internal {
-
-template<typename TranspositionType, typename MatrixType, int Side, bool Transposed>
-struct traits<transposition_matrix_product_retval<TranspositionType, MatrixType, Side, Transposed> >
-{
- typedef typename MatrixType::PlainObject ReturnType;
-};
-
-template<typename TranspositionType, typename MatrixType, int Side, bool Transposed>
-struct transposition_matrix_product_retval
- : public ReturnByValue<transposition_matrix_product_retval<TranspositionType, MatrixType, Side, Transposed> >
-{
- typedef typename remove_all<typename MatrixType::Nested>::type MatrixTypeNestedCleaned;
- typedef typename TranspositionType::StorageIndex StorageIndex;
-
- transposition_matrix_product_retval(const TranspositionType& tr, const MatrixType& matrix)
- : m_transpositions(tr), m_matrix(matrix)
- {}
-
- inline Index rows() const { return m_matrix.rows(); }
- inline Index cols() const { return m_matrix.cols(); }
-
- template<typename Dest> inline void evalTo(Dest& dst) const
- {
- const Index size = m_transpositions.size();
- StorageIndex j = 0;
-
- if(!(is_same<MatrixTypeNestedCleaned,Dest>::value && extract_data(dst) == extract_data(m_matrix)))
- dst = m_matrix;
-
- for(Index k=(Transposed?size-1:0) ; Transposed?k>=0:k<size ; Transposed?--k:++k)
- if(Index(j=m_transpositions.coeff(k))!=k)
- {
- if(Side==OnTheLeft)
- dst.row(k).swap(dst.row(j));
- else if(Side==OnTheRight)
- dst.col(k).swap(dst.col(j));
- }
- }
-
- protected:
- const TranspositionType& m_transpositions;
- typename MatrixType::Nested m_matrix;
-};
-
-} // end namespace internal
/* Template partial specialization for transposed/inverse transpositions */
@@ -412,26 +363,56 @@ class Transpose<TranspositionsBase<TranspositionsDerived> >
/** \returns the \a matrix with the inverse transpositions applied to the columns.
*/
- template<typename Derived> friend
- inline const internal::transposition_matrix_product_retval<TranspositionType, Derived, OnTheRight, true>
- operator*(const MatrixBase<Derived>& matrix, const Transpose& trt)
+ template<typename OtherDerived> friend
+ const Product<OtherDerived, Transpose, DefaultProduct>
+ operator*(const MatrixBase<OtherDerived>& matrix, const Transpose& trt)
{
- return internal::transposition_matrix_product_retval<TranspositionType, Derived, OnTheRight, true>(trt.m_transpositions, matrix.derived());
+ return Product<OtherDerived, Transpose, DefaultProduct>(matrix.derived(), trt.derived());
}
/** \returns the \a matrix with the inverse transpositions applied to the rows.
*/
- template<typename Derived>
- inline const internal::transposition_matrix_product_retval<TranspositionType, Derived, OnTheLeft, true>
- operator*(const MatrixBase<Derived>& matrix) const
+ template<typename OtherDerived>
+ const Product<Transpose, OtherDerived, DefaultProduct>
+ operator*(const MatrixBase<OtherDerived>& matrix) const
{
- return internal::transposition_matrix_product_retval<TranspositionType, Derived, OnTheLeft, true>(m_transpositions, matrix.derived());
+ return Product<Transpose, OtherDerived, DefaultProduct>(*this, matrix.derived());
}
protected:
const TranspositionType& m_transpositions;
};
+namespace internal {
+
+// TODO currently a Transpositions expression has the form Transpositions or TranspositionsWrapper
+// or their transpose; in the future shape should be defined by the expression traits
+template<int SizeAtCompileTime, int MaxSizeAtCompileTime, typename IndexType>
+struct evaluator_traits<Transpositions<SizeAtCompileTime, MaxSizeAtCompileTime, IndexType> >
+{
+ typedef typename storage_kind_to_evaluator_kind<Dense>::Kind Kind;
+ typedef TranspositionsShape Shape;
+ static const int AssumeAliasing = 0;
+};
+
+template<typename IndicesType>
+struct evaluator_traits<TranspositionsWrapper<IndicesType> >
+{
+ typedef typename storage_kind_to_evaluator_kind<Dense>::Kind Kind;
+ typedef TranspositionsShape Shape;
+ static const int AssumeAliasing = 0;
+};
+
+template<typename Derived>
+struct evaluator_traits<Transpose<TranspositionsBase<Derived> > >
+{
+ typedef typename storage_kind_to_evaluator_kind<Dense>::Kind Kind;
+ typedef TranspositionsShape Shape;
+ static const int AssumeAliasing = 0;
+};
+
+} // end namespace internal
+
} // end namespace Eigen
#endif // EIGEN_TRANSPOSITIONS_H