diff options
-rw-r--r-- | Eigen/src/Core/AssignEvaluator.h | 268 | ||||
-rw-r--r-- | Eigen/src/Core/CoreEvaluators.h | 300 | ||||
-rw-r--r-- | test/evaluators.cpp | 2 |
3 files changed, 313 insertions, 257 deletions
diff --git a/Eigen/src/Core/AssignEvaluator.h b/Eigen/src/Core/AssignEvaluator.h index c49c2a50f..93ca2433a 100644 --- a/Eigen/src/Core/AssignEvaluator.h +++ b/Eigen/src/Core/AssignEvaluator.h @@ -147,165 +147,132 @@ public: * Part 2 : meta-unrollers ***************************************************************************/ -// TODO:`Ideally, we want to use only the evaluator objects here, not the expression objects -// However, we need to access .rowIndexByOuterInner() which is in the expression object - /************************ *** Default traversal *** ************************/ -template<typename DstXprType, typename SrcXprType, int Index, int Stop> +template<typename DstEvaluatorType, typename SrcEvaluatorType, int Index, int Stop> struct copy_using_evaluator_DefaultTraversal_CompleteUnrolling { + typedef typename DstEvaluatorType::XprType DstXprType; + enum { outer = Index / DstXprType::InnerSizeAtCompileTime, inner = Index % DstXprType::InnerSizeAtCompileTime }; - typedef typename evaluator<DstXprType>::type DstEvaluatorType; - typedef typename evaluator<SrcXprType>::type SrcEvaluatorType; - EIGEN_STRONG_INLINE static void run(DstEvaluatorType &dstEvaluator, - SrcEvaluatorType &srcEvaluator, - const DstXprType &dst) + SrcEvaluatorType &srcEvaluator) { - // TODO: Use copyCoeffByOuterInner ? - typename DstXprType::Index row = dst.rowIndexByOuterInner(outer, inner); - typename DstXprType::Index col = dst.colIndexByOuterInner(outer, inner); - dstEvaluator.coeffRef(row, col) = srcEvaluator.coeff(row, col); - copy_using_evaluator_DefaultTraversal_CompleteUnrolling<DstXprType, SrcXprType, Index+1, Stop> - ::run(dstEvaluator, srcEvaluator, dst); + dstEvaluator.copyCoeffByOuterInner(outer, inner, srcEvaluator); + copy_using_evaluator_DefaultTraversal_CompleteUnrolling + <DstEvaluatorType, SrcEvaluatorType, Index+1, Stop> + ::run(dstEvaluator, srcEvaluator); } }; -template<typename DstXprType, typename SrcXprType, int Stop> -struct copy_using_evaluator_DefaultTraversal_CompleteUnrolling<DstXprType, SrcXprType, Stop, Stop> +template<typename DstEvaluatorType, typename SrcEvaluatorType, int Stop> +struct copy_using_evaluator_DefaultTraversal_CompleteUnrolling<DstEvaluatorType, SrcEvaluatorType, Stop, Stop> { - typedef typename evaluator<DstXprType>::type DstEvaluatorType; - typedef typename evaluator<SrcXprType>::type SrcEvaluatorType; - EIGEN_STRONG_INLINE static void run(DstEvaluatorType&, SrcEvaluatorType&, const DstXprType&) { } + EIGEN_STRONG_INLINE static void run(DstEvaluatorType&, SrcEvaluatorType&) { } }; -template<typename DstXprType, typename SrcXprType, int Index, int Stop> +template<typename DstEvaluatorType, typename SrcEvaluatorType, int Index, int Stop> struct copy_using_evaluator_DefaultTraversal_InnerUnrolling { - typedef typename evaluator<DstXprType>::type DstEvaluatorType; - typedef typename evaluator<SrcXprType>::type SrcEvaluatorType; - EIGEN_STRONG_INLINE static void run(DstEvaluatorType &dstEvaluator, SrcEvaluatorType &srcEvaluator, - const DstXprType &dst, int outer) { - // TODO: Use copyCoeffByOuterInner ? - typename DstXprType::Index row = dst.rowIndexByOuterInner(outer, Index); - typename DstXprType::Index col = dst.colIndexByOuterInner(outer, Index); - dstEvaluator.coeffRef(row, col) = srcEvaluator.coeff(row, col); - copy_using_evaluator_DefaultTraversal_InnerUnrolling<DstXprType, SrcXprType, Index+1, Stop> - ::run(dstEvaluator, srcEvaluator, dst, outer); + dstEvaluator.copyCoeffByOuterInner(outer, Index, srcEvaluator); + copy_using_evaluator_DefaultTraversal_InnerUnrolling + <DstEvaluatorType, SrcEvaluatorType, Index+1, Stop> + ::run(dstEvaluator, srcEvaluator, outer); } }; -template<typename DstXprType, typename SrcXprType, int Stop> -struct copy_using_evaluator_DefaultTraversal_InnerUnrolling<DstXprType, SrcXprType, Stop, Stop> +template<typename DstEvaluatorType, typename SrcEvaluatorType, int Stop> +struct copy_using_evaluator_DefaultTraversal_InnerUnrolling<DstEvaluatorType, SrcEvaluatorType, Stop, Stop> { - typedef typename evaluator<DstXprType>::type DstEvaluatorType; - typedef typename evaluator<SrcXprType>::type SrcEvaluatorType; - EIGEN_STRONG_INLINE static void run(DstEvaluatorType&, SrcEvaluatorType&, const DstXprType&, int) { } + EIGEN_STRONG_INLINE static void run(DstEvaluatorType&, SrcEvaluatorType&, int) { } }; /*********************** *** Linear traversal *** ***********************/ -template<typename DstXprType, typename SrcXprType, int Index, int Stop> +template<typename DstEvaluatorType, typename SrcEvaluatorType, int Index, int Stop> struct copy_using_evaluator_LinearTraversal_CompleteUnrolling { - typedef typename evaluator<DstXprType>::type DstEvaluatorType; - typedef typename evaluator<SrcXprType>::type SrcEvaluatorType; - EIGEN_STRONG_INLINE static void run(DstEvaluatorType &dstEvaluator, - SrcEvaluatorType &srcEvaluator, - const DstXprType &dst) + SrcEvaluatorType &srcEvaluator) { - // use copyCoeff ? - dstEvaluator.coeffRef(Index) = srcEvaluator.coeff(Index); - copy_using_evaluator_LinearTraversal_CompleteUnrolling<DstXprType, SrcXprType, Index+1, Stop> - ::run(dstEvaluator, srcEvaluator, dst); + dstEvaluator.copyCoeff(Index, srcEvaluator); + copy_using_evaluator_LinearTraversal_CompleteUnrolling + <DstEvaluatorType, SrcEvaluatorType, Index+1, Stop> + ::run(dstEvaluator, srcEvaluator); } }; -template<typename DstXprType, typename SrcXprType, int Stop> -struct copy_using_evaluator_LinearTraversal_CompleteUnrolling<DstXprType, SrcXprType, Stop, Stop> +template<typename DstEvaluatorType, typename SrcEvaluatorType, int Stop> +struct copy_using_evaluator_LinearTraversal_CompleteUnrolling<DstEvaluatorType, SrcEvaluatorType, Stop, Stop> { - typedef typename evaluator<DstXprType>::type DstEvaluatorType; - typedef typename evaluator<SrcXprType>::type SrcEvaluatorType; - EIGEN_STRONG_INLINE static void run(DstEvaluatorType&, SrcEvaluatorType&, const DstXprType&) { } + EIGEN_STRONG_INLINE static void run(DstEvaluatorType&, SrcEvaluatorType&) { } }; /************************** *** Inner vectorization *** **************************/ -template<typename DstXprType, typename SrcXprType, int Index, int Stop> +template<typename DstEvaluatorType, typename SrcEvaluatorType, int Index, int Stop> struct copy_using_evaluator_innervec_CompleteUnrolling { + typedef typename DstEvaluatorType::XprType DstXprType; + typedef typename SrcEvaluatorType::XprType SrcXprType; + enum { outer = Index / DstXprType::InnerSizeAtCompileTime, inner = Index % DstXprType::InnerSizeAtCompileTime, JointAlignment = copy_using_evaluator_traits<DstXprType,SrcXprType>::JointAlignment }; - typedef typename evaluator<DstXprType>::type DstEvaluatorType; - typedef typename evaluator<SrcXprType>::type SrcEvaluatorType; - EIGEN_STRONG_INLINE static void run(DstEvaluatorType &dstEvaluator, - SrcEvaluatorType &srcEvaluator, - const DstXprType &dst) + SrcEvaluatorType &srcEvaluator) { - // TODO: Use copyPacketByOuterInner ? - typename DstXprType::Index row = dst.rowIndexByOuterInner(outer, inner); - typename DstXprType::Index col = dst.colIndexByOuterInner(outer, inner); - dstEvaluator.template writePacket<Aligned>(row, col, srcEvaluator.template packet<JointAlignment>(row, col)); - copy_using_evaluator_innervec_CompleteUnrolling<DstXprType, SrcXprType, - Index+packet_traits<typename DstXprType::Scalar>::size, Stop>::run(dstEvaluator, srcEvaluator, dst); + dstEvaluator.template copyPacketByOuterInner<Aligned, JointAlignment>(outer, inner, srcEvaluator); + enum { NextIndex = Index + packet_traits<typename DstXprType::Scalar>::size }; + copy_using_evaluator_innervec_CompleteUnrolling + <DstEvaluatorType, SrcEvaluatorType, NextIndex, Stop> + ::run(dstEvaluator, srcEvaluator); } }; -template<typename DstXprType, typename SrcXprType, int Stop> -struct copy_using_evaluator_innervec_CompleteUnrolling<DstXprType, SrcXprType, Stop, Stop> +template<typename DstEvaluatorType, typename SrcEvaluatorType, int Stop> +struct copy_using_evaluator_innervec_CompleteUnrolling<DstEvaluatorType, SrcEvaluatorType, Stop, Stop> { - typedef typename evaluator<DstXprType>::type DstEvaluatorType; - typedef typename evaluator<SrcXprType>::type SrcEvaluatorType; - EIGEN_STRONG_INLINE static void run(DstEvaluatorType&, SrcEvaluatorType&, const DstXprType&) { } + EIGEN_STRONG_INLINE static void run(DstEvaluatorType&, SrcEvaluatorType&) { } }; -template<typename DstXprType, typename SrcXprType, int Index, int Stop> +template<typename DstEvaluatorType, typename SrcEvaluatorType, int Index, int Stop> struct copy_using_evaluator_innervec_InnerUnrolling { - typedef typename evaluator<DstXprType>::type DstEvaluatorType; - typedef typename evaluator<SrcXprType>::type SrcEvaluatorType; - EIGEN_STRONG_INLINE static void run(DstEvaluatorType &dstEvaluator, SrcEvaluatorType &srcEvaluator, - const DstXprType &dst, int outer) { - // TODO: Use copyPacketByOuterInner ? - typename DstXprType::Index row = dst.rowIndexByOuterInner(outer, Index); - typename DstXprType::Index col = dst.colIndexByOuterInner(outer, Index); - dstEvaluator.template writePacket<Aligned>(row, col, srcEvaluator.template packet<Aligned>(row, col)); - copy_using_evaluator_innervec_InnerUnrolling<DstXprType, SrcXprType, - Index+packet_traits<typename DstXprType::Scalar>::size, Stop>::run(dstEvaluator, srcEvaluator, dst, outer); + dstEvaluator.template copyPacketByOuterInner<Aligned, Aligned>(outer, Index, srcEvaluator); + typedef typename DstEvaluatorType::XprType DstXprType; + enum { NextIndex = Index + packet_traits<typename DstXprType::Scalar>::size }; + copy_using_evaluator_innervec_InnerUnrolling + <DstEvaluatorType, SrcEvaluatorType, NextIndex, Stop> + ::run(dstEvaluator, srcEvaluator, outer); } }; -template<typename DstXprType, typename SrcXprType, int Stop> -struct copy_using_evaluator_innervec_InnerUnrolling<DstXprType, SrcXprType, Stop, Stop> +template<typename DstEvaluatorType, typename SrcEvaluatorType, int Stop> +struct copy_using_evaluator_innervec_InnerUnrolling<DstEvaluatorType, SrcEvaluatorType, Stop, Stop> { - typedef typename evaluator<DstXprType>::type DstEvaluatorType; - typedef typename evaluator<SrcXprType>::type SrcEvaluatorType; - EIGEN_STRONG_INLINE static void run(DstEvaluatorType&, SrcEvaluatorType&, const DstXprType&, int) { } + EIGEN_STRONG_INLINE static void run(DstEvaluatorType&, SrcEvaluatorType&, int) { } }; /*************************************************************************** @@ -326,20 +293,18 @@ struct copy_using_evaluator_impl; template<typename DstXprType, typename SrcXprType> struct copy_using_evaluator_impl<DstXprType, SrcXprType, DefaultTraversal, NoUnrolling> { - static void run(const DstXprType& dst, const SrcXprType& src) + static void run(DstXprType& dst, const SrcXprType& src) { typedef typename evaluator<DstXprType>::type DstEvaluatorType; typedef typename evaluator<SrcXprType>::type SrcEvaluatorType; typedef typename DstXprType::Index Index; - DstEvaluatorType dstEvaluator(dst.const_cast_derived()); + DstEvaluatorType dstEvaluator(dst); SrcEvaluatorType srcEvaluator(src); for(Index outer = 0; outer < dst.outerSize(); ++outer) { for(Index inner = 0; inner < dst.innerSize(); ++inner) { - Index row = dst.rowIndexByOuterInner(outer, inner); - Index col = dst.colIndexByOuterInner(outer, inner); - dstEvaluator.coeffRef(row, col) = srcEvaluator.coeff(row, col); // TODO: use copyCoeff ? + dstEvaluator.copyCoeffByOuterInner(outer, inner, srcEvaluator); } } } @@ -348,16 +313,17 @@ struct copy_using_evaluator_impl<DstXprType, SrcXprType, DefaultTraversal, NoUnr template<typename DstXprType, typename SrcXprType> struct copy_using_evaluator_impl<DstXprType, SrcXprType, DefaultTraversal, CompleteUnrolling> { - EIGEN_STRONG_INLINE static void run(const DstXprType &dst, const SrcXprType &src) + EIGEN_STRONG_INLINE static void run(DstXprType &dst, const SrcXprType &src) { typedef typename evaluator<DstXprType>::type DstEvaluatorType; typedef typename evaluator<SrcXprType>::type SrcEvaluatorType; - DstEvaluatorType dstEvaluator(dst.const_cast_derived()); + DstEvaluatorType dstEvaluator(dst); SrcEvaluatorType srcEvaluator(src); - copy_using_evaluator_DefaultTraversal_CompleteUnrolling<DstXprType, SrcXprType, 0, DstXprType::SizeAtCompileTime> - ::run(dstEvaluator, srcEvaluator, dst); + copy_using_evaluator_DefaultTraversal_CompleteUnrolling + <DstEvaluatorType, SrcEvaluatorType, 0, DstXprType::SizeAtCompileTime> + ::run(dstEvaluator, srcEvaluator); } }; @@ -365,18 +331,19 @@ template<typename DstXprType, typename SrcXprType> struct copy_using_evaluator_impl<DstXprType, SrcXprType, DefaultTraversal, InnerUnrolling> { typedef typename DstXprType::Index Index; - EIGEN_STRONG_INLINE static void run(const DstXprType &dst, const SrcXprType &src) + EIGEN_STRONG_INLINE static void run(DstXprType &dst, const SrcXprType &src) { typedef typename evaluator<DstXprType>::type DstEvaluatorType; typedef typename evaluator<SrcXprType>::type SrcEvaluatorType; - DstEvaluatorType dstEvaluator(dst.const_cast_derived()); + DstEvaluatorType dstEvaluator(dst); SrcEvaluatorType srcEvaluator(src); const Index outerSize = dst.outerSize(); for(Index outer = 0; outer < outerSize; ++outer) - copy_using_evaluator_DefaultTraversal_InnerUnrolling<DstXprType, SrcXprType, 0, DstXprType::InnerSizeAtCompileTime> - ::run(dstEvaluator, srcEvaluator, dst, outer); + copy_using_evaluator_DefaultTraversal_InnerUnrolling + <DstEvaluatorType, SrcEvaluatorType, 0, DstXprType::InnerSizeAtCompileTime> + ::run(dstEvaluator, srcEvaluator, outer); } }; @@ -387,43 +354,46 @@ struct copy_using_evaluator_impl<DstXprType, SrcXprType, DefaultTraversal, Inner template <bool IsAligned = false> struct unaligned_copy_using_evaluator_impl { + // if IsAligned = true, then do nothing template <typename SrcEvaluatorType, typename DstEvaluatorType> static EIGEN_STRONG_INLINE void run(const SrcEvaluatorType&, DstEvaluatorType&, typename SrcEvaluatorType::Index, typename SrcEvaluatorType::Index) {} }; -// TODO: check why no ...<true> ???? - template <> struct unaligned_copy_using_evaluator_impl<false> { // MSVC must not inline this functions. If it does, it fails to optimize the // packet access path. #ifdef _MSC_VER - template <typename SrcEvaluatorType, typename DstEvaluatorType> - static EIGEN_DONT_INLINE void run(const SrcEvaluatorType& src, DstEvaluatorType& dst, - typename SrcEvaluatorType::Index start, typename SrcEvaluatorType::Index end) + template <typename DstEvaluatorType, typename SrcEvaluatorType> + static EIGEN_DONT_INLINE void run(DstEvaluatorType &dstEvaluator, + const SrcEvaluatorType &srcEvaluator, + typename DstEvaluatorType::Index start, + typename DstEvaluatorType::Index end) #else - template <typename SrcEvaluatorType, typename DstEvaluatorType> - static EIGEN_STRONG_INLINE void run(const SrcEvaluatorType& src, DstEvaluatorType& dst, - typename SrcEvaluatorType::Index start, typename SrcEvaluatorType::Index end) + template <typename DstEvaluatorType, typename SrcEvaluatorType> + static EIGEN_STRONG_INLINE void run(DstEvaluatorType &dstEvaluator, + const SrcEvaluatorType &srcEvaluator, + typename DstEvaluatorType::Index start, + typename DstEvaluatorType::Index end) #endif { - for (typename SrcEvaluatorType::Index index = start; index < end; ++index) - dst.copyCoeff(index, src); + for (typename DstEvaluatorType::Index index = start; index < end; ++index) + dstEvaluator.copyCoeff(index, srcEvaluator); } }; template<typename DstXprType, typename SrcXprType> struct copy_using_evaluator_impl<DstXprType, SrcXprType, LinearVectorizedTraversal, NoUnrolling> { - EIGEN_STRONG_INLINE static void run(const DstXprType &dst, const SrcXprType &src) + EIGEN_STRONG_INLINE static void run(DstXprType &dst, const SrcXprType &src) { typedef typename evaluator<DstXprType>::type DstEvaluatorType; typedef typename evaluator<SrcXprType>::type SrcEvaluatorType; typedef typename DstXprType::Index Index; - DstEvaluatorType dstEvaluator(dst.const_cast_derived()); + DstEvaluatorType dstEvaluator(dst); SrcEvaluatorType srcEvaluator(src); const Index size = dst.size(); @@ -437,14 +407,14 @@ struct copy_using_evaluator_impl<DstXprType, SrcXprType, LinearVectorizedTravers const Index alignedStart = dstIsAligned ? 0 : first_aligned(&dst.coeffRef(0), size); const Index alignedEnd = alignedStart + ((size-alignedStart)/packetSize)*packetSize; - unaligned_copy_using_evaluator_impl<dstIsAligned!=0>::run(src,dst.const_cast_derived(),0,alignedStart); + unaligned_copy_using_evaluator_impl<dstIsAligned!=0>::run(dstEvaluator, srcEvaluator, 0, alignedStart); for(Index index = alignedStart; index < alignedEnd; index += packetSize) { - dstEvaluator.template writePacket<dstAlignment>(index, srcEvaluator.template packet<srcAlignment>(index)); + dstEvaluator.template copyPacket<dstAlignment, srcAlignment>(index, srcEvaluator); } - unaligned_copy_using_evaluator_impl<>::run(src,dst.const_cast_derived(),alignedEnd,size); + unaligned_copy_using_evaluator_impl<>::run(dstEvaluator, srcEvaluator, alignedEnd, size); } }; @@ -452,22 +422,24 @@ template<typename DstXprType, typename SrcXprType> struct copy_using_evaluator_impl<DstXprType, SrcXprType, LinearVectorizedTraversal, CompleteUnrolling> { typedef typename DstXprType::Index Index; - EIGEN_STRONG_INLINE static void run(const DstXprType &dst, const SrcXprType &src) + EIGEN_STRONG_INLINE static void run(DstXprType &dst, const SrcXprType &src) { typedef typename evaluator<DstXprType>::type DstEvaluatorType; typedef typename evaluator<SrcXprType>::type SrcEvaluatorType; - DstEvaluatorType dstEvaluator(dst.const_cast_derived()); + DstEvaluatorType dstEvaluator(dst); SrcEvaluatorType srcEvaluator(src); enum { size = DstXprType::SizeAtCompileTime, packetSize = packet_traits<typename DstXprType::Scalar>::size, alignedSize = (size/packetSize)*packetSize }; - copy_using_evaluator_innervec_CompleteUnrolling<DstXprType, SrcXprType, 0, alignedSize> - ::run(dstEvaluator, srcEvaluator, dst); - copy_using_evaluator_DefaultTraversal_CompleteUnrolling<DstXprType, SrcXprType, alignedSize, size> - ::run(dstEvaluator, srcEvaluator, dst); + copy_using_evaluator_innervec_CompleteUnrolling + <DstEvaluatorType, SrcEvaluatorType, 0, alignedSize> + ::run(dstEvaluator, srcEvaluator); + copy_using_evaluator_DefaultTraversal_CompleteUnrolling + <DstEvaluatorType, SrcEvaluatorType, alignedSize, size> + ::run(dstEvaluator, srcEvaluator); } }; @@ -478,13 +450,13 @@ struct copy_using_evaluator_impl<DstXprType, SrcXprType, LinearVectorizedTravers template<typename DstXprType, typename SrcXprType> struct copy_using_evaluator_impl<DstXprType, SrcXprType, InnerVectorizedTraversal, NoUnrolling> { - inline static void run(const DstXprType &dst, const SrcXprType &src) + inline static void run(DstXprType &dst, const SrcXprType &src) { typedef typename evaluator<DstXprType>::type DstEvaluatorType; typedef typename evaluator<SrcXprType>::type SrcEvaluatorType; typedef typename DstXprType::Index Index; - DstEvaluatorType dstEvaluator(dst.const_cast_derived()); + DstEvaluatorType dstEvaluator(dst); SrcEvaluatorType srcEvaluator(src); const Index innerSize = dst.innerSize(); @@ -492,10 +464,7 @@ struct copy_using_evaluator_impl<DstXprType, SrcXprType, InnerVectorizedTraversa const Index packetSize = packet_traits<typename DstXprType::Scalar>::size; for(Index outer = 0; outer < outerSize; ++outer) for(Index inner = 0; inner < innerSize; inner+=packetSize) { - // TODO: Use copyPacketByOuterInner ? - Index row = dst.rowIndexByOuterInner(outer, inner); - Index col = dst.colIndexByOuterInner(outer, inner); - dstEvaluator.template writePacket<Aligned>(row, col, srcEvaluator.template packet<Aligned>(row, col)); + dstEvaluator.template copyPacketByOuterInner<Aligned, Aligned>(outer, inner, srcEvaluator); } } }; @@ -503,16 +472,17 @@ struct copy_using_evaluator_impl<DstXprType, SrcXprType, InnerVectorizedTraversa template<typename DstXprType, typename SrcXprType> struct copy_using_evaluator_impl<DstXprType, SrcXprType, InnerVectorizedTraversal, CompleteUnrolling> { - EIGEN_STRONG_INLINE static void run(const DstXprType &dst, const SrcXprType &src) + EIGEN_STRONG_INLINE static void run(DstXprType &dst, const SrcXprType &src) { typedef typename evaluator<DstXprType>::type DstEvaluatorType; typedef typename evaluator<SrcXprType>::type SrcEvaluatorType; - DstEvaluatorType dstEvaluator(dst.const_cast_derived()); + DstEvaluatorType dstEvaluator(dst); SrcEvaluatorType srcEvaluator(src); - copy_using_evaluator_innervec_CompleteUnrolling<DstXprType, SrcXprType, 0, DstXprType::SizeAtCompileTime> - ::run(dstEvaluator, srcEvaluator, dst); + copy_using_evaluator_innervec_CompleteUnrolling + <DstEvaluatorType, SrcEvaluatorType, 0, DstXprType::SizeAtCompileTime> + ::run(dstEvaluator, srcEvaluator); } }; @@ -520,18 +490,19 @@ template<typename DstXprType, typename SrcXprType> struct copy_using_evaluator_impl<DstXprType, SrcXprType, InnerVectorizedTraversal, InnerUnrolling> { typedef typename DstXprType::Index Index; - EIGEN_STRONG_INLINE static void run(const DstXprType &dst, const SrcXprType &src) + EIGEN_STRONG_INLINE static void run(DstXprType &dst, const SrcXprType &src) { typedef typename evaluator<DstXprType>::type DstEvaluatorType; typedef typename evaluator<SrcXprType>::type SrcEvaluatorType; - DstEvaluatorType dstEvaluator(dst.const_cast_derived()); + DstEvaluatorType dstEvaluator(dst); SrcEvaluatorType srcEvaluator(src); const Index outerSize = dst.outerSize(); for(Index outer = 0; outer < outerSize; ++outer) - copy_using_evaluator_innervec_InnerUnrolling<DstXprType, SrcXprType, 0, DstXprType::InnerSizeAtCompileTime> - ::run(dstEvaluator, srcEvaluator, dst, outer); + copy_using_evaluator_innervec_InnerUnrolling + <DstEvaluatorType, SrcEvaluatorType, 0, DstXprType::InnerSizeAtCompileTime> + ::run(dstEvaluator, srcEvaluator, outer); } }; @@ -542,34 +513,35 @@ struct copy_using_evaluator_impl<DstXprType, SrcXprType, InnerVectorizedTraversa template<typename DstXprType, typename SrcXprType> struct copy_using_evaluator_impl<DstXprType, SrcXprType, LinearTraversal, NoUnrolling> { - inline static void run(const DstXprType &dst, const SrcXprType &src) + inline static void run(DstXprType &dst, const SrcXprType &src) { typedef typename evaluator<DstXprType>::type DstEvaluatorType; typedef typename evaluator<SrcXprType>::type SrcEvaluatorType; typedef typename DstXprType::Index Index; - DstEvaluatorType dstEvaluator(dst.const_cast_derived()); + DstEvaluatorType dstEvaluator(dst); SrcEvaluatorType srcEvaluator(src); const Index size = dst.size(); for(Index i = 0; i < size; ++i) - dstEvaluator.coeffRef(i) = srcEvaluator.coeff(i); // TODO: use copyCoeff ? + dstEvaluator.copyCoeff(i, srcEvaluator); } }; template<typename DstXprType, typename SrcXprType> struct copy_using_evaluator_impl<DstXprType, SrcXprType, LinearTraversal, CompleteUnrolling> { - EIGEN_STRONG_INLINE static void run(const DstXprType &dst, const SrcXprType &src) + EIGEN_STRONG_INLINE static void run(DstXprType &dst, const SrcXprType &src) { typedef typename evaluator<DstXprType>::type DstEvaluatorType; typedef typename evaluator<SrcXprType>::type SrcEvaluatorType; - DstEvaluatorType dstEvaluator(dst.const_cast_derived()); + DstEvaluatorType dstEvaluator(dst); SrcEvaluatorType srcEvaluator(src); - copy_using_evaluator_LinearTraversal_CompleteUnrolling<DstXprType, SrcXprType, 0, DstXprType::SizeAtCompileTime> - ::run(dstEvaluator, srcEvaluator, dst); + copy_using_evaluator_LinearTraversal_CompleteUnrolling + <DstEvaluatorType, SrcEvaluatorType, 0, DstXprType::SizeAtCompileTime> + ::run(dstEvaluator, srcEvaluator); } }; @@ -580,13 +552,13 @@ struct copy_using_evaluator_impl<DstXprType, SrcXprType, LinearTraversal, Comple template<typename DstXprType, typename SrcXprType> struct copy_using_evaluator_impl<DstXprType, SrcXprType, SliceVectorizedTraversal, NoUnrolling> { - inline static void run(const DstXprType &dst, const SrcXprType &src) + inline static void run(DstXprType &dst, const SrcXprType &src) { typedef typename evaluator<DstXprType>::type DstEvaluatorType; typedef typename evaluator<SrcXprType>::type SrcEvaluatorType; typedef typename DstXprType::Index Index; - DstEvaluatorType dstEvaluator(dst.const_cast_derived()); + DstEvaluatorType dstEvaluator(dst); SrcEvaluatorType srcEvaluator(src); typedef packet_traits<typename DstXprType::Scalar> PacketTraits; @@ -608,23 +580,17 @@ struct copy_using_evaluator_impl<DstXprType, SrcXprType, SliceVectorizedTraversa const Index alignedEnd = alignedStart + ((innerSize-alignedStart) & ~packetAlignedMask); // do the non-vectorizable part of the assignment for(Index inner = 0; inner<alignedStart ; ++inner) { - Index row = dst.rowIndexByOuterInner(outer, inner); - Index col = dst.colIndexByOuterInner(outer, inner); - dstEvaluator.coeffRef(row, col) = srcEvaluator.coeff(row, col); + dstEvaluator.copyCoeffByOuterInner(outer, inner, srcEvaluator); } // do the vectorizable part of the assignment for(Index inner = alignedStart; inner<alignedEnd; inner+=packetSize) { - Index row = dst.rowIndexByOuterInner(outer, inner); - Index col = dst.colIndexByOuterInner(outer, inner); - dstEvaluator.template writePacket<dstAlignment>(row, col, srcEvaluator.template packet<srcAlignment>(row, col)); + dstEvaluator.template copyPacketByOuterInner<dstAlignment, srcAlignment>(outer, inner, srcEvaluator); } // do the non-vectorizable part of the assignment for(Index inner = alignedEnd; inner<innerSize ; ++inner) { - Index row = dst.rowIndexByOuterInner(outer, inner); - Index col = dst.colIndexByOuterInner(outer, inner); - dstEvaluator.coeffRef(row, col) = srcEvaluator.coeff(row, col); + dstEvaluator.copyCoeffByOuterInner(outer, inner, srcEvaluator); } alignedStart = std::min<Index>((alignedStart+alignedStep)%packetSize, innerSize); @@ -644,7 +610,7 @@ const DstXprType& copy_using_evaluator(const DstXprType& dst, const SrcXprType& #ifdef EIGEN_DEBUG_ASSIGN internal::copy_using_evaluator_traits<DstXprType, SrcXprType>::debug(); #endif - copy_using_evaluator_impl<DstXprType, SrcXprType>::run(dst, src); + copy_using_evaluator_impl<DstXprType, SrcXprType>::run(const_cast<DstXprType&>(dst), src); return dst; } diff --git a/Eigen/src/Core/CoreEvaluators.h b/Eigen/src/Core/CoreEvaluators.h index 1ef82e4be..187dc1c97 100644 --- a/Eigen/src/Core/CoreEvaluators.h +++ b/Eigen/src/Core/CoreEvaluators.h @@ -47,63 +47,140 @@ struct evaluator<const T> typedef evaluator_impl<T> type; }; -// -------------------- Transpose -------------------- +// ---------- base class for all writable evaluators ---------- template<typename ExpressionType> -struct evaluator_impl<Transpose<ExpressionType> > +struct evaluator_impl_base { - typedef Transpose<ExpressionType> TransposeType; - evaluator_impl(const TransposeType& t) : m_argImpl(t.nestedExpression()) {} + typedef typename ExpressionType::Index Index; + + template<typename OtherEvaluatorType> + void copyCoeff(Index row, Index col, const OtherEvaluatorType& other) + { + derived().coeffRef(row, col) = other.coeff(row, col); + } + + template<typename OtherEvaluatorType> + void copyCoeffByOuterInner(Index outer, Index inner, const OtherEvaluatorType& other) + { + Index row = rowIndexByOuterInner(outer, inner); + Index col = colIndexByOuterInner(outer, inner); + derived().coeffRef(row, col) = other.coeff(row, col); + } + + template<typename OtherEvaluatorType> + void copyCoeff(Index index, const OtherEvaluatorType& other) + { + derived().coeffRef(index) = other.coeff(index); + } + + template<int StoreMode, int LoadMode, typename OtherEvaluatorType> + void copyPacket(Index row, Index col, const OtherEvaluatorType& other) + { + derived().template writePacket<StoreMode>(row, col, + other.template packet<LoadMode>(row, col)); + } + + template<int StoreMode, int LoadMode, typename OtherEvaluatorType> + void copyPacketByOuterInner(Index outer, Index inner, const OtherEvaluatorType& other) + { + Index row = rowIndexByOuterInner(outer, inner); + Index col = colIndexByOuterInner(outer, inner); + derived().template writePacket<StoreMode>(row, col, + other.template packet<LoadMode>(row, col)); + } + + template<int StoreMode, int LoadMode, typename OtherEvaluatorType> + void copyPacket(Index index, const OtherEvaluatorType& other) + { + derived().template writePacket<StoreMode>(index, + other.template packet<LoadMode>(index)); + } - typedef typename TransposeType::Index Index; + Index rowIndexByOuterInner(Index outer, Index inner) const + { + return int(ExpressionType::RowsAtCompileTime) == 1 ? 0 + : int(ExpressionType::ColsAtCompileTime) == 1 ? inner + : int(ExpressionType::Flags)&RowMajorBit ? outer + : inner; + } - typename TransposeType::CoeffReturnType coeff(Index i, Index j) const + Index colIndexByOuterInner(Index outer, Index inner) const { - return m_argImpl.coeff(j, i); + return int(ExpressionType::ColsAtCompileTime) == 1 ? 0 + : int(ExpressionType::RowsAtCompileTime) == 1 ? inner + : int(ExpressionType::Flags)&RowMajorBit ? inner + : outer; } - typename TransposeType::CoeffReturnType coeff(Index index) const + evaluator_impl<ExpressionType>& derived() + { + return *static_cast<evaluator_impl<ExpressionType>*>(this); + } +}; + +// -------------------- Transpose -------------------- + +template<typename ArgType> +struct evaluator_impl<Transpose<ArgType> > + : evaluator_impl_base<Transpose<ArgType> > +{ + typedef Transpose<ArgType> XprType; + + evaluator_impl(const XprType& t) : m_argImpl(t.nestedExpression()) {} + + typedef typename XprType::Index Index; + typedef typename XprType::Scalar Scalar; + typedef typename XprType::CoeffReturnType CoeffReturnType; + typedef typename XprType::PacketScalar PacketScalar; + typedef typename XprType::PacketReturnType PacketReturnType; + + CoeffReturnType coeff(Index row, Index col) const + { + return m_argImpl.coeff(col, row); + } + + CoeffReturnType coeff(Index index) const { return m_argImpl.coeff(index); } - typename TransposeType::Scalar& coeffRef(Index i, Index j) + Scalar& coeffRef(Index row, Index col) { - return m_argImpl.coeffRef(j, i); + return m_argImpl.coeffRef(col, row); } - typename TransposeType::Scalar& coeffRef(Index index) + typename XprType::Scalar& coeffRef(Index index) { return m_argImpl.coeffRef(index); } - // TODO: Difference between PacketScalar and PacketReturnType? template<int LoadMode> - const typename ExpressionType::PacketScalar packet(Index row, Index col) const + PacketReturnType packet(Index row, Index col) const { return m_argImpl.template packet<LoadMode>(col, row); } template<int LoadMode> - const typename ExpressionType::PacketScalar packet(Index index) const + PacketReturnType packet(Index index) const { return m_argImpl.template packet<LoadMode>(index); } template<int StoreMode> - void writePacket(Index row, Index col, const typename ExpressionType::PacketScalar& x) + void writePacket(Index row, Index col, const PacketScalar& x) { m_argImpl.template writePacket<StoreMode>(col, row, x); } template<int StoreMode> - void writePacket(Index index, const typename ExpressionType::PacketScalar& x) + void writePacket(Index index, const PacketScalar& x) { m_argImpl.template writePacket<StoreMode>(index, x); } protected: - typename evaluator<ExpressionType>::type m_argImpl; + typename evaluator<ArgType>::type m_argImpl; }; // -------------------- Matrix and Array -------------------- @@ -113,6 +190,7 @@ protected: template<typename Derived> struct evaluator_impl<PlainObjectBase<Derived> > + : evaluator_impl_base<Derived> { typedef PlainObjectBase<Derived> PlainObjectType; @@ -176,10 +254,10 @@ template<typename Scalar, int Rows, int Cols, int Options, int MaxRows, int MaxC struct evaluator_impl<Matrix<Scalar, Rows, Cols, Options, MaxRows, MaxCols> > : evaluator_impl<PlainObjectBase<Matrix<Scalar, Rows, Cols, Options, MaxRows, MaxCols> > > { - typedef Matrix<Scalar, Rows, Cols, Options, MaxRows, MaxCols> MatrixType; + typedef Matrix<Scalar, Rows, Cols, Options, MaxRows, MaxCols> XprType; - evaluator_impl(const MatrixType& m) - : evaluator_impl<PlainObjectBase<MatrixType> >(m) + evaluator_impl(const XprType& m) + : evaluator_impl<PlainObjectBase<XprType> >(m) { } }; @@ -187,10 +265,10 @@ template<typename Scalar, int Rows, int Cols, int Options, int MaxRows, int MaxC struct evaluator_impl<Array<Scalar, Rows, Cols, Options, MaxRows, MaxCols> > : evaluator_impl<PlainObjectBase<Array<Scalar, Rows, Cols, Options, MaxRows, MaxCols> > > { - typedef Array<Scalar, Rows, Cols, Options, MaxRows, MaxCols> ArrayType; + typedef Array<Scalar, Rows, Cols, Options, MaxRows, MaxCols> XprType; - evaluator_impl(const ArrayType& m) - : evaluator_impl<PlainObjectBase<ArrayType> >(m) + evaluator_impl(const XprType& m) + : evaluator_impl<PlainObjectBase<XprType> >(m) { } }; @@ -199,15 +277,15 @@ struct evaluator_impl<Array<Scalar, Rows, Cols, Options, MaxRows, MaxCols> > template<typename NullaryOp, typename PlainObjectType> struct evaluator_impl<CwiseNullaryOp<NullaryOp,PlainObjectType> > { - typedef CwiseNullaryOp<NullaryOp,PlainObjectType> NullaryOpType; + typedef CwiseNullaryOp<NullaryOp,PlainObjectType> XprType; - evaluator_impl(const NullaryOpType& n) + evaluator_impl(const XprType& n) : m_functor(n.functor()) { } - typedef typename NullaryOpType::Index Index; - typedef typename NullaryOpType::CoeffReturnType CoeffReturnType; - typedef typename NullaryOpType::PacketScalar PacketScalar; + typedef typename XprType::Index Index; + typedef typename XprType::CoeffReturnType CoeffReturnType; + typedef typename XprType::PacketScalar PacketScalar; CoeffReturnType coeff(Index row, Index col) const { @@ -240,16 +318,16 @@ protected: template<typename UnaryOp, typename ArgType> struct evaluator_impl<CwiseUnaryOp<UnaryOp, ArgType> > { - typedef CwiseUnaryOp<UnaryOp, ArgType> UnaryOpType; + typedef CwiseUnaryOp<UnaryOp, ArgType> XprType; - evaluator_impl(const UnaryOpType& op) + evaluator_impl(const XprType& op) : m_functor(op.functor()), m_argImpl(op.nestedExpression()) { } - typedef typename UnaryOpType::Index Index; - typedef typename UnaryOpType::CoeffReturnType CoeffReturnType; - typedef typename UnaryOpType::PacketScalar PacketScalar; + typedef typename XprType::Index Index; + typedef typename XprType::CoeffReturnType CoeffReturnType; + typedef typename XprType::PacketScalar PacketScalar; CoeffReturnType coeff(Index row, Index col) const { @@ -283,17 +361,17 @@ protected: template<typename BinaryOp, typename Lhs, typename Rhs> struct evaluator_impl<CwiseBinaryOp<BinaryOp, Lhs, Rhs> > { - typedef CwiseBinaryOp<BinaryOp, Lhs, Rhs> BinaryOpType; + typedef CwiseBinaryOp<BinaryOp, Lhs, Rhs> XprType; - evaluator_impl(const BinaryOpType& xpr) + evaluator_impl(const XprType& xpr) : m_functor(xpr.functor()), m_lhsImpl(xpr.lhs()), m_rhsImpl(xpr.rhs()) { } - typedef typename BinaryOpType::Index Index; - typedef typename BinaryOpType::CoeffReturnType CoeffReturnType; - typedef typename BinaryOpType::PacketScalar PacketScalar; + typedef typename XprType::Index Index; + typedef typename XprType::CoeffReturnType CoeffReturnType; + typedef typename XprType::PacketScalar PacketScalar; CoeffReturnType coeff(Index row, Index col) const { @@ -329,17 +407,18 @@ protected: template<typename UnaryOp, typename ArgType> struct evaluator_impl<CwiseUnaryView<UnaryOp, ArgType> > + : evaluator_impl_base<CwiseUnaryView<UnaryOp, ArgType> > { - typedef CwiseUnaryView<UnaryOp, ArgType> CwiseUnaryViewType; + typedef CwiseUnaryView<UnaryOp, ArgType> XprType; - evaluator_impl(const CwiseUnaryViewType& op) + evaluator_impl(const XprType& op) : m_unaryOp(op.functor()), m_argImpl(op.nestedExpression()) { } - typedef typename CwiseUnaryViewType::Index Index; - typedef typename CwiseUnaryViewType::Scalar Scalar; - typedef typename CwiseUnaryViewType::CoeffReturnType CoeffReturnType; + typedef typename XprType::Index Index; + typedef typename XprType::Scalar Scalar; + typedef typename XprType::CoeffReturnType CoeffReturnType; CoeffReturnType coeff(Index row, Index col) const { @@ -400,23 +479,26 @@ protected: template<typename Derived, int AccessorsType> struct evaluator_impl<MapBase<Derived, AccessorsType> > + : evaluator_impl_base<Derived> { typedef MapBase<Derived, AccessorsType> MapType; - typedef typename MapType::PointerType PointerType; - typedef typename MapType::Index Index; - typedef typename MapType::Scalar Scalar; - typedef typename MapType::CoeffReturnType CoeffReturnType; - typedef typename MapType::PacketScalar PacketScalar; - typedef typename MapType::PacketReturnType PacketReturnType; + typedef Derived XprType; + + typedef typename XprType::PointerType PointerType; + typedef typename XprType::Index Index; + typedef typename XprType::Scalar Scalar; + typedef typename XprType::CoeffReturnType CoeffReturnType; + typedef typename XprType::PacketScalar PacketScalar; + typedef typename XprType::PacketReturnType PacketReturnType; - evaluator_impl(const MapType& map) + evaluator_impl(const XprType& map) : m_data(const_cast<PointerType>(map.data())), m_rowStride(map.rowStride()), m_colStride(map.colStride()) { } enum { - RowsAtCompileTime = MapType::RowsAtCompileTime + RowsAtCompileTime = XprType::RowsAtCompileTime }; CoeffReturnType coeff(Index row, Index col) const @@ -480,34 +562,35 @@ template<typename PlainObjectType, int MapOptions, typename StrideType> struct evaluator_impl<Map<PlainObjectType, MapOptions, StrideType> > : public evaluator_impl<MapBase<Map<PlainObjectType, MapOptions, StrideType> > > { - typedef Map<PlainObjectType, MapOptions, StrideType> MapType; + typedef Map<PlainObjectType, MapOptions, StrideType> XprType; - evaluator_impl(const MapType& map) - : evaluator_impl<MapBase<MapType> >(map) + evaluator_impl(const XprType& map) + : evaluator_impl<MapBase<XprType> >(map) { } }; // -------------------- Block -------------------- -template<typename XprType, int BlockRows, int BlockCols, bool InnerPanel> -struct evaluator_impl<Block<XprType, BlockRows, BlockCols, InnerPanel, /* HasDirectAccess */ false> > +template<typename ArgType, int BlockRows, int BlockCols, bool InnerPanel> +struct evaluator_impl<Block<ArgType, BlockRows, BlockCols, InnerPanel, /* HasDirectAccess */ false> > + : evaluator_impl_base<Block<ArgType, BlockRows, BlockCols, InnerPanel, false> > { - typedef Block<XprType, BlockRows, BlockCols, InnerPanel, false> BlockType; + typedef Block<ArgType, BlockRows, BlockCols, InnerPanel, false> XprType; - evaluator_impl(const BlockType& block) + evaluator_impl(const XprType& block) : m_argImpl(block.nestedExpression()), m_startRow(block.startRow()), m_startCol(block.startCol()) { } - typedef typename BlockType::Index Index; - typedef typename BlockType::Scalar Scalar; - typedef typename BlockType::CoeffReturnType CoeffReturnType; - typedef typename BlockType::PacketScalar PacketScalar; - typedef typename BlockType::PacketReturnType PacketReturnType; + typedef typename XprType::Index Index; + typedef typename XprType::Scalar Scalar; + typedef typename XprType::CoeffReturnType CoeffReturnType; + typedef typename XprType::PacketScalar PacketScalar; + typedef typename XprType::PacketReturnType PacketReturnType; enum { - RowsAtCompileTime = BlockType::RowsAtCompileTime + RowsAtCompileTime = XprType::RowsAtCompileTime }; CoeffReturnType coeff(Index row, Index col) const @@ -560,7 +643,7 @@ struct evaluator_impl<Block<XprType, BlockRows, BlockCols, InnerPanel, /* HasDir } protected: - typename evaluator<XprType>::type m_argImpl; + typename evaluator<ArgType>::type m_argImpl; // TODO: Get rid of m_startRow, m_startCol if known at compile time Index m_startRow; @@ -570,14 +653,14 @@ protected: // TODO: This evaluator does not actually use the child evaluator; // all action is via the data() as returned by the Block expression. -template<typename XprType, int BlockRows, int BlockCols, bool InnerPanel> -struct evaluator_impl<Block<XprType, BlockRows, BlockCols, InnerPanel, /* HasDirectAccess */ true> > - : evaluator_impl<MapBase<Block<XprType, BlockRows, BlockCols, InnerPanel, true> > > +template<typename ArgType, int BlockRows, int BlockCols, bool InnerPanel> +struct evaluator_impl<Block<ArgType, BlockRows, BlockCols, InnerPanel, /* HasDirectAccess */ true> > + : evaluator_impl<MapBase<Block<ArgType, BlockRows, BlockCols, InnerPanel, true> > > { - typedef Block<XprType, BlockRows, BlockCols, InnerPanel, true> BlockType; + typedef Block<ArgType, BlockRows, BlockCols, InnerPanel, true> XprType; - evaluator_impl(const BlockType& block) - : evaluator_impl<MapBase<BlockType> >(block) + evaluator_impl(const XprType& block) + : evaluator_impl<MapBase<XprType> >(block) { } }; @@ -587,16 +670,16 @@ struct evaluator_impl<Block<XprType, BlockRows, BlockCols, InnerPanel, /* HasDir template<typename ConditionMatrixType, typename ThenMatrixType, typename ElseMatrixType> struct evaluator_impl<Select<ConditionMatrixType, ThenMatrixType, ElseMatrixType> > { - typedef Select<ConditionMatrixType, ThenMatrixType, ElseMatrixType> SelectType; + typedef Select<ConditionMatrixType, ThenMatrixType, ElseMatrixType> XprType; - evaluator_impl(const SelectType& select) + evaluator_impl(const XprType& select) : m_conditionImpl(select.conditionMatrix()), m_thenImpl(select.thenMatrix()), m_elseImpl(select.elseMatrix()) { } - typedef typename SelectType::Index Index; - typedef typename SelectType::CoeffReturnType CoeffReturnType; + typedef typename XprType::Index Index; + typedef typename XprType::CoeffReturnType CoeffReturnType; CoeffReturnType coeff(Index row, Index col) const { @@ -623,20 +706,20 @@ protected: // -------------------- Replicate -------------------- -template<typename XprType, int RowFactor, int ColFactor> -struct evaluator_impl<Replicate<XprType, RowFactor, ColFactor> > +template<typename ArgType, int RowFactor, int ColFactor> +struct evaluator_impl<Replicate<ArgType, RowFactor, ColFactor> > { - typedef Replicate<XprType, RowFactor, ColFactor> ReplicateType; + typedef Replicate<ArgType, RowFactor, ColFactor> XprType; - evaluator_impl(const ReplicateType& replicate) + evaluator_impl(const XprType& replicate) : m_argImpl(replicate.nestedExpression()), m_rows(replicate.nestedExpression().rows()), m_cols(replicate.nestedExpression().cols()) { } - typedef typename ReplicateType::Index Index; - typedef typename ReplicateType::CoeffReturnType CoeffReturnType; - typedef typename ReplicateType::PacketReturnType PacketReturnType; + typedef typename XprType::Index Index; + typedef typename XprType::CoeffReturnType CoeffReturnType; + typedef typename XprType::PacketReturnType PacketReturnType; CoeffReturnType coeff(Index row, Index col) const { @@ -665,7 +748,7 @@ struct evaluator_impl<Replicate<XprType, RowFactor, ColFactor> > } protected: - typename evaluator<XprType>::type m_argImpl; + typename evaluator<ArgType>::type m_argImpl; Index m_rows; // TODO: Get rid of this if known at compile time Index m_cols; }; @@ -677,17 +760,17 @@ protected: // TODO: Find out how to write a proper evaluator without duplicating // the row() and col() member functions. -template< typename XprType, typename MemberOp, int Direction> -struct evaluator_impl<PartialReduxExpr<XprType, MemberOp, Direction> > +template< typename ArgType, typename MemberOp, int Direction> +struct evaluator_impl<PartialReduxExpr<ArgType, MemberOp, Direction> > { - typedef PartialReduxExpr<XprType, MemberOp, Direction> PartialReduxExprType; + typedef PartialReduxExpr<ArgType, MemberOp, Direction> XprType; - evaluator_impl(const PartialReduxExprType expr) + evaluator_impl(const XprType expr) : m_expr(expr) { } - typedef typename PartialReduxExprType::Index Index; - typedef typename PartialReduxExprType::CoeffReturnType CoeffReturnType; + typedef typename XprType::Index Index; + typedef typename XprType::CoeffReturnType CoeffReturnType; CoeffReturnType coeff(Index row, Index col) const { @@ -700,7 +783,7 @@ struct evaluator_impl<PartialReduxExpr<XprType, MemberOp, Direction> > } protected: - const PartialReduxExprType& m_expr; + const XprType& m_expr; }; @@ -711,6 +794,7 @@ protected: template<typename ArgType> struct evaluator_impl_wrapper_base + : evaluator_impl_base<ArgType> { evaluator_impl_wrapper_base(const ArgType& arg) : m_argImpl(arg) {} @@ -772,7 +856,9 @@ template<typename ArgType> struct evaluator_impl<MatrixWrapper<ArgType> > : evaluator_impl_wrapper_base<ArgType> { - evaluator_impl(const MatrixWrapper<ArgType>& wrapper) + typedef MatrixWrapper<ArgType> XprType; + + evaluator_impl(const XprType& wrapper) : evaluator_impl_wrapper_base<ArgType>(wrapper.nestedExpression()) { } }; @@ -781,7 +867,9 @@ template<typename ArgType> struct evaluator_impl<ArrayWrapper<ArgType> > : evaluator_impl_wrapper_base<ArgType> { - evaluator_impl(const ArrayWrapper<ArgType>& wrapper) + typedef ArrayWrapper<ArgType> XprType; + + evaluator_impl(const XprType& wrapper) : evaluator_impl_wrapper_base<ArgType>(wrapper.nestedExpression()) { } }; @@ -794,24 +882,25 @@ template<typename PacketScalar, bool ReversePacket> struct reverse_packet_cond; template<typename ArgType, int Direction> struct evaluator_impl<Reverse<ArgType, Direction> > + : evaluator_impl_base<Reverse<ArgType, Direction> > { - typedef Reverse<ArgType, Direction> ReverseType; + typedef Reverse<ArgType, Direction> XprType; - evaluator_impl(const ReverseType& reverse) + evaluator_impl(const XprType& reverse) : m_argImpl(reverse.nestedExpression()), m_rows(reverse.nestedExpression().rows()), m_cols(reverse.nestedExpression().cols()) { } - typedef typename ReverseType::Index Index; - typedef typename ReverseType::Scalar Scalar; - typedef typename ReverseType::CoeffReturnType CoeffReturnType; - typedef typename ReverseType::PacketScalar PacketScalar; - typedef typename ReverseType::PacketReturnType PacketReturnType; + typedef typename XprType::Index Index; + typedef typename XprType::Scalar Scalar; + typedef typename XprType::CoeffReturnType CoeffReturnType; + typedef typename XprType::PacketScalar PacketScalar; + typedef typename XprType::PacketReturnType PacketReturnType; enum { PacketSize = internal::packet_traits<Scalar>::size, - IsRowMajor = ReverseType::IsRowMajor, + IsRowMajor = XprType::IsRowMajor, IsColMajor = !IsRowMajor, ReverseRow = (Direction == Vertical) || (Direction == BothDirections), ReverseCol = (Direction == Horizontal) || (Direction == BothDirections), @@ -885,17 +974,18 @@ protected: template<typename ArgType, int DiagIndex> struct evaluator_impl<Diagonal<ArgType, DiagIndex> > + : evaluator_impl_base<Diagonal<ArgType, DiagIndex> > { - typedef Diagonal<ArgType, DiagIndex> DiagonalType; + typedef Diagonal<ArgType, DiagIndex> XprType; - evaluator_impl(const DiagonalType& diagonal) + evaluator_impl(const XprType& diagonal) : m_argImpl(diagonal.nestedExpression()), m_index(diagonal.index()) { } - typedef typename DiagonalType::Index Index; - typedef typename DiagonalType::Scalar Scalar; - typedef typename DiagonalType::CoeffReturnType CoeffReturnType; + typedef typename XprType::Index Index; + typedef typename XprType::Scalar Scalar; + typedef typename XprType::CoeffReturnType CoeffReturnType; CoeffReturnType coeff(Index row, Index) const { diff --git a/test/evaluators.cpp b/test/evaluators.cpp index 5a123f0ad..ea957cb1e 100644 --- a/test/evaluators.cpp +++ b/test/evaluators.cpp @@ -214,5 +214,5 @@ void test_evaluators() copy_using_evaluator(mat1.diagonal<-1>(), mat1.diagonal(1)); mat2.diagonal<-1>() = mat2.diagonal(1); - VERIFY_IS_APPROX(mat1, mat2); + VERIFY_IS_APPROX(mat1, mat2); } |