diff options
author | Jitse Niesen <jitse@maths.leeds.ac.uk> | 2011-04-04 15:35:14 +0100 |
---|---|---|
committer | Jitse Niesen <jitse@maths.leeds.ac.uk> | 2011-04-04 15:35:14 +0100 |
commit | ae06b8af5cd9828240b474363bfabbbab8202e23 (patch) | |
tree | 68f4ef8f576f29e368cb0e6d6a7b3cf0633296b7 /Eigen | |
parent | afdd26f2299fd64ca05174d6a25a3847bb3b9c1d (diff) |
Make evaluators for Matrix and Array inherit from common base class.
This gets rid of some code duplication.
Diffstat (limited to 'Eigen')
-rw-r--r-- | Eigen/src/Core/CoreEvaluators.h | 140 |
1 files changed, 52 insertions, 88 deletions
diff --git a/Eigen/src/Core/CoreEvaluators.h b/Eigen/src/Core/CoreEvaluators.h index 39ef6078c..6fe7177c6 100644 --- a/Eigen/src/Core/CoreEvaluators.h +++ b/Eigen/src/Core/CoreEvaluators.h @@ -106,128 +106,92 @@ protected: typename evaluator<ExpressionType>::type m_argImpl; }; -// -------------------- Matrix -------------------- +// -------------------- Matrix and Array-------------------- +// +// evaluator_impl<PlainObjectBase> is a common base class for the +// Matrix and Array evaluators. -template<typename Scalar, int Rows, int Cols, int Options, int MaxRows, int MaxCols> -struct evaluator_impl<Matrix<Scalar, Rows, Cols, Options, MaxRows, MaxCols> > +template<typename Derived> +struct evaluator_impl<PlainObjectBase<Derived> > { - typedef Matrix<Scalar, Rows, Cols, Options, MaxRows, MaxCols> MatrixType; + typedef PlainObjectBase<Derived> PlainObjectType; - evaluator_impl(const MatrixType& m) : m_matrix(m) {} + evaluator_impl(const PlainObjectType& m) : m_plainObject(m) {} - typedef typename MatrixType::Index Index; + typedef typename PlainObjectType::Index Index; + typedef typename PlainObjectType::Scalar Scalar; + typedef typename PlainObjectType::CoeffReturnType CoeffReturnType; + typedef typename PlainObjectType::PacketScalar PacketScalar; + typedef typename PlainObjectType::PacketReturnType PacketReturnType; - typename MatrixType::CoeffReturnType coeff(Index i, Index j) const + CoeffReturnType coeff(Index i, Index j) const { - return m_matrix.coeff(i, j); + return m_plainObject.coeff(i, j); } - typename MatrixType::CoeffReturnType coeff(Index index) const + CoeffReturnType coeff(Index index) const { - return m_matrix.coeff(index); + return m_plainObject.coeff(index); } - typename MatrixType::Scalar& coeffRef(Index i, Index j) + Scalar& coeffRef(Index i, Index j) { - return m_matrix.const_cast_derived().coeffRef(i, j); + return m_plainObject.const_cast_derived().coeffRef(i, j); } - typename MatrixType::Scalar& coeffRef(Index index) + Scalar& coeffRef(Index index) { - return m_matrix.const_cast_derived().coeffRef(index); + return m_plainObject.const_cast_derived().coeffRef(index); } template<int LoadMode> - typename MatrixType::PacketReturnType packet(Index row, Index col) const + PacketReturnType packet(Index row, Index col) const { - return m_matrix.template packet<LoadMode>(row, col); + return m_plainObject.template packet<LoadMode>(row, col); } template<int LoadMode> - typename MatrixType::PacketReturnType packet(Index index) const + PacketReturnType packet(Index index) const { - // eigen_internal_assert(index >= 0 && index < size()); - return m_matrix.template packet<LoadMode>(index); + return m_plainObject.template packet<LoadMode>(index); } template<int StoreMode> - void writePacket(Index row, Index col, const typename MatrixType::PacketScalar& x) + void writePacket(Index row, Index col, const PacketScalar& x) { - m_matrix.const_cast_derived().template writePacket<StoreMode>(row, col, x); + m_plainObject.const_cast_derived().template writePacket<StoreMode>(row, col, x); } template<int StoreMode> - void writePacket(Index index, const typename MatrixType::PacketScalar& x) + void writePacket(Index index, const PacketScalar& x) { - // eigen_internal_assert(index >= 0 && index < size()); - m_matrix.const_cast_derived().template writePacket<StoreMode>(index, x); + m_plainObject.const_cast_derived().template writePacket<StoreMode>(index, x); } protected: - const MatrixType &m_matrix; + const PlainObjectType &m_plainObject; }; -// -------------------- Array -------------------- +template<typename Scalar, int Rows, int Cols, int Options, int MaxRows, int MaxCols> +struct evaluator_impl<Matrix<Scalar, Rows, Cols, Options, MaxRows, MaxCols> > + : evaluator_impl<PlainObjectBase<Matrix<Scalar, Rows, Cols, Options, MaxRows, MaxCols> > > +{ + typedef Matrix<Scalar, Rows, Cols, Options, MaxRows, MaxCols> MatrixType; -// TODO: should be sharing code with Matrix case + evaluator_impl(const MatrixType& m) + : evaluator_impl<PlainObjectBase<MatrixType> >(m) + { } +}; template<typename Scalar, int Rows, int Cols, int Options, int MaxRows, int MaxCols> struct evaluator_impl<Array<Scalar, Rows, Cols, Options, MaxRows, MaxCols> > + : evaluator_impl<PlainObjectBase<Array<Scalar, Rows, Cols, Options, MaxRows, MaxCols> > > { typedef Array<Scalar, Rows, Cols, Options, MaxRows, MaxCols> ArrayType; - evaluator_impl(const ArrayType& a) : m_array(a) {} - - typedef typename ArrayType::Index Index; - - typename ArrayType::CoeffReturnType coeff(Index i, Index j) const - { - return m_array.coeff(i, j); - } - - typename ArrayType::CoeffReturnType coeff(Index index) const - { - return m_array.coeff(index); - } - - typename ArrayType::Scalar& coeffRef(Index i, Index j) - { - return m_array.const_cast_derived().coeffRef(i, j); - } - - typename ArrayType::Scalar& coeffRef(Index index) - { - return m_array.const_cast_derived().coeffRef(index); - } - - template<int LoadMode> - typename ArrayType::PacketReturnType packet(Index row, Index col) const - { - return m_array.template packet<LoadMode>(row, col); - } - - template<int LoadMode> - typename ArrayType::PacketReturnType packet(Index index) const - { - // eigen_internal_assert(index >= 0 && index < size()); - return m_array.template packet<LoadMode>(index); - } - - template<int StoreMode> - void writePacket(Index row, Index col, const typename ArrayType::PacketScalar& x) - { - m_array.const_cast_derived().template writePacket<StoreMode>(row, col, x); - } - - template<int StoreMode> - void writePacket(Index index, const typename ArrayType::PacketScalar& x) - { - // eigen_internal_assert(index >= 0 && index < size()); - m_array.const_cast_derived().template writePacket<StoreMode>(index, x); - } - -protected: - const ArrayType &m_array; + evaluator_impl(const ArrayType& m) + : evaluator_impl<PlainObjectBase<ArrayType> >(m) + { } }; // -------------------- CwiseNullaryOp -------------------- @@ -400,8 +364,8 @@ struct evaluator_impl<Block<XprType, BlockRows, BlockCols, InnerPanel, /* HasDir CoeffReturnType coeff(Index index) const { - return m_argImpl.coeff(m_startRow + (RowsAtCompileTime == 1 ? 0 : index), - m_startCol + (RowsAtCompileTime == 1 ? index : 0)); + return coeff(RowsAtCompileTime == 1 ? 0 : index, + RowsAtCompileTime == 1 ? index : 0); } Scalar& coeffRef(Index row, Index col) @@ -411,8 +375,8 @@ struct evaluator_impl<Block<XprType, BlockRows, BlockCols, InnerPanel, /* HasDir Scalar& coeffRef(Index index) { - return m_argImpl.coeffRef(m_startRow + (RowsAtCompileTime == 1 ? 0 : index), - m_startCol + (RowsAtCompileTime == 1 ? index : 0)); + return coeffRef(RowsAtCompileTime == 1 ? 0 : index, + RowsAtCompileTime == 1 ? index : 0); } template<int LoadMode> @@ -424,8 +388,8 @@ struct evaluator_impl<Block<XprType, BlockRows, BlockCols, InnerPanel, /* HasDir template<int LoadMode> PacketReturnType packet(Index index) const { - return m_argImpl.template packet<LoadMode>(m_startRow + (RowsAtCompileTime == 1 ? 0 : index), - m_startCol + (RowsAtCompileTime == 1 ? index : 0)); + return packet<LoadMode>(RowsAtCompileTime == 1 ? 0 : index, + RowsAtCompileTime == 1 ? index : 0); } template<int StoreMode> @@ -437,9 +401,9 @@ struct evaluator_impl<Block<XprType, BlockRows, BlockCols, InnerPanel, /* HasDir template<int StoreMode> void writePacket(Index index, const PacketScalar& x) { - return m_argImpl.template writePacket<StoreMode>(m_startRow + (RowsAtCompileTime == 1 ? 0 : index), - m_startCol + (RowsAtCompileTime == 1 ? index : 0), - x); + return writePacket<StoreMode>(RowsAtCompileTime == 1 ? 0 : index, + RowsAtCompileTime == 1 ? index : 0, + x); } protected: |