diff options
-rw-r--r-- | src/MatrixBase.h | 98 | ||||
-rw-r--r-- | src/MatrixXpr.h | 5 | ||||
-rw-r--r-- | src/ScalarOps.h | 24 | ||||
-rw-r--r-- | src/Util.h | 9 | ||||
-rw-r--r-- | test/matrixmanip.cpp | 1 | ||||
-rw-r--r-- | test/matrixops.cpp | 2 | ||||
-rw-r--r-- | test/vectorops.cpp | 2 |
7 files changed, 117 insertions, 24 deletions
diff --git a/src/MatrixBase.h b/src/MatrixBase.h index bcc747aca..860b9272d 100644 --- a/src/MatrixBase.h +++ b/src/MatrixBase.h @@ -67,6 +67,7 @@ template<typename MatrixType> class MatrixRef { public: typedef typename ForwardDecl<MatrixType>::Scalar Scalar; + typedef MatrixXpr<MatrixRef<MatrixType> > Xpr; MatrixRef(MatrixType& matrix) : m_matrix(matrix) {} MatrixRef(const MatrixRef& other) : m_matrix(other.m_matrix) {} @@ -92,6 +93,11 @@ template<typename MatrixType> class MatrixRef MatrixType& matrix() { return m_matrix; } + Xpr xpr() + { + return Xpr(*this); + } + protected: MatrixType& m_matrix; }; @@ -106,6 +112,7 @@ class MatrixBase typedef MatrixRef<MatrixBase<Derived> > Ref; typedef MatrixConstXpr<ConstRef> ConstXpr; typedef MatrixXpr<Ref> Xpr; + typedef MatrixAlias<Derived> Alias; Ref ref() { @@ -127,6 +134,8 @@ class MatrixBase return ConstXpr(constRef()); } + Alias alias(); + static bool hasDynamicNumRows() { return Derived::_hasDynamicNumRows(); @@ -175,29 +184,16 @@ class MatrixBase } template<typename XprContent> - void operator=(const MatrixConstXpr<XprContent> &xpr) - { - resize(xpr.rows(), xpr.cols()); - for(int i = 0; i < rows(); i++) - for(int j = 0; j < cols(); j++) - this->operator()(i, j) = xpr(i, j); - } - - void operator=(const MatrixBase &other) + MatrixBase& operator=(const MatrixConstXpr<XprContent> &otherXpr) { - resize(other.rows(), other.cols()); - for(int i = 0; i < rows(); i++) - for(int j = 0; j < cols(); j++) - this->operator()(i, j) = other(i, j); + resize(otherXpr.rows(), otherXpr.cols()); + xpr() = otherXpr; + return *this; } - template<typename XprContent> - void operator<<(const MatrixConstXpr<XprContent> &xpr) + MatrixBase& operator=(const MatrixBase &other) { - Derived tmp(xpr.rows(), xpr.cols()); - MatrixBase *ptr = static_cast<MatrixBase*>(&tmp); - *ptr = xpr; - *this = *ptr; + return *this = other.constXpr(); } MatrixConstXpr<MatrixRow<const ConstRef> > row(int i) const; @@ -213,12 +209,13 @@ class MatrixBase template<typename Content> template<typename Derived> -void MatrixXpr<Content>::operator=(const MatrixBase<Derived>& matrix) +MatrixXpr<Content>& MatrixXpr<Content>::operator=(const MatrixBase<Derived>& matrix) { assert(rows() == matrix.rows() && cols() == matrix.cols()); for(int i = 0; i < rows(); i++) for(int j = 0; j < cols(); j++) this->operator()(i, j) = matrix(i, j); + return *this; } template<typename Derived> @@ -252,6 +249,67 @@ std::ostream & operator << (std::ostream & s, return s; } +template<typename Derived> class MatrixAlias +{ + public: + typedef typename Derived::Scalar Scalar; + typedef MatrixRef<MatrixAlias<Derived> > Ref; + typedef MatrixXpr<Ref> Xpr; + + MatrixAlias(Derived& matrix) : m_ref(matrix), m_tmp(matrix) {} + MatrixAlias(const MatrixAlias& other) : m_ref(other.m_ref), m_tmp(other.m_tmp) {} + + ~MatrixAlias() + { + m_ref.xpr() = m_tmp; + } + + Xpr xpr() + { + return Xpr(ref()); + } + + static bool hasDynamicNumRows() + { + return MatrixBase<Derived>::hasDynamicNumRows(); + } + + static bool hasDynamicNumCols() + { + return MatrixBase<Derived>::hasDynamicNumCols(); + } + + int rows() const { return m_tmp.rows(); } + int cols() const { return m_tmp.cols(); } + + Scalar& operator()(int row, int col) + { + return m_tmp(row, col); + } + + Ref ref() + { + return Ref(*this); + } + + template<typename XprContent> + void operator=(const MatrixConstXpr<XprContent> &xpr) + { + ref().xpr() = xpr; + } + + protected: + MatrixRef<MatrixBase<Derived> > m_ref; + Derived m_tmp; +}; + +template<typename Derived> +typename MatrixBase<Derived>::Alias +MatrixBase<Derived>::alias() +{ + return Alias(*static_cast<Derived*>(this)); +} + } // namespace Eigen #endif // EIGEN_MATRIXBASE_H diff --git a/src/MatrixXpr.h b/src/MatrixXpr.h index 9a1ab1085..937d4e280 100644 --- a/src/MatrixXpr.h +++ b/src/MatrixXpr.h @@ -97,16 +97,17 @@ template<typename Content> class MatrixXpr } template<typename OtherContent> - void operator=(const MatrixConstXpr<OtherContent> &other) + MatrixXpr& operator=(const MatrixConstXpr<OtherContent> &other) { assert(rows() == other.rows() && cols() == other.cols()); for(int i = 0; i < rows(); i++) for(int j = 0; j < cols(); j++) this->operator()(i, j) = other(i, j); + return *this; } template<typename Derived> - void operator=(const MatrixBase<Derived>& matrix); + MatrixXpr& operator=(const MatrixBase<Derived>& matrix); MatrixXpr<MatrixRow<MatrixXpr<Content> > > row(int i); MatrixXpr<MatrixCol<MatrixXpr<Content> > > col(int i); diff --git a/src/ScalarOps.h b/src/ScalarOps.h index 26698c751..5d0a16ec9 100644 --- a/src/ScalarOps.h +++ b/src/ScalarOps.h @@ -116,6 +116,30 @@ operator *(typename Derived::Scalar scalar, return XprType(ProductType(matrix.constRef(), scalar)); } +template<typename Content> +const MatrixConstXpr< + const ScalarProduct< + MatrixConstXpr<Content> + > +> +operator /(const MatrixConstXpr<Content>& xpr, + typename Content::Scalar scalar) +{ + return xpr * (static_cast<typename Content::Scalar>(1) / scalar); +} + +template<typename Derived> +const MatrixConstXpr< + const ScalarProduct< + MatrixConstRef<MatrixBase<Derived> > + > +> +operator /(const MatrixBase<Derived>& matrix, + typename Derived::Scalar scalar) +{ + return matrix * (static_cast<typename Derived::Scalar>(1) / scalar); +} + } // namespace Eigen #endif // EIGEN_SCALAROPS_H diff --git a/src/Util.h b/src/Util.h index b789b891e..f1b51bbe1 100644 --- a/src/Util.h +++ b/src/Util.h @@ -48,6 +48,7 @@ template<typename T> class MatrixX; template<typename T, int Size> class Vector; template<typename T> class VectorX; template<typename Derived> class MatrixBase; +template<typename Derived> class MatrixAlias; template<typename T> struct ForwardDecl; template<typename T, int Rows, int Cols> struct ForwardDecl< Matrix<T, Rows, Cols> > @@ -60,12 +61,20 @@ template<typename T> struct ForwardDecl< VectorX<T> > { typedef T Scalar; }; template<typename T, int Rows, int Cols> struct ForwardDecl< MatrixBase<Matrix<T, Rows, Cols> > > { typedef T Scalar; }; +template<typename T, int Rows, int Cols> struct ForwardDecl< MatrixAlias<Matrix<T, Rows, Cols> > > +{ typedef T Scalar; }; template<typename T> struct ForwardDecl< MatrixBase<MatrixX<T> > > { typedef T Scalar; }; +template<typename T> struct ForwardDecl< MatrixAlias<MatrixX<T> > > +{ typedef T Scalar; }; template<typename T, int Size> struct ForwardDecl< MatrixBase<Vector<T, Size> > > { typedef T Scalar; }; +template<typename T, int Size> struct ForwardDecl< MatrixAlias<Vector<T, Size> > > +{ typedef T Scalar; }; template<typename T> struct ForwardDecl< MatrixBase<VectorX<T> > > { typedef T Scalar; }; +template<typename T> struct ForwardDecl< MatrixAlias<VectorX<T> > > +{ typedef T Scalar; }; template<typename MatrixType> class MatrixRef; diff --git a/test/matrixmanip.cpp b/test/matrixmanip.cpp index 153af365c..5ce8fa615 100644 --- a/test/matrixmanip.cpp +++ b/test/matrixmanip.cpp @@ -37,6 +37,7 @@ template<typename MatrixType> void matrixManip(const MatrixType& m) a.block(1, rows-1, 1, cols-1); a.xpr().row(i) = b.row(i); a.xpr().minor(i, j) = b.block(1, rows-1, 1, cols-1); + a.alias().xpr().minor(i, j) = a.block(1, rows-1, 1, cols-1); } void EigenTest::testMatrixManip() diff --git a/test/matrixops.cpp b/test/matrixops.cpp index 6e95f7893..a9034dcf2 100644 --- a/test/matrixops.cpp +++ b/test/matrixops.cpp @@ -44,7 +44,7 @@ template<typename MatrixType1, a = b; a = b + c; a = s * (b - c); - a << a + b; + a.alias() = a + b; MatrixType1 d(rows1, cols1); MatrixType2 e(rows2, cols2); diff --git a/test/vectorops.cpp b/test/vectorops.cpp index abd5ef0c2..b4b638495 100644 --- a/test/vectorops.cpp +++ b/test/vectorops.cpp @@ -42,7 +42,7 @@ template<typename VectorType> void vectorOps(const VectorType& v) a = b; a = b + c; a = s * (b - c); - a << a + b; + a.alias() = a + b; } void EigenTest::testVectorOps() |