diff options
author | Jitse Niesen <jitse@maths.leeds.ac.uk> | 2011-04-22 22:36:45 +0100 |
---|---|---|
committer | Jitse Niesen <jitse@maths.leeds.ac.uk> | 2011-04-22 22:36:45 +0100 |
commit | bb2d70d211a8fc8184b690b75d29ba484edace0e (patch) | |
tree | 2325d715f307dc558f32e560951939d91aa37f11 | |
parent | 6441e8727b32c6cbb194c0ce1bbd784c2a24a2b2 (diff) |
Implement evaluators for ArrayWrapper and MatrixWrapper.
-rw-r--r-- | Eigen/src/Core/ArrayWrapper.h | 12 | ||||
-rw-r--r-- | Eigen/src/Core/CoreEvaluators.h | 85 | ||||
-rw-r--r-- | Eigen/src/Core/util/ForwardDeclarations.h | 1 | ||||
-rw-r--r-- | test/evaluators.cpp | 12 |
4 files changed, 109 insertions, 1 deletions
diff --git a/Eigen/src/Core/ArrayWrapper.h b/Eigen/src/Core/ArrayWrapper.h index 7ba01de36..6c7e2b198 100644 --- a/Eigen/src/Core/ArrayWrapper.h +++ b/Eigen/src/Core/ArrayWrapper.h @@ -119,6 +119,12 @@ class ArrayWrapper : public ArrayBase<ArrayWrapper<ExpressionType> > template<typename Dest> inline void evalTo(Dest& dst) const { dst = m_expression; } + const typename internal::remove_all<NestedExpressionType>::type& + nestedExpression() const + { + return m_expression; + } + protected: const NestedExpressionType m_expression; }; @@ -214,6 +220,12 @@ class MatrixWrapper : public MatrixBase<MatrixWrapper<ExpressionType> > m_expression.const_cast_derived().template writePacket<LoadMode>(index, x); } + const typename internal::remove_all<NestedExpressionType>::type& + nestedExpression() const + { + return m_expression; + } + protected: const NestedExpressionType m_expression; }; diff --git a/Eigen/src/Core/CoreEvaluators.h b/Eigen/src/Core/CoreEvaluators.h index 6b08c78a0..47835f576 100644 --- a/Eigen/src/Core/CoreEvaluators.h +++ b/Eigen/src/Core/CoreEvaluators.h @@ -106,7 +106,7 @@ protected: typename evaluator<ExpressionType>::type m_argImpl; }; -// -------------------- Matrix and Array-------------------- +// -------------------- Matrix and Array -------------------- // // evaluator_impl<PlainObjectBase> is a common base class for the // Matrix and Array evaluators. @@ -704,6 +704,89 @@ protected: }; +// -------------------- MatrixWrapper and ArrayWrapper -------------------- +// +// evaluator_impl_wrapper_base<T> is a common base class for the +// MatrixWrapper and ArrayWrapper evaluators. + +template<typename ArgType> +struct evaluator_impl_wrapper_base +{ + evaluator_impl_wrapper_base(const ArgType& arg) : m_argImpl(arg) {} + + typedef typename ArgType::Index Index; + typedef typename ArgType::Scalar Scalar; + typedef typename ArgType::CoeffReturnType CoeffReturnType; + typedef typename ArgType::PacketScalar PacketScalar; + typedef typename ArgType::PacketReturnType PacketReturnType; + + CoeffReturnType coeff(Index row, Index col) const + { + return m_argImpl.coeff(row, col); + } + + CoeffReturnType coeff(Index index) const + { + return m_argImpl.coeff(index); + } + + Scalar& coeffRef(Index row, Index col) + { + return m_argImpl.coeffRef(row, col); + } + + Scalar& coeffRef(Index index) + { + return m_argImpl.coeffRef(index); + } + + template<int LoadMode> + PacketReturnType packet(Index row, Index col) const + { + return m_argImpl.template packet<LoadMode>(row, col); + } + + template<int LoadMode> + PacketReturnType packet(Index index) const + { + return m_argImpl.template packet<LoadMode>(index); + } + + template<int StoreMode> + void writePacket(Index row, Index col, const PacketScalar& x) + { + m_argImpl.template writePacket<StoreMode>(row, col, x); + } + + template<int StoreMode> + void writePacket(Index index, const PacketScalar& x) + { + m_argImpl.template writePacket<StoreMode>(index, x); + } + +protected: + typename evaluator<ArgType>::type m_argImpl; +}; + +template<typename ArgType> +struct evaluator_impl<MatrixWrapper<ArgType> > + : evaluator_impl_wrapper_base<ArgType> +{ + evaluator_impl(const MatrixWrapper<ArgType>& wrapper) + : evaluator_impl_wrapper_base<ArgType>(wrapper.nestedExpression()) + { } +}; + +template<typename ArgType> +struct evaluator_impl<ArrayWrapper<ArgType> > + : evaluator_impl_wrapper_base<ArgType> +{ + evaluator_impl(const ArrayWrapper<ArgType>& wrapper) + : evaluator_impl_wrapper_base<ArgType>(wrapper.nestedExpression()) + { } +}; + + } // namespace internal #endif // EIGEN_COREEVALUATORS_H diff --git a/Eigen/src/Core/util/ForwardDeclarations.h b/Eigen/src/Core/util/ForwardDeclarations.h index 7fbccf98c..ce784daed 100644 --- a/Eigen/src/Core/util/ForwardDeclarations.h +++ b/Eigen/src/Core/util/ForwardDeclarations.h @@ -133,6 +133,7 @@ template<typename ExpressionType> class WithFormat; template<typename MatrixType> struct CommaInitializer; template<typename Derived> class ReturnByValue; template<typename ExpressionType> class ArrayWrapper; +template<typename ExpressionType> class MatrixWrapper; namespace internal { template<typename DecompositionType, typename Rhs> struct solve_retval_base; diff --git a/test/evaluators.cpp b/test/evaluators.cpp index 4c55736eb..da6b9064b 100644 --- a/test/evaluators.cpp +++ b/test/evaluators.cpp @@ -180,4 +180,16 @@ void test_evaluators() VectorXd vec1(6); VERIFY_IS_APPROX_EVALUATOR(vec1, mat1.rowwise().sum()); VERIFY_IS_APPROX_EVALUATOR(vec1, mat1.colwise().sum().transpose()); + + // test MatrixWrapper and ArrayWrapper + mat1.setRandom(6,6); + arr1.setRandom(6,6); + VERIFY_IS_APPROX_EVALUATOR(mat2, arr1.matrix()); + VERIFY_IS_APPROX_EVALUATOR(arr2, mat1.array()); + VERIFY_IS_APPROX_EVALUATOR(mat2, (arr1 + 2).matrix()); + VERIFY_IS_APPROX_EVALUATOR(arr2, mat1.array() + 2); + mat2.array() = arr1 * arr1; + VERIFY_IS_APPROX(mat2, (arr1 * arr1).matrix()); + arr2.matrix() = MatrixXd::Identity(6,6); + VERIFY_IS_APPROX(arr2, MatrixXd::Identity(6,6).array()); } |