diff options
Diffstat (limited to 'Eigen/src/Core/CacheFriendlyProduct.h')
-rw-r--r-- | Eigen/src/Core/CacheFriendlyProduct.h | 42 |
1 files changed, 14 insertions, 28 deletions
diff --git a/Eigen/src/Core/CacheFriendlyProduct.h b/Eigen/src/Core/CacheFriendlyProduct.h index 9d4b0af36..bd9f4d0d9 100644 --- a/Eigen/src/Core/CacheFriendlyProduct.h +++ b/Eigen/src/Core/CacheFriendlyProduct.h @@ -419,16 +419,19 @@ EIGEN_DONT_INLINE static void ei_cache_friendly_product_colmajor_times_vector( ei_internal_assert((alignmentPattern==NoneAligned) || (size_t(lhs+alignedStart+lhsStride*skipColumns)%sizeof(Packet))==0); } + + int offset1 = (FirstAligned && alignmentStep==1?3:1); + int offset3 = (FirstAligned && alignmentStep==1?1:3); int columnBound = ((rhs.size()-skipColumns)/columnsAtOnce)*columnsAtOnce + skipColumns; for (int i=skipColumns; i<columnBound; i+=columnsAtOnce) { - Packet ptmp0 = ei_pset1(rhs[i]), ptmp1 = ei_pset1(rhs[i+1]), - ptmp2 = ei_pset1(rhs[i+2]), ptmp3 = ei_pset1(rhs[i+3]); + Packet ptmp0 = ei_pset1(rhs[i]), ptmp1 = ei_pset1(rhs[i+offset1]), + ptmp2 = ei_pset1(rhs[i+2]), ptmp3 = ei_pset1(rhs[i+offset3]); // this helps a lot generating better binary code - const Scalar *lhs0 = lhs + i*lhsStride, *lhs1 = lhs + (i+1)*lhsStride, - *lhs2 = lhs + (i+2)*lhsStride, *lhs3 = lhs + (i+3)*lhsStride; + const Scalar *lhs0 = lhs + i*lhsStride, *lhs1 = lhs + (i+offset1)*lhsStride, + *lhs2 = lhs + (i+2)*lhsStride, *lhs3 = lhs + (i+offset3)*lhsStride; if (PacketSize>1) { @@ -453,17 +456,11 @@ EIGEN_DONT_INLINE static void ei_cache_friendly_product_colmajor_times_vector( if(peels>1) { Packet A00, A01, A02, A03, A10, A11, A12, A13; - if (alignmentStep==1) - { - A00 = ptmp1; ptmp1 = ptmp3; ptmp3 = A00; - const Scalar* aux = lhs1; - lhs1 = lhs3; lhs3 = aux; - } A01 = ei_pload(&lhs1[alignedStart-1]); A02 = ei_pload(&lhs2[alignedStart-2]); A03 = ei_pload(&lhs3[alignedStart-3]); - + for (int j = alignedStart; j<peeledSize; j+=peels*PacketSize) { A11 = ei_pload(&lhs1[j-1+PacketSize]); ei_palign<1>(A01,A11); @@ -613,6 +610,9 @@ EIGEN_DONT_INLINE static void ei_cache_friendly_product_rowmajor_times_vector( ei_internal_assert((alignmentPattern==NoneAligned) || PacketSize==1 || (size_t(lhs+alignedStart+lhsStride*skipRows)%sizeof(Packet))==0); } + + int offset1 = (FirstAligned && alignmentStep==1?3:1); + int offset3 = (FirstAligned && alignmentStep==1?1:3); int rowBound = ((res.size()-skipRows)/rowsAtOnce)*rowsAtOnce + skipRows; for (int i=skipRows; i<rowBound; i+=rowsAtOnce) @@ -620,8 +620,8 @@ EIGEN_DONT_INLINE static void ei_cache_friendly_product_rowmajor_times_vector( Scalar tmp0 = Scalar(0), tmp1 = Scalar(0), tmp2 = Scalar(0), tmp3 = Scalar(0); // this helps the compiler generating good binary code - const Scalar *lhs0 = lhs + i*lhsStride, *lhs1 = lhs + (i+1)*lhsStride, - *lhs2 = lhs + (i+2)*lhsStride, *lhs3 = lhs + (i+3)*lhsStride; + const Scalar *lhs0 = lhs + i*lhsStride, *lhs1 = lhs + (i+offset1)*lhsStride, + *lhs2 = lhs + (i+2)*lhsStride, *lhs3 = lhs + (i+offset3)*lhsStride; if (PacketSize>1) { @@ -658,13 +658,6 @@ EIGEN_DONT_INLINE static void ei_cache_friendly_product_rowmajor_times_vector( * than basic unaligned loads. */ Packet A01, A02, A03, b, A11, A12, A13; - if (alignmentStep==1) - { - // flip row #1 and #3 - b = ptmp1; ptmp1 = ptmp3; ptmp3 = b; - const Scalar* aux = lhs1; - lhs1 = lhs3; lhs3 = aux; - } A01 = ei_pload(&lhs1[alignedStart-1]); A02 = ei_pload(&lhs2[alignedStart-2]); A03 = ei_pload(&lhs3[alignedStart-3]); @@ -690,13 +683,6 @@ EIGEN_DONT_INLINE static void ei_cache_friendly_product_rowmajor_times_vector( ptmp2 = ei_pmadd(b, A12, ptmp2); ptmp3 = ei_pmadd(b, A13, ptmp3); } - if (alignmentStep==1) - { - // restore rows #1 and #3 - b = ptmp1; ptmp1 = ptmp3; ptmp3 = b; - const Scalar* aux = lhs1; - lhs1 = lhs3; lhs3 = aux; - } } for (int j = peeledSize; j<alignedSize; j+=PacketSize) _EIGEN_ACCUMULATE_PACKETS(,u,u,); @@ -720,7 +706,7 @@ EIGEN_DONT_INLINE static void ei_cache_friendly_product_rowmajor_times_vector( Scalar b = rhs[j]; tmp0 += b*lhs0[j]; tmp1 += b*lhs1[j]; tmp2 += b*lhs2[j]; tmp3 += b*lhs3[j]; } - res[i] += tmp0; res[i+1] += tmp1; res[i+2] += tmp2; res[i+3] += tmp3; + res[i] += tmp0; res[i+offset1] += tmp1; res[i+2] += tmp2; res[i+offset3] += tmp3; } // process remaining first and last rows (at most columnsAtOnce-1) |