aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen/src/Core
diff options
context:
space:
mode:
authorGravatar Gael Guennebaud <g.gael@free.fr>2015-10-08 15:57:05 +0200
committerGravatar Gael Guennebaud <g.gael@free.fr>2015-10-08 15:57:05 +0200
commitaa6b1aebf373fba262fab7cd833881eac4fed8ef (patch)
treea2c0c211837e1600f69f23da27aa5c5c705ed593 /Eigen/src/Core
parent5cc7251188110f2d24425eac3ce00d051d2b2c55 (diff)
Properly implement PartialReduxExpr on top of evaluators, and fix multiple evaluation of nested expression
Diffstat (limited to 'Eigen/src/Core')
-rw-r--r--Eigen/src/Core/CoreEvaluators.h40
-rw-r--r--Eigen/src/Core/VectorwiseOp.h23
2 files changed, 26 insertions, 37 deletions
diff --git a/Eigen/src/Core/CoreEvaluators.h b/Eigen/src/Core/CoreEvaluators.h
index 214114ebe..b96ef99fa 100644
--- a/Eigen/src/Core/CoreEvaluators.h
+++ b/Eigen/src/Core/CoreEvaluators.h
@@ -965,17 +965,16 @@ protected:
// -------------------- PartialReduxExpr --------------------
-//
-// This is a wrapper around the expression object.
-// TODO: Find out how to write a proper evaluator without duplicating
-// the row() and col() member functions.
template< typename ArgType, typename MemberOp, int Direction>
struct evaluator<PartialReduxExpr<ArgType, MemberOp, Direction> >
: evaluator_base<PartialReduxExpr<ArgType, MemberOp, Direction> >
{
typedef PartialReduxExpr<ArgType, MemberOp, Direction> XprType;
- typedef typename XprType::Scalar InputScalar;
+ typedef typename internal::nested_eval<ArgType,1>::type ArgTypeNested;
+ typedef typename internal::remove_all<ArgTypeNested>::type ArgTypeNestedCleaned;
+ typedef typename ArgType::Scalar InputScalar;
+ typedef typename XprType::Scalar Scalar;
enum {
TraversalSize = Direction==int(Vertical) ? int(ArgType::RowsAtCompileTime) : int(XprType::ColsAtCompileTime)
};
@@ -986,27 +985,34 @@ struct evaluator<PartialReduxExpr<ArgType, MemberOp, Direction> >
Flags = (traits<XprType>::Flags&RowMajorBit) | (evaluator<ArgType>::Flags&HereditaryBits),
- Alignment = 0 // FIXME this could be improved
+ Alignment = 0 // FIXME this will need to be improved once PartialReduxExpr is vectorized
};
- EIGEN_DEVICE_FUNC explicit evaluator(const XprType expr)
- : m_expr(expr)
+ EIGEN_DEVICE_FUNC explicit evaluator(const XprType xpr)
+ : m_arg(xpr.nestedExpression()), m_functor(xpr.functor())
{}
typedef typename XprType::CoeffReturnType CoeffReturnType;
-
- EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index row, Index col) const
- {
- return m_expr.coeff(row, col);
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar coeff(Index i, Index j) const
+ {
+ if (Direction==Vertical)
+ return m_functor(m_arg.col(j));
+ else
+ return m_functor(m_arg.row(i));
}
-
- EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index index) const
- {
- return m_expr.coeff(index);
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar coeff(Index index) const
+ {
+ if (Direction==Vertical)
+ return m_functor(m_arg.col(index));
+ else
+ return m_functor(m_arg.row(index));
}
protected:
- const XprType m_expr;
+ const ArgTypeNested m_arg;
+ const MemberOp m_functor;
};
diff --git a/Eigen/src/Core/VectorwiseOp.h b/Eigen/src/Core/VectorwiseOp.h
index 79c7d135d..5de53732e 100644
--- a/Eigen/src/Core/VectorwiseOp.h
+++ b/Eigen/src/Core/VectorwiseOp.h
@@ -41,8 +41,6 @@ struct traits<PartialReduxExpr<MatrixType, MemberOp, Direction> >
typedef typename traits<MatrixType>::StorageKind StorageKind;
typedef typename traits<MatrixType>::XprKind XprKind;
typedef typename MatrixType::Scalar InputScalar;
- typedef typename ref_selector<MatrixType>::type MatrixTypeNested;
- typedef typename remove_all<MatrixTypeNested>::type _MatrixTypeNested;
enum {
RowsAtCompileTime = Direction==Vertical ? 1 : MatrixType::RowsAtCompileTime,
ColsAtCompileTime = Direction==Horizontal ? 1 : MatrixType::ColsAtCompileTime,
@@ -62,8 +60,6 @@ class PartialReduxExpr : public internal::dense_xpr_base< PartialReduxExpr<Matri
typedef typename internal::dense_xpr_base<PartialReduxExpr>::type Base;
EIGEN_DENSE_PUBLIC_INTERFACE(PartialReduxExpr)
- typedef typename internal::traits<PartialReduxExpr>::MatrixTypeNested MatrixTypeNested;
- typedef typename internal::traits<PartialReduxExpr>::_MatrixTypeNested _MatrixTypeNested;
EIGEN_DEVICE_FUNC
explicit PartialReduxExpr(const MatrixType& mat, const MemberOp& func = MemberOp())
@@ -74,24 +70,11 @@ class PartialReduxExpr : public internal::dense_xpr_base< PartialReduxExpr<Matri
EIGEN_DEVICE_FUNC
Index cols() const { return (Direction==Horizontal ? 1 : m_matrix.cols()); }
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar coeff(Index i, Index j) const
- {
- if (Direction==Vertical)
- return m_functor(m_matrix.col(j));
- else
- return m_functor(m_matrix.row(i));
- }
-
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar coeff(Index index) const
- {
- if (Direction==Vertical)
- return m_functor(m_matrix.col(index));
- else
- return m_functor(m_matrix.row(index));
- }
+ typename MatrixType::Nested nestedExpression() const { return m_matrix; }
+ const MemberOp& functor() const { return m_functor; }
protected:
- MatrixTypeNested m_matrix;
+ typename MatrixType::Nested m_matrix;
const MemberOp m_functor;
};