aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen/src/Core/Redux.h
diff options
context:
space:
mode:
authorGravatar Gael Guennebaud <g.gael@free.fr>2013-12-02 17:54:38 +0100
committerGravatar Gael Guennebaud <g.gael@free.fr>2013-12-02 17:54:38 +0100
commitf0b82c3ab972a1eafd6aa4f08f4eaffa0d6f1e55 (patch)
treee7117c7773f19786752b4d702a7e7205dc3d1cf2 /Eigen/src/Core/Redux.h
parent6f1a0479b3a8176f4d8db6513cbb39c027c34b7f (diff)
Make reductions compatible with evaluators
Diffstat (limited to 'Eigen/src/Core/Redux.h')
-rw-r--r--Eigen/src/Core/Redux.h81
1 files changed, 74 insertions, 7 deletions
diff --git a/Eigen/src/Core/Redux.h b/Eigen/src/Core/Redux.h
index b2c775d90..1b6a4177a 100644
--- a/Eigen/src/Core/Redux.h
+++ b/Eigen/src/Core/Redux.h
@@ -174,7 +174,7 @@ struct redux_impl<Func, Derived, DefaultTraversal, NoUnrolling>
typedef typename Derived::Scalar Scalar;
typedef typename Derived::Index Index;
EIGEN_DEVICE_FUNC
- static EIGEN_STRONG_INLINE Scalar run(const Derived& mat, const Func& func)
+ static EIGEN_STRONG_INLINE Scalar run(const Derived &mat, const Func& func)
{
eigen_assert(mat.rows()>0 && mat.cols()>0 && "you are using an empty matrix");
Scalar res;
@@ -200,10 +200,10 @@ struct redux_impl<Func, Derived, LinearVectorizedTraversal, NoUnrolling>
typedef typename packet_traits<Scalar>::type PacketScalar;
typedef typename Derived::Index Index;
- static Scalar run(const Derived& mat, const Func& func)
+ static Scalar run(const Derived &mat, const Func& func)
{
const Index size = mat.size();
- eigen_assert(size && "you are using an empty matrix");
+
const Index packetSize = packet_traits<Scalar>::size;
const Index alignedStart = internal::first_aligned(mat);
enum {
@@ -258,7 +258,7 @@ struct redux_impl<Func, Derived, SliceVectorizedTraversal, NoUnrolling>
typedef typename packet_traits<Scalar>::type PacketScalar;
typedef typename Derived::Index Index;
- static Scalar run(const Derived& mat, const Func& func)
+ static Scalar run(const Derived &mat, const Func& func)
{
eigen_assert(mat.rows()>0 && mat.cols()>0 && "you are using an empty matrix");
const Index innerSize = mat.innerSize();
@@ -300,9 +300,8 @@ struct redux_impl<Func, Derived, LinearVectorizedTraversal, CompleteUnrolling>
Size = Derived::SizeAtCompileTime,
VectorizedSize = (Size / PacketSize) * PacketSize
};
- static EIGEN_STRONG_INLINE Scalar run(const Derived& mat, const Func& func)
+ static EIGEN_STRONG_INLINE Scalar run(const Derived &mat, const Func& func)
{
- eigen_assert(mat.rows()>0 && mat.cols()>0 && "you are using an empty matrix");
Scalar res = func.predux(redux_vec_unroller<Func, Derived, 0, Size / PacketSize>::run(mat,func));
if (VectorizedSize != Size)
res = func(res,redux_novec_unroller<Func, Derived, VectorizedSize, Size-VectorizedSize>::run(mat,func));
@@ -310,6 +309,64 @@ struct redux_impl<Func, Derived, LinearVectorizedTraversal, CompleteUnrolling>
}
};
+#ifdef EIGEN_ENABLE_EVALUATORS
+// evaluator adaptor
+template<typename XprType>
+class redux_evaluator
+{
+public:
+ redux_evaluator(const XprType &xpr) : m_evaluator(xpr), m_xpr(xpr) {}
+
+ typedef typename XprType::Index Index;
+ typedef typename XprType::Scalar Scalar;
+ typedef typename XprType::CoeffReturnType CoeffReturnType;
+ typedef typename XprType::PacketScalar PacketScalar;
+ typedef typename XprType::PacketReturnType PacketReturnType;
+
+ enum {
+ MaxRowsAtCompileTime = XprType::MaxRowsAtCompileTime,
+ MaxColsAtCompileTime = XprType::MaxColsAtCompileTime,
+ // TODO we should not remove DirectAccessBit and rather find an elegant way to query the alignment offset at runtime from the evaluator
+ Flags = XprType::Flags & ~DirectAccessBit,
+ IsRowMajor = XprType::IsRowMajor,
+ SizeAtCompileTime = XprType::SizeAtCompileTime,
+ InnerSizeAtCompileTime = XprType::InnerSizeAtCompileTime,
+ CoeffReadCost = XprType::CoeffReadCost
+ };
+
+ Index rows() const { return m_xpr.rows(); }
+ Index cols() const { return m_xpr.cols(); }
+ Index size() const { return m_xpr.size(); }
+ Index innerSize() const { return m_xpr.innerSize(); }
+ Index outerSize() const { return m_xpr.outerSize(); }
+
+ CoeffReturnType coeff(Index row, Index col) const
+ { return m_evaluator.coeff(row, col); }
+
+ CoeffReturnType coeff(Index index) const
+ { return m_evaluator.coeff(index); }
+
+ template<int LoadMode>
+ PacketReturnType packet(Index row, Index col) const
+ { return m_evaluator.template packet<LoadMode>(row, col); }
+
+ template<int LoadMode>
+ PacketReturnType packet(Index index) const
+ { return m_evaluator.template packet<LoadMode>(index); }
+
+ CoeffReturnType coeffByOuterInner(Index outer, Index inner) const
+ { return m_evaluator.coeff(IsRowMajor ? outer : inner, IsRowMajor ? inner : outer); }
+
+ template<int LoadMode>
+ PacketReturnType packetByOuterInner(Index outer, Index inner) const
+ { return m_evaluator.template packet<LoadMode>(IsRowMajor ? outer : inner, IsRowMajor ? inner : outer); }
+
+protected:
+ typename internal::evaluator<XprType>::nestedType m_evaluator;
+ const XprType &m_xpr;
+};
+#endif
+
} // end namespace internal
/***************************************************************************
@@ -320,7 +377,7 @@ struct redux_impl<Func, Derived, LinearVectorizedTraversal, CompleteUnrolling>
/** \returns the result of a full redux operation on the whole matrix or vector using \a func
*
* The template parameter \a BinaryOp is the type of the functor \a func which must be
- * an associative operator. Both current STL and TR1 functor styles are handled.
+ * an associative operator. Both current C++98 and C++11 functor styles are handled.
*
* \sa DenseBase::sum(), DenseBase::minCoeff(), DenseBase::maxCoeff(), MatrixBase::colwise(), MatrixBase::rowwise()
*/
@@ -329,9 +386,19 @@ template<typename Func>
EIGEN_STRONG_INLINE typename internal::result_of<Func(typename internal::traits<Derived>::Scalar)>::type
DenseBase<Derived>::redux(const Func& func) const
{
+ eigen_assert(this->rows()>0 && this->cols()>0 && "you are using an empty matrix");
+#ifdef EIGEN_TEST_EVALUATORS
+
+ typedef typename internal::redux_evaluator<Derived> ThisEvaluator;
+ ThisEvaluator thisEval(derived());
+ return internal::redux_impl<Func, ThisEvaluator>::run(thisEval, func);
+
+#else
typedef typename internal::remove_all<typename Derived::Nested>::type ThisNested;
+
return internal::redux_impl<Func, ThisNested>
::run(derived(), func);
+#endif
}
/** \returns the minimum of all coefficients of \c *this.