diff options
Diffstat (limited to 'Eigen/Core/OperatorEquals.h')
-rw-r--r-- | Eigen/Core/OperatorEquals.h | 79 |
1 files changed, 64 insertions, 15 deletions
diff --git a/Eigen/Core/OperatorEquals.h b/Eigen/Core/OperatorEquals.h index 8206d280f..efda4c033 100644 --- a/Eigen/Core/OperatorEquals.h +++ b/Eigen/Core/OperatorEquals.h @@ -28,27 +28,27 @@ #define EIGEN_OPERATOREQUALS_H template<typename Derived1, typename Derived2, int UnrollCount, int Rows> -struct OperatorEqualsUnroller +struct MatrixOperatorEqualsUnroller { static const int col = (UnrollCount-1) / Rows; static const int row = (UnrollCount-1) % Rows; static void run(Derived1 &dst, const Derived2 &src) { - OperatorEqualsUnroller<Derived1, Derived2, UnrollCount-1, Rows>::run(dst, src); + MatrixOperatorEqualsUnroller<Derived1, Derived2, UnrollCount-1, Rows>::run(dst, src); dst.coeffRef(row, col) = src.coeff(row, col); } }; // prevent buggy user code from causing an infinite recursion template<typename Derived1, typename Derived2, int UnrollCount> -struct OperatorEqualsUnroller<Derived1, Derived2, UnrollCount, 0> +struct MatrixOperatorEqualsUnroller<Derived1, Derived2, UnrollCount, 0> { static void run(Derived1 &, const Derived2 &) {} }; template<typename Derived1, typename Derived2, int Rows> -struct OperatorEqualsUnroller<Derived1, Derived2, 1, Rows> +struct MatrixOperatorEqualsUnroller<Derived1, Derived2, 1, Rows> { static void run(Derived1 &dst, const Derived2 &src) { @@ -57,7 +57,41 @@ struct OperatorEqualsUnroller<Derived1, Derived2, 1, Rows> }; template<typename Derived1, typename Derived2, int Rows> -struct OperatorEqualsUnroller<Derived1, Derived2, Dynamic, Rows> +struct MatrixOperatorEqualsUnroller<Derived1, Derived2, Dynamic, Rows> +{ + static void run(Derived1 &, const Derived2 &) {} +}; + +template<typename Derived1, typename Derived2, int UnrollCount> +struct VectorOperatorEqualsUnroller +{ + static const int index = UnrollCount - 1; + + static void run(Derived1 &dst, const Derived2 &src) + { + VectorOperatorEqualsUnroller<Derived1, Derived2, UnrollCount-1>::run(dst, src); + dst.coeffRef(index) = src.coeff(index); + } +}; + +// prevent buggy user code from causing an infinite recursion +template<typename Derived1, typename Derived2> +struct VectorOperatorEqualsUnroller<Derived1, Derived2, 0> +{ + static void run(Derived1 &, const Derived2 &) {} +}; + +template<typename Derived1, typename Derived2> +struct VectorOperatorEqualsUnroller<Derived1, Derived2, 1> +{ + static void run(Derived1 &dst, const Derived2 &src) + { + dst.coeffRef(0) = src.coeff(0); + } +}; + +template<typename Derived1, typename Derived2> +struct VectorOperatorEqualsUnroller<Derived1, Derived2, Dynamic> { static void run(Derived1 &, const Derived2 &) {} }; @@ -67,16 +101,31 @@ template<typename OtherDerived> Derived& MatrixBase<Scalar, Derived> ::operator=(const MatrixBase<Scalar, OtherDerived>& other) { - assert(rows() == other.rows() && cols() == other.cols()); - if(EIGEN_UNROLLED_LOOPS && SizeAtCompileTime != Dynamic && SizeAtCompileTime <= 25) - OperatorEqualsUnroller - <Derived, OtherDerived, SizeAtCompileTime, RowsAtCompileTime>::run - (*static_cast<Derived*>(this), *static_cast<const OtherDerived*>(&other)); - else - for(int j = 0; j < cols(); j++) //traverse in column-dominant order - for(int i = 0; i < rows(); i++) - coeffRef(i, j) = other.coeff(i, j); - return *static_cast<Derived*>(this); + if(IsVector && OtherDerived::IsVector) // copying a vector expression into a vector + { + assert(size() == other.size()); + if(EIGEN_UNROLLED_LOOPS && SizeAtCompileTime != Dynamic && SizeAtCompileTime <= 25) + VectorOperatorEqualsUnroller + <Derived, OtherDerived, SizeAtCompileTime>::run + (*static_cast<Derived*>(this), *static_cast<const OtherDerived*>(&other)); + else + for(int i = 0; i < size(); i++) + coeffRef(i) = other.coeff(i); + return *static_cast<Derived*>(this); + } + else // all other cases (typically, but not necessarily, copying a matrix) + { + assert(rows() == other.rows() && cols() == other.cols()); + if(EIGEN_UNROLLED_LOOPS && SizeAtCompileTime != Dynamic && SizeAtCompileTime <= 25) + MatrixOperatorEqualsUnroller + <Derived, OtherDerived, SizeAtCompileTime, RowsAtCompileTime>::run + (*static_cast<Derived*>(this), *static_cast<const OtherDerived*>(&other)); + else + for(int j = 0; j < cols(); j++) //traverse in column-dominant order + for(int i = 0; i < rows(); i++) + coeffRef(i, j) = other.coeff(i, j); + return *static_cast<Derived*>(this); + } } #endif // EIGEN_OPERATOREQUALS_H |