aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--Eigen/src/Core/CwiseUnaryOp.h3
-rw-r--r--Eigen/src/Core/Product.h49
-rw-r--r--Eigen/src/Core/products/GeneralMatrixMatrix.h135
-rw-r--r--test/product_large.cpp19
4 files changed, 116 insertions, 90 deletions
diff --git a/Eigen/src/Core/CwiseUnaryOp.h b/Eigen/src/Core/CwiseUnaryOp.h
index 0095a1572..3ffb24833 100644
--- a/Eigen/src/Core/CwiseUnaryOp.h
+++ b/Eigen/src/Core/CwiseUnaryOp.h
@@ -96,7 +96,8 @@ class CwiseUnaryOp : ei_no_assignment_operator,
const UnaryOp& _functor() const { return m_functor; }
/** \internal used for introspection */
- const typename MatrixType::Nested& _expression() const { return m_matrix; }
+ const typename ei_cleantype<typename MatrixType::Nested>::type&
+ _expression() const { return m_matrix; }
protected:
const typename MatrixType::Nested m_matrix;
diff --git a/Eigen/src/Core/Product.h b/Eigen/src/Core/Product.h
index 6849d90e3..a645ab6de 100644
--- a/Eigen/src/Core/Product.h
+++ b/Eigen/src/Core/Product.h
@@ -65,12 +65,11 @@ struct ProductReturnType
template<typename Lhs, typename Rhs>
struct ProductReturnType<Lhs,Rhs,CacheFriendlyProduct>
{
- typedef typename ei_nested<Lhs,Rhs::ColsAtCompileTime>::type LhsNested;
-
- typedef typename ei_nested<Rhs,Lhs::RowsAtCompileTime,
+ typedef typename ei_nested<Lhs,1>::type LhsNested;
+ typedef typename ei_nested<Rhs,1,
typename ei_plain_matrix_type_column_major<Rhs>::type
>::type RhsNested;
-
+
typedef Product<LhsNested, RhsNested, CacheFriendlyProduct> Type;
};
@@ -95,14 +94,14 @@ template<typename Lhs, typename Rhs> struct ei_product_mode
template<typename XprType> struct ei_product_factor_traits
{
typedef typename ei_traits<XprType>::Scalar Scalar;
- typedef XprType RealXprType;
+ typedef XprType ActualXprType;
enum {
IsComplex = NumTraits<Scalar>::IsComplex,
NeedToConjugate = false,
HasScalarMultiple = false,
Access = int(ei_traits<XprType>::Flags)&DirectAccessBit ? HasDirectAccess : NoDirectAccess
};
- static inline const RealXprType& extract(const XprType& x) { return x; }
+ static inline const ActualXprType& extract(const XprType& x) { return x; }
static inline Scalar extractSalarFactor(const XprType&) { return Scalar(1); }
};
@@ -112,13 +111,13 @@ template<typename Scalar, typename NestedXpr> struct ei_product_factor_traits<Cw
{
typedef ei_product_factor_traits<NestedXpr> Base;
typedef CwiseUnaryOp<ei_scalar_conjugate_op<Scalar>, NestedXpr> XprType;
- typedef typename Base::RealXprType RealXprType;
+ typedef typename Base::ActualXprType ActualXprType;
enum {
IsComplex = NumTraits<Scalar>::IsComplex,
NeedToConjugate = IsComplex
};
- static inline const RealXprType& extract(const XprType& x) { return x._expression(); }
+ static inline const ActualXprType& extract(const XprType& x) { return x._expression(); }
static inline Scalar extractSalarFactor(const XprType& x) { return Base::extractSalarFactor(x._expression()); }
};
@@ -128,12 +127,12 @@ template<typename Scalar, typename NestedXpr> struct ei_product_factor_traits<Cw
{
typedef ei_product_factor_traits<NestedXpr> Base;
typedef CwiseUnaryOp<ei_scalar_multiple_op<Scalar>, NestedXpr> XprType;
- typedef typename Base::RealXprType RealXprType;
+ typedef typename Base::ActualXprType ActualXprType;
enum {
HasScalarMultiple = true
};
- static inline const RealXprType& extract(const XprType& x) { return x._expression(); }
- static inline Scalar extractSalarFactor(const XprType& x) { return x._functor().value; }
+ static inline const ActualXprType& extract(const XprType& x) { return x._expression(); }
+ static inline Scalar extractSalarFactor(const XprType& x) { return x._functor().m_other; }
};
/** \class Product
@@ -819,18 +818,34 @@ template<typename Lhs, typename Rhs, int ProductMode>
template<typename DestDerived>
inline void Product<Lhs,Rhs,ProductMode>::_cacheFriendlyEvalAndAdd(DestDerived& res, Scalar alpha) const
{
- typedef typename ei_product_copy_lhs<_LhsNested>::type LhsCopy;
+ typedef ei_product_factor_traits<_LhsNested> LhsProductTraits;
+ typedef ei_product_factor_traits<_RhsNested> RhsProductTraits;
+
+ typedef typename LhsProductTraits::ActualXprType ActualLhsType;
+ typedef typename RhsProductTraits::ActualXprType ActualRhsType;
+
+ const ActualLhsType& actualLhs = LhsProductTraits::extract(m_lhs);
+ const ActualRhsType& actualRhs = RhsProductTraits::extract(m_rhs);
+
+ Scalar actualAlpha = alpha * LhsProductTraits::extractSalarFactor(m_lhs)
+ * RhsProductTraits::extractSalarFactor(m_rhs);
+
+ typedef typename ei_product_copy_lhs<ActualLhsType>::type LhsCopy;
typedef typename ei_unref<LhsCopy>::type _LhsCopy;
- typedef typename ei_product_copy_rhs<_RhsNested>::type RhsCopy;
+ typedef typename ei_product_copy_rhs<ActualRhsType>::type RhsCopy;
typedef typename ei_unref<RhsCopy>::type _RhsCopy;
- LhsCopy lhs(m_lhs);
- RhsCopy rhs(m_rhs);
- ei_cache_friendly_product<Scalar,false,false>(
+ LhsCopy lhs(actualLhs);
+ RhsCopy rhs(actualRhs);
+ ei_cache_friendly_product<Scalar,
+// LhsProductTraits::NeedToConjugate,RhsProductTraits::NeedToConjugate>
+ ((int(Flags)&RowMajorBit) ? bool(RhsProductTraits::NeedToConjugate) : bool(LhsProductTraits::NeedToConjugate)),
+ ((int(Flags)&RowMajorBit) ? bool(LhsProductTraits::NeedToConjugate) : bool(RhsProductTraits::NeedToConjugate))>
+ (
rows(), cols(), lhs.cols(),
_LhsCopy::Flags&RowMajorBit, (const Scalar*)&(lhs.const_cast_derived().coeffRef(0,0)), lhs.stride(),
_RhsCopy::Flags&RowMajorBit, (const Scalar*)&(rhs.const_cast_derived().coeffRef(0,0)), rhs.stride(),
Flags&RowMajorBit, (Scalar*)&(res.coeffRef(0,0)), res.stride(),
- alpha
+ actualAlpha
);
}
diff --git a/Eigen/src/Core/products/GeneralMatrixMatrix.h b/Eigen/src/Core/products/GeneralMatrixMatrix.h
index 4630e5040..db63eadf9 100644
--- a/Eigen/src/Core/products/GeneralMatrixMatrix.h
+++ b/Eigen/src/Core/products/GeneralMatrixMatrix.h
@@ -58,6 +58,9 @@ template<> struct ei_conj_pmadd<true,true>
#ifndef EIGEN_EXTERN_INSTANTIATIONS
+/** \warning you should never call this function directly,
+ * this is because the ConjugateLhs/ConjugateRhs have to
+ * be flipped is resRowMajor==true */
template<typename Scalar, bool ConjugateLhs, bool ConjugateRhs>
static void ei_cache_friendly_product(
int _rows, int _cols, int depth,
@@ -76,6 +79,12 @@ static void ei_cache_friendly_product(
if (resRowMajor)
{
+// return ei_cache_friendly_product<Scalar,ConjugateRhs,ConjugateLhs>(_cols,_rows,depth,
+// !_rhsRowMajor, _rhs, _rhsStride,
+// !_lhsRowMajor, _lhs, _lhsStride,
+// false, res, resStride,
+// alpha);
+
lhs = _rhs;
rhs = _lhs;
lhsStride = _rhsStride;
@@ -252,59 +261,59 @@ static void ei_cache_friendly_product(
A1 = ei_pload(&blA[1*PacketSize]);
B0 = ei_pload(&blB[0*PacketSize]);
B1 = ei_pload(&blB[1*PacketSize]);
- C0 = cj_pmadd(B0, A0, C0);
+ C0 = cj_pmadd(A0, B0, C0);
if(nr==4) B2 = ei_pload(&blB[2*PacketSize]);
- C4 = cj_pmadd(B0, A1, C4);
+ C4 = cj_pmadd(A1, B0, C4);
if(nr==4) B3 = ei_pload(&blB[3*PacketSize]);
B0 = ei_pload(&blB[(nr==4 ? 4 : 2)*PacketSize]);
- C1 = cj_pmadd(B1, A0, C1);
- C5 = cj_pmadd(B1, A1, C5);
+ C1 = cj_pmadd(A0, B1, C1);
+ C5 = cj_pmadd(A1, B1, C5);
B1 = ei_pload(&blB[(nr==4 ? 5 : 3)*PacketSize]);
- if(nr==4) C2 = cj_pmadd(B2, A0, C2);
- if(nr==4) C6 = cj_pmadd(B2, A1, C6);
+ if(nr==4) C2 = cj_pmadd(A0, B2, C2);
+ if(nr==4) C6 = cj_pmadd(A1, B2, C6);
if(nr==4) B2 = ei_pload(&blB[6*PacketSize]);
- if(nr==4) C3 = cj_pmadd(B3, A0, C3);
+ if(nr==4) C3 = cj_pmadd(A0, B3, C3);
A0 = ei_pload(&blA[2*PacketSize]);
- if(nr==4) C7 = cj_pmadd(B3, A1, C7);
+ if(nr==4) C7 = cj_pmadd(A1, B3, C7);
A1 = ei_pload(&blA[3*PacketSize]);
if(nr==4) B3 = ei_pload(&blB[7*PacketSize]);
- C0 = cj_pmadd(B0, A0, C0);
- C4 = cj_pmadd(B0, A1, C4);
+ C0 = cj_pmadd(A0, B0, C0);
+ C4 = cj_pmadd(A1, B0, C4);
B0 = ei_pload(&blB[(nr==4 ? 8 : 4)*PacketSize]);
- C1 = cj_pmadd(B1, A0, C1);
- C5 = cj_pmadd(B1, A1, C5);
+ C1 = cj_pmadd(A0, B1, C1);
+ C5 = cj_pmadd(A1, B1, C5);
B1 = ei_pload(&blB[(nr==4 ? 9 : 5)*PacketSize]);
- if(nr==4) C2 = cj_pmadd(B2, A0, C2);
- if(nr==4) C6 = cj_pmadd(B2, A1, C6);
+ if(nr==4) C2 = cj_pmadd(A0, B2, C2);
+ if(nr==4) C6 = cj_pmadd(A1, B2, C6);
if(nr==4) B2 = ei_pload(&blB[10*PacketSize]);
- if(nr==4) C3 = cj_pmadd(B3, A0, C3);
+ if(nr==4) C3 = cj_pmadd(A0, B3, C3);
A0 = ei_pload(&blA[4*PacketSize]);
- if(nr==4) C7 = cj_pmadd(B3, A1, C7);
+ if(nr==4) C7 = cj_pmadd(A1, B3, C7);
A1 = ei_pload(&blA[5*PacketSize]);
if(nr==4) B3 = ei_pload(&blB[11*PacketSize]);
- C0 = cj_pmadd(B0, A0, C0);
- C4 = cj_pmadd(B0, A1, C4);
+ C0 = cj_pmadd(A0, B0, C0);
+ C4 = cj_pmadd(A1, B0, C4);
B0 = ei_pload(&blB[(nr==4 ? 12 : 6)*PacketSize]);
- C1 = cj_pmadd(B1, A0, C1);
- C5 = cj_pmadd(B1, A1, C5);
+ C1 = cj_pmadd(A0, B1, C1);
+ C5 = cj_pmadd(A1, B1, C5);
B1 = ei_pload(&blB[(nr==4 ? 13 : 7)*PacketSize]);
- if(nr==4) C2 = cj_pmadd(B2, A0, C2);
- if(nr==4) C6 = cj_pmadd(B2, A1, C6);
+ if(nr==4) C2 = cj_pmadd(A0, B2, C2);
+ if(nr==4) C6 = cj_pmadd(A1, B2, C6);
if(nr==4) B2 = ei_pload(&blB[14*PacketSize]);
- if(nr==4) C3 = cj_pmadd(B3, A0, C3);
+ if(nr==4) C3 = cj_pmadd(A0, B3, C3);
A0 = ei_pload(&blA[6*PacketSize]);
- if(nr==4) C7 = cj_pmadd(B3, A1, C7);
+ if(nr==4) C7 = cj_pmadd(A1, B3, C7);
A1 = ei_pload(&blA[7*PacketSize]);
if(nr==4) B3 = ei_pload(&blB[15*PacketSize]);
- C0 = cj_pmadd(B0, A0, C0);
- C4 = cj_pmadd(B0, A1, C4);
- C1 = cj_pmadd(B1, A0, C1);
- C5 = cj_pmadd(B1, A1, C5);
- if(nr==4) C2 = cj_pmadd(B2, A0, C2);
- if(nr==4) C6 = cj_pmadd(B2, A1, C6);
- if(nr==4) C3 = cj_pmadd(B3, A0, C3);
- if(nr==4) C7 = cj_pmadd(B3, A1, C7);
+ C0 = cj_pmadd(A0, B0, C0);
+ C4 = cj_pmadd(A1, B0, C4);
+ C1 = cj_pmadd(A0, B1, C1);
+ C5 = cj_pmadd(A1, B1, C5);
+ if(nr==4) C2 = cj_pmadd(A0, B2, C2);
+ if(nr==4) C6 = cj_pmadd(A1, B2, C6);
+ if(nr==4) C3 = cj_pmadd(A0, B3, C3);
+ if(nr==4) C7 = cj_pmadd(A1, B3, C7);
blB += 4*nr*PacketSize;
blA += 4*mr;
@@ -318,16 +327,16 @@ static void ei_cache_friendly_product(
A1 = ei_pload(&blA[1*PacketSize]);
B0 = ei_pload(&blB[0*PacketSize]);
B1 = ei_pload(&blB[1*PacketSize]);
- C0 = cj_pmadd(B0, A0, C0);
+ C0 = cj_pmadd(A0, B0, C0);
if(nr==4) B2 = ei_pload(&blB[2*PacketSize]);
- C4 = cj_pmadd(B0, A1, C4);
+ C4 = cj_pmadd(A1, B0, C4);
if(nr==4) B3 = ei_pload(&blB[3*PacketSize]);
- C1 = cj_pmadd(B1, A0, C1);
- C5 = cj_pmadd(B1, A1, C5);
- if(nr==4) C2 = cj_pmadd(B2, A0, C2);
- if(nr==4) C6 = cj_pmadd(B2, A1, C6);
- if(nr==4) C3 = cj_pmadd(B3, A0, C3);
- if(nr==4) C7 = cj_pmadd(B3, A1, C7);
+ C1 = cj_pmadd(A0, B1, C1);
+ C5 = cj_pmadd(A1, B1, C5);
+ if(nr==4) C2 = cj_pmadd(A0, B2, C2);
+ if(nr==4) C6 = cj_pmadd(A1, B2, C6);
+ if(nr==4) C3 = cj_pmadd(A0, B3, C3);
+ if(nr==4) C7 = cj_pmadd(A1, B3, C7);
blB += nr*PacketSize;
blA += mr;
@@ -359,12 +368,12 @@ static void ei_cache_friendly_product(
A0 = blA[k];
B0 = blB[0*PacketSize];
B1 = blB[1*PacketSize];
- C0 += B0 * A0;
+ C0 = cj_pmadd(A0, B0, C0);
if(nr==4) B2 = blB[2*PacketSize];
if(nr==4) B3 = blB[3*PacketSize];
- C1 += B1 * A0;
- if(nr==4) C2 += B2 * A0;
- if(nr==4) C3 += B3 * A0;
+ C1 = cj_pmadd(A0, B1, C1);
+ if(nr==4) C2 = cj_pmadd(A0, B2, C2);
+ if(nr==4) C3 = cj_pmadd(A0, B3, C3);
blB += nr*PacketSize;
}
@@ -382,10 +391,10 @@ static void ei_cache_friendly_product(
Scalar c0 = Scalar(0);
if (lhsRowMajor)
for(int k=0; k<actual_kc; k++)
- c0 += lhs[(k2+k)+(i2+i)*lhsStride] * rhs[j2*rhsStride + k2 + k];
+ c0 = cj_pmadd(lhs[(k2+k)+(i2+i)*lhsStride], rhs[j2*rhsStride + k2 + k], c0);
else
for(int k=0; k<actual_kc; k++)
- c0 += lhs[(k2+k)*lhsStride + i2+i] * rhs[j2*rhsStride + k2 + k];
+ c0 = cj_pmadd(lhs[(k2+k)*lhsStride + i2+i], rhs[j2*rhsStride + k2 + k], c0);
res[(j2)*resStride + i2+i] += alpha * c0;
}
}
@@ -395,6 +404,8 @@ static void ei_cache_friendly_product(
ei_aligned_stack_delete(Scalar, blockA, kc*mc);
ei_aligned_stack_delete(Scalar, blockB, kc*cols*PacketSize);
+
+
#else // alternate product from cylmor
enum {
@@ -482,39 +493,39 @@ static void ei_cache_friendly_product(
L0 = ei_pload(&lb[1*PacketSize]);
R1 = ei_pload(&lb[2*PacketSize]);
L1 = ei_pload(&lb[3*PacketSize]);
- T0 = cj_pmadd(R0, A0, T0);
- T1 = cj_pmadd(L0, A0, T1);
+ T0 = cj_pmadd(A0, R0, T0);
+ T1 = cj_pmadd(A0, L0, T1);
R0 = ei_pload(&lb[4*PacketSize]);
L0 = ei_pload(&lb[5*PacketSize]);
- T0 = cj_pmadd(R1, A1, T0);
- T1 = cj_pmadd(L1, A1, T1);
+ T0 = cj_pmadd(A1, R1, T0);
+ T1 = cj_pmadd(A1, L1, T1);
R1 = ei_pload(&lb[6*PacketSize]);
L1 = ei_pload(&lb[7*PacketSize]);
- T0 = cj_pmadd(R0, A2, T0);
- T1 = cj_pmadd(L0, A2, T1);
+ T0 = cj_pmadd(A2, R0, T0);
+ T1 = cj_pmadd(A2, L0, T1);
if(MaxBlockRows==8)
{
R0 = ei_pload(&lb[8*PacketSize]);
L0 = ei_pload(&lb[9*PacketSize]);
}
- T0 = cj_pmadd(R1, A3, T0);
- T1 = cj_pmadd(L1, A3, T1);
+ T0 = cj_pmadd(A3, R1, T0);
+ T1 = cj_pmadd(A3, L1, T1);
if(MaxBlockRows==8)
{
R1 = ei_pload(&lb[10*PacketSize]);
L1 = ei_pload(&lb[11*PacketSize]);
- T0 = cj_pmadd(R0, A4, T0);
- T1 = cj_pmadd(L0, A4, T1);
+ T0 = cj_pmadd(A4, R0, T0);
+ T1 = cj_pmadd(A4, L0, T1);
R0 = ei_pload(&lb[12*PacketSize]);
L0 = ei_pload(&lb[13*PacketSize]);
- T0 = cj_pmadd(R1, A5, T0);
- T1 = cj_pmadd(L1, A5, T1);
+ T0 = cj_pmadd(A5, R1, T0);
+ T1 = cj_pmadd(A5, L1, T1);
R1 = ei_pload(&lb[14*PacketSize]);
L1 = ei_pload(&lb[15*PacketSize]);
- T0 = cj_pmadd(R0, A6, T0);
- T1 = cj_pmadd(L0, A6, T1);
- T0 = cj_pmadd(R1, A7, T0);
- T1 = cj_pmadd(L1, A7, T1);
+ T0 = cj_pmadd(A6, R0, T0);
+ T1 = cj_pmadd(A6, L0, T1);
+ T0 = cj_pmadd(A7, R1, T0);
+ T1 = cj_pmadd(A7, L1, T1);
}
lb += MaxBlockRows*2*PacketSize;
diff --git a/test/product_large.cpp b/test/product_large.cpp
index c327f70b3..77ae7b587 100644
--- a/test/product_large.cpp
+++ b/test/product_large.cpp
@@ -28,19 +28,18 @@ void test_product_large()
{
for(int i = 0; i < g_repeat; i++) {
CALL_SUBTEST( product(MatrixXf(ei_random<int>(1,320), ei_random<int>(1,320))) );
- //CALL_SUBTEST( product(MatrixXf(ei_random<int>(1,320), ei_random<int>(1,320))) );
-// CALL_SUBTEST( product(MatrixXd(ei_random<int>(1,320), ei_random<int>(1,320))) );
-// CALL_SUBTEST( product(MatrixXi(ei_random<int>(1,320), ei_random<int>(1,320))) );
-// CALL_SUBTEST( product(MatrixXcf(ei_random<int>(1,50), ei_random<int>(1,50))) );
-// CALL_SUBTEST( product(Matrix<float,Dynamic,Dynamic,RowMajor>(ei_random<int>(1,320), ei_random<int>(1,320))) );
+ CALL_SUBTEST( product(MatrixXd(ei_random<int>(1,320), ei_random<int>(1,320))) );
+ CALL_SUBTEST( product(MatrixXi(ei_random<int>(1,320), ei_random<int>(1,320))) );
+ CALL_SUBTEST( product(MatrixXcf(ei_random<int>(1,50), ei_random<int>(1,50))) );
+ CALL_SUBTEST( product(Matrix<float,Dynamic,Dynamic,RowMajor>(ei_random<int>(1,320), ei_random<int>(1,320))) );
}
{
// test a specific issue in DiagonalProduct
-// int N = 1000000;
-// VectorXf v = VectorXf::Ones(N);
-// MatrixXf m = MatrixXf::Ones(N,3);
-// m = (v+v).asDiagonal() * m;
-// VERIFY_IS_APPROX(m, MatrixXf::Constant(N,3,2));
+ int N = 1000000;
+ VectorXf v = VectorXf::Ones(N);
+ MatrixXf m = MatrixXf::Ones(N,3);
+ m = (v+v).asDiagonal() * m;
+ VERIFY_IS_APPROX(m, MatrixXf::Constant(N,3,2));
}
}