aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen/src/Core/Product.h
diff options
context:
space:
mode:
Diffstat (limited to 'Eigen/src/Core/Product.h')
-rw-r--r--Eigen/src/Core/Product.h49
1 files changed, 32 insertions, 17 deletions
diff --git a/Eigen/src/Core/Product.h b/Eigen/src/Core/Product.h
index 6849d90e3..a645ab6de 100644
--- a/Eigen/src/Core/Product.h
+++ b/Eigen/src/Core/Product.h
@@ -65,12 +65,11 @@ struct ProductReturnType
template<typename Lhs, typename Rhs>
struct ProductReturnType<Lhs,Rhs,CacheFriendlyProduct>
{
- typedef typename ei_nested<Lhs,Rhs::ColsAtCompileTime>::type LhsNested;
-
- typedef typename ei_nested<Rhs,Lhs::RowsAtCompileTime,
+ typedef typename ei_nested<Lhs,1>::type LhsNested;
+ typedef typename ei_nested<Rhs,1,
typename ei_plain_matrix_type_column_major<Rhs>::type
>::type RhsNested;
-
+
typedef Product<LhsNested, RhsNested, CacheFriendlyProduct> Type;
};
@@ -95,14 +94,14 @@ template<typename Lhs, typename Rhs> struct ei_product_mode
template<typename XprType> struct ei_product_factor_traits
{
typedef typename ei_traits<XprType>::Scalar Scalar;
- typedef XprType RealXprType;
+ typedef XprType ActualXprType;
enum {
IsComplex = NumTraits<Scalar>::IsComplex,
NeedToConjugate = false,
HasScalarMultiple = false,
Access = int(ei_traits<XprType>::Flags)&DirectAccessBit ? HasDirectAccess : NoDirectAccess
};
- static inline const RealXprType& extract(const XprType& x) { return x; }
+ static inline const ActualXprType& extract(const XprType& x) { return x; }
static inline Scalar extractSalarFactor(const XprType&) { return Scalar(1); }
};
@@ -112,13 +111,13 @@ template<typename Scalar, typename NestedXpr> struct ei_product_factor_traits<Cw
{
typedef ei_product_factor_traits<NestedXpr> Base;
typedef CwiseUnaryOp<ei_scalar_conjugate_op<Scalar>, NestedXpr> XprType;
- typedef typename Base::RealXprType RealXprType;
+ typedef typename Base::ActualXprType ActualXprType;
enum {
IsComplex = NumTraits<Scalar>::IsComplex,
NeedToConjugate = IsComplex
};
- static inline const RealXprType& extract(const XprType& x) { return x._expression(); }
+ static inline const ActualXprType& extract(const XprType& x) { return x._expression(); }
static inline Scalar extractSalarFactor(const XprType& x) { return Base::extractSalarFactor(x._expression()); }
};
@@ -128,12 +127,12 @@ template<typename Scalar, typename NestedXpr> struct ei_product_factor_traits<Cw
{
typedef ei_product_factor_traits<NestedXpr> Base;
typedef CwiseUnaryOp<ei_scalar_multiple_op<Scalar>, NestedXpr> XprType;
- typedef typename Base::RealXprType RealXprType;
+ typedef typename Base::ActualXprType ActualXprType;
enum {
HasScalarMultiple = true
};
- static inline const RealXprType& extract(const XprType& x) { return x._expression(); }
- static inline Scalar extractSalarFactor(const XprType& x) { return x._functor().value; }
+ static inline const ActualXprType& extract(const XprType& x) { return x._expression(); }
+ static inline Scalar extractSalarFactor(const XprType& x) { return x._functor().m_other; }
};
/** \class Product
@@ -819,18 +818,34 @@ template<typename Lhs, typename Rhs, int ProductMode>
template<typename DestDerived>
inline void Product<Lhs,Rhs,ProductMode>::_cacheFriendlyEvalAndAdd(DestDerived& res, Scalar alpha) const
{
- typedef typename ei_product_copy_lhs<_LhsNested>::type LhsCopy;
+ typedef ei_product_factor_traits<_LhsNested> LhsProductTraits;
+ typedef ei_product_factor_traits<_RhsNested> RhsProductTraits;
+
+ typedef typename LhsProductTraits::ActualXprType ActualLhsType;
+ typedef typename RhsProductTraits::ActualXprType ActualRhsType;
+
+ const ActualLhsType& actualLhs = LhsProductTraits::extract(m_lhs);
+ const ActualRhsType& actualRhs = RhsProductTraits::extract(m_rhs);
+
+ Scalar actualAlpha = alpha * LhsProductTraits::extractSalarFactor(m_lhs)
+ * RhsProductTraits::extractSalarFactor(m_rhs);
+
+ typedef typename ei_product_copy_lhs<ActualLhsType>::type LhsCopy;
typedef typename ei_unref<LhsCopy>::type _LhsCopy;
- typedef typename ei_product_copy_rhs<_RhsNested>::type RhsCopy;
+ typedef typename ei_product_copy_rhs<ActualRhsType>::type RhsCopy;
typedef typename ei_unref<RhsCopy>::type _RhsCopy;
- LhsCopy lhs(m_lhs);
- RhsCopy rhs(m_rhs);
- ei_cache_friendly_product<Scalar,false,false>(
+ LhsCopy lhs(actualLhs);
+ RhsCopy rhs(actualRhs);
+ ei_cache_friendly_product<Scalar,
+// LhsProductTraits::NeedToConjugate,RhsProductTraits::NeedToConjugate>
+ ((int(Flags)&RowMajorBit) ? bool(RhsProductTraits::NeedToConjugate) : bool(LhsProductTraits::NeedToConjugate)),
+ ((int(Flags)&RowMajorBit) ? bool(LhsProductTraits::NeedToConjugate) : bool(RhsProductTraits::NeedToConjugate))>
+ (
rows(), cols(), lhs.cols(),
_LhsCopy::Flags&RowMajorBit, (const Scalar*)&(lhs.const_cast_derived().coeffRef(0,0)), lhs.stride(),
_RhsCopy::Flags&RowMajorBit, (const Scalar*)&(rhs.const_cast_derived().coeffRef(0,0)), rhs.stride(),
Flags&RowMajorBit, (Scalar*)&(res.coeffRef(0,0)), res.stride(),
- alpha
+ actualAlpha
);
}