aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen
diff options
context:
space:
mode:
authorGravatar Jitse Niesen <jitse@maths.leeds.ac.uk>2011-04-04 15:35:14 +0100
committerGravatar Jitse Niesen <jitse@maths.leeds.ac.uk>2011-04-04 15:35:14 +0100
commitae06b8af5cd9828240b474363bfabbbab8202e23 (patch)
tree68f4ef8f576f29e368cb0e6d6a7b3cf0633296b7 /Eigen
parentafdd26f2299fd64ca05174d6a25a3847bb3b9c1d (diff)
Make evaluators for Matrix and Array inherit from common base class.
This gets rid of some code duplication.
Diffstat (limited to 'Eigen')
-rw-r--r--Eigen/src/Core/CoreEvaluators.h140
1 files changed, 52 insertions, 88 deletions
diff --git a/Eigen/src/Core/CoreEvaluators.h b/Eigen/src/Core/CoreEvaluators.h
index 39ef6078c..6fe7177c6 100644
--- a/Eigen/src/Core/CoreEvaluators.h
+++ b/Eigen/src/Core/CoreEvaluators.h
@@ -106,128 +106,92 @@ protected:
typename evaluator<ExpressionType>::type m_argImpl;
};
-// -------------------- Matrix --------------------
+// -------------------- Matrix and Array--------------------
+//
+// evaluator_impl<PlainObjectBase> is a common base class for the
+// Matrix and Array evaluators.
-template<typename Scalar, int Rows, int Cols, int Options, int MaxRows, int MaxCols>
-struct evaluator_impl<Matrix<Scalar, Rows, Cols, Options, MaxRows, MaxCols> >
+template<typename Derived>
+struct evaluator_impl<PlainObjectBase<Derived> >
{
- typedef Matrix<Scalar, Rows, Cols, Options, MaxRows, MaxCols> MatrixType;
+ typedef PlainObjectBase<Derived> PlainObjectType;
- evaluator_impl(const MatrixType& m) : m_matrix(m) {}
+ evaluator_impl(const PlainObjectType& m) : m_plainObject(m) {}
- typedef typename MatrixType::Index Index;
+ typedef typename PlainObjectType::Index Index;
+ typedef typename PlainObjectType::Scalar Scalar;
+ typedef typename PlainObjectType::CoeffReturnType CoeffReturnType;
+ typedef typename PlainObjectType::PacketScalar PacketScalar;
+ typedef typename PlainObjectType::PacketReturnType PacketReturnType;
- typename MatrixType::CoeffReturnType coeff(Index i, Index j) const
+ CoeffReturnType coeff(Index i, Index j) const
{
- return m_matrix.coeff(i, j);
+ return m_plainObject.coeff(i, j);
}
- typename MatrixType::CoeffReturnType coeff(Index index) const
+ CoeffReturnType coeff(Index index) const
{
- return m_matrix.coeff(index);
+ return m_plainObject.coeff(index);
}
- typename MatrixType::Scalar& coeffRef(Index i, Index j)
+ Scalar& coeffRef(Index i, Index j)
{
- return m_matrix.const_cast_derived().coeffRef(i, j);
+ return m_plainObject.const_cast_derived().coeffRef(i, j);
}
- typename MatrixType::Scalar& coeffRef(Index index)
+ Scalar& coeffRef(Index index)
{
- return m_matrix.const_cast_derived().coeffRef(index);
+ return m_plainObject.const_cast_derived().coeffRef(index);
}
template<int LoadMode>
- typename MatrixType::PacketReturnType packet(Index row, Index col) const
+ PacketReturnType packet(Index row, Index col) const
{
- return m_matrix.template packet<LoadMode>(row, col);
+ return m_plainObject.template packet<LoadMode>(row, col);
}
template<int LoadMode>
- typename MatrixType::PacketReturnType packet(Index index) const
+ PacketReturnType packet(Index index) const
{
- // eigen_internal_assert(index >= 0 && index < size());
- return m_matrix.template packet<LoadMode>(index);
+ return m_plainObject.template packet<LoadMode>(index);
}
template<int StoreMode>
- void writePacket(Index row, Index col, const typename MatrixType::PacketScalar& x)
+ void writePacket(Index row, Index col, const PacketScalar& x)
{
- m_matrix.const_cast_derived().template writePacket<StoreMode>(row, col, x);
+ m_plainObject.const_cast_derived().template writePacket<StoreMode>(row, col, x);
}
template<int StoreMode>
- void writePacket(Index index, const typename MatrixType::PacketScalar& x)
+ void writePacket(Index index, const PacketScalar& x)
{
- // eigen_internal_assert(index >= 0 && index < size());
- m_matrix.const_cast_derived().template writePacket<StoreMode>(index, x);
+ m_plainObject.const_cast_derived().template writePacket<StoreMode>(index, x);
}
protected:
- const MatrixType &m_matrix;
+ const PlainObjectType &m_plainObject;
};
-// -------------------- Array --------------------
+template<typename Scalar, int Rows, int Cols, int Options, int MaxRows, int MaxCols>
+struct evaluator_impl<Matrix<Scalar, Rows, Cols, Options, MaxRows, MaxCols> >
+ : evaluator_impl<PlainObjectBase<Matrix<Scalar, Rows, Cols, Options, MaxRows, MaxCols> > >
+{
+ typedef Matrix<Scalar, Rows, Cols, Options, MaxRows, MaxCols> MatrixType;
-// TODO: should be sharing code with Matrix case
+ evaluator_impl(const MatrixType& m)
+ : evaluator_impl<PlainObjectBase<MatrixType> >(m)
+ { }
+};
template<typename Scalar, int Rows, int Cols, int Options, int MaxRows, int MaxCols>
struct evaluator_impl<Array<Scalar, Rows, Cols, Options, MaxRows, MaxCols> >
+ : evaluator_impl<PlainObjectBase<Array<Scalar, Rows, Cols, Options, MaxRows, MaxCols> > >
{
typedef Array<Scalar, Rows, Cols, Options, MaxRows, MaxCols> ArrayType;
- evaluator_impl(const ArrayType& a) : m_array(a) {}
-
- typedef typename ArrayType::Index Index;
-
- typename ArrayType::CoeffReturnType coeff(Index i, Index j) const
- {
- return m_array.coeff(i, j);
- }
-
- typename ArrayType::CoeffReturnType coeff(Index index) const
- {
- return m_array.coeff(index);
- }
-
- typename ArrayType::Scalar& coeffRef(Index i, Index j)
- {
- return m_array.const_cast_derived().coeffRef(i, j);
- }
-
- typename ArrayType::Scalar& coeffRef(Index index)
- {
- return m_array.const_cast_derived().coeffRef(index);
- }
-
- template<int LoadMode>
- typename ArrayType::PacketReturnType packet(Index row, Index col) const
- {
- return m_array.template packet<LoadMode>(row, col);
- }
-
- template<int LoadMode>
- typename ArrayType::PacketReturnType packet(Index index) const
- {
- // eigen_internal_assert(index >= 0 && index < size());
- return m_array.template packet<LoadMode>(index);
- }
-
- template<int StoreMode>
- void writePacket(Index row, Index col, const typename ArrayType::PacketScalar& x)
- {
- m_array.const_cast_derived().template writePacket<StoreMode>(row, col, x);
- }
-
- template<int StoreMode>
- void writePacket(Index index, const typename ArrayType::PacketScalar& x)
- {
- // eigen_internal_assert(index >= 0 && index < size());
- m_array.const_cast_derived().template writePacket<StoreMode>(index, x);
- }
-
-protected:
- const ArrayType &m_array;
+ evaluator_impl(const ArrayType& m)
+ : evaluator_impl<PlainObjectBase<ArrayType> >(m)
+ { }
};
// -------------------- CwiseNullaryOp --------------------
@@ -400,8 +364,8 @@ struct evaluator_impl<Block<XprType, BlockRows, BlockCols, InnerPanel, /* HasDir
CoeffReturnType coeff(Index index) const
{
- return m_argImpl.coeff(m_startRow + (RowsAtCompileTime == 1 ? 0 : index),
- m_startCol + (RowsAtCompileTime == 1 ? index : 0));
+ return coeff(RowsAtCompileTime == 1 ? 0 : index,
+ RowsAtCompileTime == 1 ? index : 0);
}
Scalar& coeffRef(Index row, Index col)
@@ -411,8 +375,8 @@ struct evaluator_impl<Block<XprType, BlockRows, BlockCols, InnerPanel, /* HasDir
Scalar& coeffRef(Index index)
{
- return m_argImpl.coeffRef(m_startRow + (RowsAtCompileTime == 1 ? 0 : index),
- m_startCol + (RowsAtCompileTime == 1 ? index : 0));
+ return coeffRef(RowsAtCompileTime == 1 ? 0 : index,
+ RowsAtCompileTime == 1 ? index : 0);
}
template<int LoadMode>
@@ -424,8 +388,8 @@ struct evaluator_impl<Block<XprType, BlockRows, BlockCols, InnerPanel, /* HasDir
template<int LoadMode>
PacketReturnType packet(Index index) const
{
- return m_argImpl.template packet<LoadMode>(m_startRow + (RowsAtCompileTime == 1 ? 0 : index),
- m_startCol + (RowsAtCompileTime == 1 ? index : 0));
+ return packet<LoadMode>(RowsAtCompileTime == 1 ? 0 : index,
+ RowsAtCompileTime == 1 ? index : 0);
}
template<int StoreMode>
@@ -437,9 +401,9 @@ struct evaluator_impl<Block<XprType, BlockRows, BlockCols, InnerPanel, /* HasDir
template<int StoreMode>
void writePacket(Index index, const PacketScalar& x)
{
- return m_argImpl.template writePacket<StoreMode>(m_startRow + (RowsAtCompileTime == 1 ? 0 : index),
- m_startCol + (RowsAtCompileTime == 1 ? index : 0),
- x);
+ return writePacket<StoreMode>(RowsAtCompileTime == 1 ? 0 : index,
+ RowsAtCompileTime == 1 ? index : 0,
+ x);
}
protected: