diff options
author | Jitse Niesen <jitse@maths.leeds.ac.uk> | 2011-04-12 22:54:31 +0100 |
---|---|---|
committer | Jitse Niesen <jitse@maths.leeds.ac.uk> | 2011-04-12 22:54:31 +0100 |
commit | 11164830f54d0a8db32f1274f9c971115b5b2eee (patch) | |
tree | 36b3b40f0fbf6c715204d028af793b6b392b0f4f | |
parent | 12a30a982feab745d36d647ab88dfb0a51da2213 (diff) |
Implement evaluator for Replicate.
-rw-r--r-- | Eigen/src/Core/CoreEvaluators.h | 50 | ||||
-rw-r--r-- | Eigen/src/Core/Replicate.h | 4 | ||||
-rw-r--r-- | test/evaluators.cpp | 9 |
3 files changed, 63 insertions, 0 deletions
diff --git a/Eigen/src/Core/CoreEvaluators.h b/Eigen/src/Core/CoreEvaluators.h index 756ebde1e..ef6913add 100644 --- a/Eigen/src/Core/CoreEvaluators.h +++ b/Eigen/src/Core/CoreEvaluators.h @@ -615,6 +615,56 @@ protected: }; +// -------------------- Replicate -------------------- + +template<typename XprType, int RowFactor, int ColFactor> +struct evaluator_impl<Replicate<XprType, RowFactor, ColFactor> > +{ + typedef Replicate<XprType, RowFactor, ColFactor> ReplicateType; + + evaluator_impl(const ReplicateType& replicate) + : m_argImpl(replicate.nestedExpression()), + m_rows(replicate.nestedExpression().rows()), + m_cols(replicate.nestedExpression().cols()) + { } + + typedef typename ReplicateType::Index Index; + typedef typename ReplicateType::CoeffReturnType CoeffReturnType; + typedef typename ReplicateType::PacketReturnType PacketReturnType; + + CoeffReturnType coeff(Index row, Index col) const + { + // try to avoid using modulo; this is a pure optimization strategy + const Index actual_row = internal::traits<XprType>::RowsAtCompileTime==1 ? 0 + : RowFactor==1 ? row + : row % m_rows; + const Index actual_col = internal::traits<XprType>::ColsAtCompileTime==1 ? 0 + : ColFactor==1 ? col + : col % m_cols; + + return m_argImpl.coeff(actual_row, actual_col); + } + + template<int LoadMode> + PacketReturnType packet(Index row, Index col) const + { + const Index actual_row = internal::traits<XprType>::RowsAtCompileTime==1 ? 0 + : RowFactor==1 ? row + : row % m_rows; + const Index actual_col = internal::traits<XprType>::ColsAtCompileTime==1 ? 0 + : ColFactor==1 ? col + : col % m_cols; + + return m_argImpl.template packet<LoadMode>(actual_row, actual_col); + } + +protected: + typename evaluator<XprType>::type m_argImpl; + Index m_rows; + Index m_cols; +}; + + } // namespace internal #endif // EIGEN_COREEVALUATORS_H diff --git a/Eigen/src/Core/Replicate.h b/Eigen/src/Core/Replicate.h index d2f9712db..2acb8ab77 100644 --- a/Eigen/src/Core/Replicate.h +++ b/Eigen/src/Core/Replicate.h @@ -122,6 +122,10 @@ template<typename MatrixType,int RowFactor,int ColFactor> class Replicate return m_matrix.template packet<LoadMode>(actual_row, actual_col); } + const typename internal::remove_all<typename MatrixType::Nested>::type& nestedExpression() const + { + return m_matrix; + } protected: const typename MatrixType::Nested m_matrix; diff --git a/test/evaluators.cpp b/test/evaluators.cpp index fc6fda557..8e0dadb06 100644 --- a/test/evaluators.cpp +++ b/test/evaluators.cpp @@ -167,4 +167,13 @@ void test_evaluators() // test Select VERIFY_IS_APPROX_EVALUATOR(aX, (aXsrc > 0).select(aXsrc, -aXsrc)); + + // test Replicate + mXsrc = MatrixXf::Random(6, 6); + VectorXf vX = VectorXf::Random(6); + mX.resize(6, 6); + VERIFY_IS_APPROX_EVALUATOR(mX, mXsrc.colwise() + vX); + matXcd.resize(12, 12); + VERIFY_IS_APPROX_EVALUATOR(matXcd, matXcd_ref.replicate(2,2)); + VERIFY_IS_APPROX_EVALUATOR(matXcd, (matXcd_ref.replicate<2,2>())); } |