diff options
author | Gael Guennebaud <g.gael@free.fr> | 2015-10-08 15:57:05 +0200 |
---|---|---|
committer | Gael Guennebaud <g.gael@free.fr> | 2015-10-08 15:57:05 +0200 |
commit | aa6b1aebf373fba262fab7cd833881eac4fed8ef (patch) | |
tree | a2c0c211837e1600f69f23da27aa5c5c705ed593 /Eigen/src/Core/CoreEvaluators.h | |
parent | 5cc7251188110f2d24425eac3ce00d051d2b2c55 (diff) |
Properly implement PartialReduxExpr on top of evaluators, and fix multiple evaluation of nested expression
Diffstat (limited to 'Eigen/src/Core/CoreEvaluators.h')
-rw-r--r-- | Eigen/src/Core/CoreEvaluators.h | 40 |
1 files changed, 23 insertions, 17 deletions
diff --git a/Eigen/src/Core/CoreEvaluators.h b/Eigen/src/Core/CoreEvaluators.h index 214114ebe..b96ef99fa 100644 --- a/Eigen/src/Core/CoreEvaluators.h +++ b/Eigen/src/Core/CoreEvaluators.h @@ -965,17 +965,16 @@ protected: // -------------------- PartialReduxExpr -------------------- -// -// This is a wrapper around the expression object. -// TODO: Find out how to write a proper evaluator without duplicating -// the row() and col() member functions. template< typename ArgType, typename MemberOp, int Direction> struct evaluator<PartialReduxExpr<ArgType, MemberOp, Direction> > : evaluator_base<PartialReduxExpr<ArgType, MemberOp, Direction> > { typedef PartialReduxExpr<ArgType, MemberOp, Direction> XprType; - typedef typename XprType::Scalar InputScalar; + typedef typename internal::nested_eval<ArgType,1>::type ArgTypeNested; + typedef typename internal::remove_all<ArgTypeNested>::type ArgTypeNestedCleaned; + typedef typename ArgType::Scalar InputScalar; + typedef typename XprType::Scalar Scalar; enum { TraversalSize = Direction==int(Vertical) ? int(ArgType::RowsAtCompileTime) : int(XprType::ColsAtCompileTime) }; @@ -986,27 +985,34 @@ struct evaluator<PartialReduxExpr<ArgType, MemberOp, Direction> > Flags = (traits<XprType>::Flags&RowMajorBit) | (evaluator<ArgType>::Flags&HereditaryBits), - Alignment = 0 // FIXME this could be improved + Alignment = 0 // FIXME this will need to be improved once PartialReduxExpr is vectorized }; - EIGEN_DEVICE_FUNC explicit evaluator(const XprType expr) - : m_expr(expr) + EIGEN_DEVICE_FUNC explicit evaluator(const XprType xpr) + : m_arg(xpr.nestedExpression()), m_functor(xpr.functor()) {} typedef typename XprType::CoeffReturnType CoeffReturnType; - - EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index row, Index col) const - { - return m_expr.coeff(row, col); + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar coeff(Index i, Index j) const + { + if (Direction==Vertical) + return m_functor(m_arg.col(j)); + else + return m_functor(m_arg.row(i)); } - - EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index index) const - { - return m_expr.coeff(index); + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar coeff(Index index) const + { + if (Direction==Vertical) + return m_functor(m_arg.col(index)); + else + return m_functor(m_arg.row(index)); } protected: - const XprType m_expr; + const ArgTypeNested m_arg; + const MemberOp m_functor; }; |