diff options
author | Gael Guennebaud <g.gael@free.fr> | 2015-10-08 15:57:05 +0200 |
---|---|---|
committer | Gael Guennebaud <g.gael@free.fr> | 2015-10-08 15:57:05 +0200 |
commit | aa6b1aebf373fba262fab7cd833881eac4fed8ef (patch) | |
tree | a2c0c211837e1600f69f23da27aa5c5c705ed593 /Eigen/src/Core | |
parent | 5cc7251188110f2d24425eac3ce00d051d2b2c55 (diff) |
Properly implement PartialReduxExpr on top of evaluators, and fix multiple evaluation of nested expression
Diffstat (limited to 'Eigen/src/Core')
-rw-r--r-- | Eigen/src/Core/CoreEvaluators.h | 40 | ||||
-rw-r--r-- | Eigen/src/Core/VectorwiseOp.h | 23 |
2 files changed, 26 insertions, 37 deletions
diff --git a/Eigen/src/Core/CoreEvaluators.h b/Eigen/src/Core/CoreEvaluators.h index 214114ebe..b96ef99fa 100644 --- a/Eigen/src/Core/CoreEvaluators.h +++ b/Eigen/src/Core/CoreEvaluators.h @@ -965,17 +965,16 @@ protected: // -------------------- PartialReduxExpr -------------------- -// -// This is a wrapper around the expression object. -// TODO: Find out how to write a proper evaluator without duplicating -// the row() and col() member functions. template< typename ArgType, typename MemberOp, int Direction> struct evaluator<PartialReduxExpr<ArgType, MemberOp, Direction> > : evaluator_base<PartialReduxExpr<ArgType, MemberOp, Direction> > { typedef PartialReduxExpr<ArgType, MemberOp, Direction> XprType; - typedef typename XprType::Scalar InputScalar; + typedef typename internal::nested_eval<ArgType,1>::type ArgTypeNested; + typedef typename internal::remove_all<ArgTypeNested>::type ArgTypeNestedCleaned; + typedef typename ArgType::Scalar InputScalar; + typedef typename XprType::Scalar Scalar; enum { TraversalSize = Direction==int(Vertical) ? int(ArgType::RowsAtCompileTime) : int(XprType::ColsAtCompileTime) }; @@ -986,27 +985,34 @@ struct evaluator<PartialReduxExpr<ArgType, MemberOp, Direction> > Flags = (traits<XprType>::Flags&RowMajorBit) | (evaluator<ArgType>::Flags&HereditaryBits), - Alignment = 0 // FIXME this could be improved + Alignment = 0 // FIXME this will need to be improved once PartialReduxExpr is vectorized }; - EIGEN_DEVICE_FUNC explicit evaluator(const XprType expr) - : m_expr(expr) + EIGEN_DEVICE_FUNC explicit evaluator(const XprType xpr) + : m_arg(xpr.nestedExpression()), m_functor(xpr.functor()) {} typedef typename XprType::CoeffReturnType CoeffReturnType; - - EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index row, Index col) const - { - return m_expr.coeff(row, col); + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar coeff(Index i, Index j) const + { + if (Direction==Vertical) + return m_functor(m_arg.col(j)); + else + return m_functor(m_arg.row(i)); } - - EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index index) const - { - return m_expr.coeff(index); + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar coeff(Index index) const + { + if (Direction==Vertical) + return m_functor(m_arg.col(index)); + else + return m_functor(m_arg.row(index)); } protected: - const XprType m_expr; + const ArgTypeNested m_arg; + const MemberOp m_functor; }; diff --git a/Eigen/src/Core/VectorwiseOp.h b/Eigen/src/Core/VectorwiseOp.h index 79c7d135d..5de53732e 100644 --- a/Eigen/src/Core/VectorwiseOp.h +++ b/Eigen/src/Core/VectorwiseOp.h @@ -41,8 +41,6 @@ struct traits<PartialReduxExpr<MatrixType, MemberOp, Direction> > typedef typename traits<MatrixType>::StorageKind StorageKind; typedef typename traits<MatrixType>::XprKind XprKind; typedef typename MatrixType::Scalar InputScalar; - typedef typename ref_selector<MatrixType>::type MatrixTypeNested; - typedef typename remove_all<MatrixTypeNested>::type _MatrixTypeNested; enum { RowsAtCompileTime = Direction==Vertical ? 1 : MatrixType::RowsAtCompileTime, ColsAtCompileTime = Direction==Horizontal ? 1 : MatrixType::ColsAtCompileTime, @@ -62,8 +60,6 @@ class PartialReduxExpr : public internal::dense_xpr_base< PartialReduxExpr<Matri typedef typename internal::dense_xpr_base<PartialReduxExpr>::type Base; EIGEN_DENSE_PUBLIC_INTERFACE(PartialReduxExpr) - typedef typename internal::traits<PartialReduxExpr>::MatrixTypeNested MatrixTypeNested; - typedef typename internal::traits<PartialReduxExpr>::_MatrixTypeNested _MatrixTypeNested; EIGEN_DEVICE_FUNC explicit PartialReduxExpr(const MatrixType& mat, const MemberOp& func = MemberOp()) @@ -74,24 +70,11 @@ class PartialReduxExpr : public internal::dense_xpr_base< PartialReduxExpr<Matri EIGEN_DEVICE_FUNC Index cols() const { return (Direction==Horizontal ? 1 : m_matrix.cols()); } - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar coeff(Index i, Index j) const - { - if (Direction==Vertical) - return m_functor(m_matrix.col(j)); - else - return m_functor(m_matrix.row(i)); - } - - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar coeff(Index index) const - { - if (Direction==Vertical) - return m_functor(m_matrix.col(index)); - else - return m_functor(m_matrix.row(index)); - } + typename MatrixType::Nested nestedExpression() const { return m_matrix; } + const MemberOp& functor() const { return m_functor; } protected: - MatrixTypeNested m_matrix; + typename MatrixType::Nested m_matrix; const MemberOp m_functor; }; |