diff options
-rw-r--r-- | Eigen/src/Core/AssignEvaluator.h | 12 | ||||
-rw-r--r-- | Eigen/src/Core/CoreEvaluators.h | 74 | ||||
-rw-r--r-- | Eigen/src/Core/Swap.h | 2 | ||||
-rw-r--r-- | test/evaluators.cpp | 24 |
4 files changed, 107 insertions, 5 deletions
diff --git a/Eigen/src/Core/AssignEvaluator.h b/Eigen/src/Core/AssignEvaluator.h index 93ca2433a..c5f345a2f 100644 --- a/Eigen/src/Core/AssignEvaluator.h +++ b/Eigen/src/Core/AssignEvaluator.h @@ -404,7 +404,7 @@ struct copy_using_evaluator_impl<DstXprType, SrcXprType, LinearVectorizedTravers dstAlignment = PacketTraits::AlignedOnScalar ? Aligned : dstIsAligned, srcAlignment = copy_using_evaluator_traits<DstXprType,SrcXprType>::JointAlignment }; - const Index alignedStart = dstIsAligned ? 0 : first_aligned(&dst.coeffRef(0), size); + const Index alignedStart = dstIsAligned ? 0 : first_aligned(&dstEvaluator.coeffRef(0), size); const Index alignedEnd = alignedStart + ((size-alignedStart)/packetSize)*packetSize; unaligned_copy_using_evaluator_impl<dstIsAligned!=0>::run(dstEvaluator, srcEvaluator, 0, alignedStart); @@ -614,6 +614,16 @@ const DstXprType& copy_using_evaluator(const DstXprType& dst, const SrcXprType& return dst; } +// Based on DenseBase::swap() +// TODO: Chech whether we need to do something special for swapping two +// Arrays or Matrices. + +template<typename DstXprType, typename SrcXprType> +void swap_using_evaluator(const DstXprType& dst, const SrcXprType& src) +{ + copy_using_evaluator(SwapWrapper<DstXprType>(const_cast<DstXprType&>(dst)), src); +} + } // namespace internal #endif // EIGEN_ASSIGN_EVALUATOR_H diff --git a/Eigen/src/Core/CoreEvaluators.h b/Eigen/src/Core/CoreEvaluators.h index 187dc1c97..899aa04ea 100644 --- a/Eigen/src/Core/CoreEvaluators.h +++ b/Eigen/src/Core/CoreEvaluators.h @@ -65,7 +65,7 @@ struct evaluator_impl_base { Index row = rowIndexByOuterInner(outer, inner); Index col = colIndexByOuterInner(outer, inner); - derived().coeffRef(row, col) = other.coeff(row, col); + derived().copyCoeff(row, col, other); } template<typename OtherEvaluatorType> @@ -86,8 +86,7 @@ struct evaluator_impl_base { Index row = rowIndexByOuterInner(outer, inner); Index col = colIndexByOuterInner(outer, inner); - derived().template writePacket<StoreMode>(row, col, - other.template packet<LoadMode>(row, col)); + derived().template copyPacket<StoreMode, LoadMode>(row, col, other); } template<int StoreMode, int LoadMode, typename OtherEvaluatorType> @@ -1017,6 +1016,75 @@ private: }; +// ---------- SwapWrapper ---------- + +template<typename ArgType> +struct evaluator_impl<SwapWrapper<ArgType> > + : evaluator_impl_base<SwapWrapper<ArgType> > +{ + typedef SwapWrapper<ArgType> XprType; + + evaluator_impl(const XprType& swapWrapper) + : m_argImpl(swapWrapper.expression()) + { } + + typedef typename XprType::Index Index; + typedef typename XprType::Scalar Scalar; + typedef typename XprType::Packet Packet; + + Scalar& coeffRef(Index row, Index col) + { + return m_argImpl.coeffRef(row, col); + } + + inline Scalar& coeffRef(Index index) + { + return m_argImpl.coeffRef(index); + } + + template<typename OtherEvaluatorType> + void copyCoeff(Index row, Index col, const OtherEvaluatorType& other) + { + OtherEvaluatorType& nonconst_other = const_cast<OtherEvaluatorType&>(other); + Scalar tmp = m_argImpl.coeff(row, col); + m_argImpl.coeffRef(row, col) = nonconst_other.coeff(row, col); + nonconst_other.coeffRef(row, col) = tmp; + } + + template<typename OtherEvaluatorType> + void copyCoeff(Index index, const OtherEvaluatorType& other) + { + OtherEvaluatorType& nonconst_other = const_cast<OtherEvaluatorType&>(other); + Scalar tmp = m_argImpl.coeff(index); + m_argImpl.coeffRef(index) = nonconst_other.coeff(index); + nonconst_other.coeffRef(index) = tmp; + } + + template<int StoreMode, int LoadMode, typename OtherEvaluatorType> + void copyPacket(Index row, Index col, const OtherEvaluatorType& other) + { + OtherEvaluatorType& nonconst_other = const_cast<OtherEvaluatorType&>(other); + Packet tmp = m_argImpl.template packet<StoreMode>(row, col); + m_argImpl.template writePacket<StoreMode> + (row, col, nonconst_other.template packet<LoadMode>(row, col)); + nonconst_other.template writePacket<LoadMode>(row, col, tmp); + } + + template<int StoreMode, int LoadMode, typename OtherEvaluatorType> + void copyPacket(Index index, const OtherEvaluatorType& other) + { + OtherEvaluatorType& nonconst_other = const_cast<OtherEvaluatorType&>(other); + Packet tmp = m_argImpl.template packet<StoreMode>(index); + m_argImpl.template writePacket<StoreMode> + (index, nonconst_other.template packet<LoadMode>(index)); + nonconst_other.template writePacket<LoadMode>(index, tmp); + } + +protected: + typename evaluator<ArgType>::type m_argImpl; +}; + + } // namespace internal #endif // EIGEN_COREEVALUATORS_H diff --git a/Eigen/src/Core/Swap.h b/Eigen/src/Core/Swap.h index 5fb032866..5fdd36e3b 100644 --- a/Eigen/src/Core/Swap.h +++ b/Eigen/src/Core/Swap.h @@ -119,6 +119,8 @@ template<typename ExpressionType> class SwapWrapper _other.template writePacket<LoadMode>(index, tmp); } + ExpressionType& expression() const { return m_expression; } + protected: ExpressionType& m_expression; }; diff --git a/test/evaluators.cpp b/test/evaluators.cpp index ea957cb1e..6e81ad5ef 100644 --- a/test/evaluators.cpp +++ b/test/evaluators.cpp @@ -214,5 +214,27 @@ void test_evaluators() copy_using_evaluator(mat1.diagonal<-1>(), mat1.diagonal(1)); mat2.diagonal<-1>() = mat2.diagonal(1); - VERIFY_IS_APPROX(mat1, mat2); + VERIFY_IS_APPROX(mat1, mat2); + + { + // test swapping + MatrixXd mat1, mat2, mat1ref, mat2ref; + mat1ref = mat1 = MatrixXd::Random(6, 6); + mat2ref = mat2 = 2 * mat1 + MatrixXd::Identity(6, 6); + swap_using_evaluator(mat1, mat2); + mat1ref.swap(mat2ref); + VERIFY_IS_APPROX(mat1, mat1ref); + VERIFY_IS_APPROX(mat2, mat2ref); + + swap_using_evaluator(mat1.block(0, 0, 3, 3), mat2.block(3, 3, 3, 3)); + mat1ref.block(0, 0, 3, 3).swap(mat2ref.block(3, 3, 3, 3)); + VERIFY_IS_APPROX(mat1, mat1ref); + VERIFY_IS_APPROX(mat2, mat2ref); + + swap_using_evaluator(mat1.row(2), mat2.col(3).transpose()); + mat1.row(2).swap(mat2.col(3).transpose()); + VERIFY_IS_APPROX(mat1, mat1ref); + VERIFY_IS_APPROX(mat2, mat2ref); + } + } |