aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen/src/Core
diff options
context:
space:
mode:
authorGravatar Gael Guennebaud <g.gael@free.fr>2015-06-19 11:50:24 +0200
committerGravatar Gael Guennebaud <g.gael@free.fr>2015-06-19 11:50:24 +0200
commit3af4c6c1c9327411d13386e4719ce48f866c7567 (patch)
tree1cc76e62ae2eaba007b658526534096abe540d43 /Eigen/src/Core
parent82b6ac08646f7b12770665134acaf3bb3cdc4dd3 (diff)
Make Transpositions use evaluators
Diffstat (limited to 'Eigen/src/Core')
-rw-r--r--Eigen/src/Core/PermutationMatrix.h3
-rw-r--r--Eigen/src/Core/ProductEvaluators.h73
-rw-r--r--Eigen/src/Core/Transpositions.h127
-rw-r--r--Eigen/src/Core/util/Constants.h1
4 files changed, 128 insertions, 76 deletions
diff --git a/Eigen/src/Core/PermutationMatrix.h b/Eigen/src/Core/PermutationMatrix.h
index 9a0c03612..8c9afd4ee 100644
--- a/Eigen/src/Core/PermutationMatrix.h
+++ b/Eigen/src/Core/PermutationMatrix.h
@@ -537,9 +537,6 @@ class PermutationWrapper : public PermutationBase<PermutationWrapper<_IndicesTyp
};
-// TODO: Do we need to define these operator* functions? Would it be better to have them inherited
-// from MatrixBase?
-
/** \returns the matrix with the permutation applied to the columns.
*/
template<typename MatrixDerived, typename PermutationDerived>
diff --git a/Eigen/src/Core/ProductEvaluators.h b/Eigen/src/Core/ProductEvaluators.h
index 9d1cb5d56..4c673a6cb 100644
--- a/Eigen/src/Core/ProductEvaluators.h
+++ b/Eigen/src/Core/ProductEvaluators.h
@@ -929,6 +929,79 @@ struct generic_product_impl<Lhs, Transpose<Rhs>, MatrixShape, PermutationShape,
}
};
+
+/***************************************************************************
+* Products with transpositions matrices
+***************************************************************************/
+
+// FIXME could we unify Transpositions and Permutation into a single "shape"??
+
+/** \internal
+ * \class transposition_matrix_product
+ * Internal helper class implementing the product between a permutation matrix and a matrix.
+ */
+template<typename MatrixType, int Side, bool Transposed, typename MatrixShape>
+struct transposition_matrix_product
+{
+ template<typename Dest, typename TranspositionType>
+ static inline void evalTo(Dest& dst, const TranspositionType& tr, const MatrixType& mat)
+ {
+ typedef typename TranspositionType::StorageIndex StorageIndex;
+ const Index size = tr.size();
+ StorageIndex j = 0;
+
+ if(!(is_same<MatrixType,Dest>::value && extract_data(dst) == extract_data(mat)))
+ dst = mat;
+
+ for(Index k=(Transposed?size-1:0) ; Transposed?k>=0:k<size ; Transposed?--k:++k)
+ if(Index(j=tr.coeff(k))!=k)
+ {
+ if(Side==OnTheLeft) dst.row(k).swap(dst.row(j));
+ else if(Side==OnTheRight) dst.col(k).swap(dst.col(j));
+ }
+ }
+};
+
+template<typename Lhs, typename Rhs, int ProductTag, typename MatrixShape>
+struct generic_product_impl<Lhs, Rhs, TranspositionsShape, MatrixShape, ProductTag>
+{
+ template<typename Dest>
+ static void evalTo(Dest& dst, const Lhs& lhs, const Rhs& rhs)
+ {
+ transposition_matrix_product<Rhs, OnTheLeft, false, MatrixShape>::run(dst, lhs, rhs);
+ }
+};
+
+template<typename Lhs, typename Rhs, int ProductTag, typename MatrixShape>
+struct generic_product_impl<Lhs, Rhs, MatrixShape, TranspositionsShape, ProductTag>
+{
+ template<typename Dest>
+ static void evalTo(Dest& dst, const Lhs& lhs, const Rhs& rhs)
+ {
+ transposition_matrix_product<Lhs, OnTheRight, false, MatrixShape>::run(dst, rhs, lhs);
+ }
+};
+
+template<typename Lhs, typename Rhs, int ProductTag, typename MatrixShape>
+struct generic_product_impl<Transpose<Lhs>, Rhs, TranspositionsShape, MatrixShape, ProductTag>
+{
+ template<typename Dest>
+ static void evalTo(Dest& dst, const Transpose<Lhs>& lhs, const Rhs& rhs)
+ {
+ transposition_matrix_product<Rhs, OnTheLeft, true, MatrixShape>::run(dst, lhs.nestedPermutation(), rhs);
+ }
+};
+
+template<typename Lhs, typename Rhs, int ProductTag, typename MatrixShape>
+struct generic_product_impl<Lhs, Transpose<Rhs>, MatrixShape, TranspositionsShape, ProductTag>
+{
+ template<typename Dest>
+ static void evalTo(Dest& dst, const Lhs& lhs, const Transpose<Rhs>& rhs)
+ {
+ transposition_matrix_product<Lhs, OnTheRight, true, MatrixShape>::run(dst, rhs.nestedPermutation(), lhs);
+ }
+};
+
} // end namespace internal
} // end namespace Eigen
diff --git a/Eigen/src/Core/Transpositions.h b/Eigen/src/Core/Transpositions.h
index b08df1ead..dad4f56c9 100644
--- a/Eigen/src/Core/Transpositions.h
+++ b/Eigen/src/Core/Transpositions.h
@@ -41,10 +41,6 @@ namespace Eigen {
* \sa class PermutationMatrix
*/
-namespace internal {
-template<typename TranspositionType, typename MatrixType, int Side, bool Transposed=false> struct transposition_matrix_product_retval;
-}
-
template<typename Derived>
class TranspositionsBase
{
@@ -325,77 +321,32 @@ class TranspositionsWrapper
const typename IndicesType::Nested m_indices;
};
+
+
/** \returns the \a matrix with the \a transpositions applied to the columns.
*/
-template<typename Derived, typename TranspositionsDerived>
-inline const internal::transposition_matrix_product_retval<TranspositionsDerived, Derived, OnTheRight>
-operator*(const MatrixBase<Derived>& matrix,
- const TranspositionsBase<TranspositionsDerived> &transpositions)
+template<typename MatrixDerived, typename TranspositionsDerived>
+EIGEN_DEVICE_FUNC
+const Product<MatrixDerived, TranspositionsDerived, DefaultProduct>
+operator*(const MatrixBase<MatrixDerived> &matrix,
+ const TranspositionsBase<TranspositionsDerived>& transpositions)
{
- return internal::transposition_matrix_product_retval
- <TranspositionsDerived, Derived, OnTheRight>
- (transpositions.derived(), matrix.derived());
+ return Product<MatrixDerived, TranspositionsDerived, DefaultProduct>
+ (matrix.derived(), transpositions.derived());
}
/** \returns the \a matrix with the \a transpositions applied to the rows.
*/
-template<typename Derived, typename TranspositionDerived>
-inline const internal::transposition_matrix_product_retval
- <TranspositionDerived, Derived, OnTheLeft>
-operator*(const TranspositionsBase<TranspositionDerived> &transpositions,
- const MatrixBase<Derived>& matrix)
+template<typename TranspositionsDerived, typename MatrixDerived>
+EIGEN_DEVICE_FUNC
+const Product<TranspositionsDerived, MatrixDerived, DefaultProduct>
+operator*(const TranspositionsBase<TranspositionsDerived> &transpositions,
+ const MatrixBase<MatrixDerived>& matrix)
{
- return internal::transposition_matrix_product_retval
- <TranspositionDerived, Derived, OnTheLeft>
- (transpositions.derived(), matrix.derived());
+ return Product<TranspositionsDerived, MatrixDerived, DefaultProduct>
+ (transpositions.derived(), matrix.derived());
}
-namespace internal {
-
-template<typename TranspositionType, typename MatrixType, int Side, bool Transposed>
-struct traits<transposition_matrix_product_retval<TranspositionType, MatrixType, Side, Transposed> >
-{
- typedef typename MatrixType::PlainObject ReturnType;
-};
-
-template<typename TranspositionType, typename MatrixType, int Side, bool Transposed>
-struct transposition_matrix_product_retval
- : public ReturnByValue<transposition_matrix_product_retval<TranspositionType, MatrixType, Side, Transposed> >
-{
- typedef typename remove_all<typename MatrixType::Nested>::type MatrixTypeNestedCleaned;
- typedef typename TranspositionType::StorageIndex StorageIndex;
-
- transposition_matrix_product_retval(const TranspositionType& tr, const MatrixType& matrix)
- : m_transpositions(tr), m_matrix(matrix)
- {}
-
- inline Index rows() const { return m_matrix.rows(); }
- inline Index cols() const { return m_matrix.cols(); }
-
- template<typename Dest> inline void evalTo(Dest& dst) const
- {
- const Index size = m_transpositions.size();
- StorageIndex j = 0;
-
- if(!(is_same<MatrixTypeNestedCleaned,Dest>::value && extract_data(dst) == extract_data(m_matrix)))
- dst = m_matrix;
-
- for(Index k=(Transposed?size-1:0) ; Transposed?k>=0:k<size ; Transposed?--k:++k)
- if(Index(j=m_transpositions.coeff(k))!=k)
- {
- if(Side==OnTheLeft)
- dst.row(k).swap(dst.row(j));
- else if(Side==OnTheRight)
- dst.col(k).swap(dst.col(j));
- }
- }
-
- protected:
- const TranspositionType& m_transpositions;
- typename MatrixType::Nested m_matrix;
-};
-
-} // end namespace internal
/* Template partial specialization for transposed/inverse transpositions */
@@ -412,26 +363,56 @@ class Transpose<TranspositionsBase<TranspositionsDerived> >
/** \returns the \a matrix with the inverse transpositions applied to the columns.
*/
- template<typename Derived> friend
- inline const internal::transposition_matrix_product_retval<TranspositionType, Derived, OnTheRight, true>
- operator*(const MatrixBase<Derived>& matrix, const Transpose& trt)
+ template<typename OtherDerived> friend
+ const Product<OtherDerived, Transpose, DefaultProduct>
+ operator*(const MatrixBase<OtherDerived>& matrix, const Transpose& trt)
{
- return internal::transposition_matrix_product_retval<TranspositionType, Derived, OnTheRight, true>(trt.m_transpositions, matrix.derived());
+ return Product<OtherDerived, Transpose, DefaultProduct>(matrix.derived(), trt.derived());
}
/** \returns the \a matrix with the inverse transpositions applied to the rows.
*/
- template<typename Derived>
- inline const internal::transposition_matrix_product_retval<TranspositionType, Derived, OnTheLeft, true>
- operator*(const MatrixBase<Derived>& matrix) const
+ template<typename OtherDerived>
+ const Product<Transpose, OtherDerived, DefaultProduct>
+ operator*(const MatrixBase<OtherDerived>& matrix) const
{
- return internal::transposition_matrix_product_retval<TranspositionType, Derived, OnTheLeft, true>(m_transpositions, matrix.derived());
+ return Product<Transpose, OtherDerived, DefaultProduct>(*this, matrix.derived());
}
protected:
const TranspositionType& m_transpositions;
};
+namespace internal {
+
+// TODO currently a Transpositions expression has the form Transpositions or TranspositionsWrapper
+// or their transpose; in the future shape should be defined by the expression traits
+template<int SizeAtCompileTime, int MaxSizeAtCompileTime, typename IndexType>
+struct evaluator_traits<Transpositions<SizeAtCompileTime, MaxSizeAtCompileTime, IndexType> >
+{
+ typedef typename storage_kind_to_evaluator_kind<Dense>::Kind Kind;
+ typedef TranspositionsShape Shape;
+ static const int AssumeAliasing = 0;
+};
+
+template<typename IndicesType>
+struct evaluator_traits<TranspositionsWrapper<IndicesType> >
+{
+ typedef typename storage_kind_to_evaluator_kind<Dense>::Kind Kind;
+ typedef TranspositionsShape Shape;
+ static const int AssumeAliasing = 0;
+};
+
+template<typename Derived>
+struct evaluator_traits<Transpose<TranspositionsBase<Derived> > >
+{
+ typedef typename storage_kind_to_evaluator_kind<Dense>::Kind Kind;
+ typedef TranspositionsShape Shape;
+ static const int AssumeAliasing = 0;
+};
+
+} // end namespace internal
+
} // end namespace Eigen
#endif // EIGEN_TRANSPOSITIONS_H
diff --git a/Eigen/src/Core/util/Constants.h b/Eigen/src/Core/util/Constants.h
index 419409608..3e6c75444 100644
--- a/Eigen/src/Core/util/Constants.h
+++ b/Eigen/src/Core/util/Constants.h
@@ -482,6 +482,7 @@ struct BandShape { static std::string debugName() { return "BandSha
struct TriangularShape { static std::string debugName() { return "TriangularShape"; } };
struct SelfAdjointShape { static std::string debugName() { return "SelfAdjointShape"; } };
struct PermutationShape { static std::string debugName() { return "PermutationShape"; } };
+struct TranspositionsShape { static std::string debugName() { return "TranspositionsShape"; } };
struct SparseShape { static std::string debugName() { return "SparseShape"; } };
namespace internal {