aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen/src/Core/Product.h
diff options
context:
space:
mode:
Diffstat (limited to 'Eigen/src/Core/Product.h')
-rw-r--r--Eigen/src/Core/Product.h185
1 files changed, 127 insertions, 58 deletions
diff --git a/Eigen/src/Core/Product.h b/Eigen/src/Core/Product.h
index a645ab6de..d63a7aa95 100644
--- a/Eigen/src/Core/Product.h
+++ b/Eigen/src/Core/Product.h
@@ -73,24 +73,9 @@ struct ProductReturnType<Lhs,Rhs,CacheFriendlyProduct>
typedef Product<LhsNested, RhsNested, CacheFriendlyProduct> Type;
};
-/* Helper class to determine the type of the product, can be either:
- * - NormalProduct
- * - CacheFriendlyProduct
- */
-template<typename Lhs, typename Rhs> struct ei_product_mode
-{
- enum{
-
- value = Lhs::MaxColsAtCompileTime == Dynamic
- && ( Lhs::MaxRowsAtCompileTime == Dynamic
- || Rhs::MaxColsAtCompileTime == Dynamic )
- && (!(Rhs::IsVectorAtCompileTime && (Lhs::Flags&RowMajorBit) && (!(Lhs::Flags&DirectAccessBit))))
- && (!(Lhs::IsVectorAtCompileTime && (!(Rhs::Flags&RowMajorBit)) && (!(Rhs::Flags&DirectAccessBit))))
- && (ei_is_same_type<typename Lhs::Scalar, typename Rhs::Scalar>::ret)
- ? CacheFriendlyProduct
- : NormalProduct };
-};
-
+/* Helper class to analyze the factors of a Product expression.
+ * In particular it allows to pop out operator-, scalar multiples,
+ * and conjugate */
template<typename XprType> struct ei_product_factor_traits
{
typedef typename ei_traits<XprType>::Scalar Scalar;
@@ -98,11 +83,10 @@ template<typename XprType> struct ei_product_factor_traits
enum {
IsComplex = NumTraits<Scalar>::IsComplex,
NeedToConjugate = false,
- HasScalarMultiple = false,
- Access = int(ei_traits<XprType>::Flags)&DirectAccessBit ? HasDirectAccess : NoDirectAccess
+ ActualAccess = int(ei_traits<XprType>::Flags)&DirectAccessBit ? HasDirectAccess : NoDirectAccess
};
static inline const ActualXprType& extract(const XprType& x) { return x; }
- static inline Scalar extractSalarFactor(const XprType&) { return Scalar(1); }
+ static inline Scalar extractScalarFactor(const XprType&) { return Scalar(1); }
};
// pop conjugate
@@ -117,8 +101,8 @@ template<typename Scalar, typename NestedXpr> struct ei_product_factor_traits<Cw
IsComplex = NumTraits<Scalar>::IsComplex,
NeedToConjugate = IsComplex
};
- static inline const ActualXprType& extract(const XprType& x) { return x._expression(); }
- static inline Scalar extractSalarFactor(const XprType& x) { return Base::extractSalarFactor(x._expression()); }
+ static inline const ActualXprType& extract(const XprType& x) { return Base::extract(x._expression()); }
+ static inline Scalar extractScalarFactor(const XprType& x) { return ei_conj(Base::extractScalarFactor(x._expression())); }
};
// pop scalar multiple
@@ -128,11 +112,41 @@ template<typename Scalar, typename NestedXpr> struct ei_product_factor_traits<Cw
typedef ei_product_factor_traits<NestedXpr> Base;
typedef CwiseUnaryOp<ei_scalar_multiple_op<Scalar>, NestedXpr> XprType;
typedef typename Base::ActualXprType ActualXprType;
- enum {
- HasScalarMultiple = true
- };
- static inline const ActualXprType& extract(const XprType& x) { return x._expression(); }
- static inline Scalar extractSalarFactor(const XprType& x) { return x._functor().m_other; }
+ static inline const ActualXprType& extract(const XprType& x) { return Base::extract(x._expression()); }
+ static inline Scalar extractScalarFactor(const XprType& x)
+ { return x._functor().m_other * Base::extractScalarFactor(x._expression()); }
+};
+
+// pop opposite
+template<typename Scalar, typename NestedXpr> struct ei_product_factor_traits<CwiseUnaryOp<ei_scalar_opposite_op<Scalar>, NestedXpr> >
+ : ei_product_factor_traits<NestedXpr>
+{
+ typedef ei_product_factor_traits<NestedXpr> Base;
+ typedef CwiseUnaryOp<ei_scalar_opposite_op<Scalar>, NestedXpr> XprType;
+ typedef typename Base::ActualXprType ActualXprType;
+ static inline const ActualXprType& extract(const XprType& x) { return Base::extract(x._expression()); }
+ static inline Scalar extractScalarFactor(const XprType& x)
+ { return - Base::extractScalarFactor(x._expression()); }
+};
+
+/* Helper class to determine the type of the product, can be either:
+ * - NormalProduct
+ * - CacheFriendlyProduct
+ */
+template<typename Lhs, typename Rhs> struct ei_product_mode
+{
+ typedef typename ei_product_factor_traits<Lhs>::ActualXprType ActualLhs;
+ typedef typename ei_product_factor_traits<Rhs>::ActualXprType ActualRhs;
+ enum{
+
+ value = Lhs::MaxColsAtCompileTime == Dynamic
+ && ( Lhs::MaxRowsAtCompileTime == Dynamic
+ || Rhs::MaxColsAtCompileTime == Dynamic )
+ && (!(Rhs::IsVectorAtCompileTime && (Lhs::Flags&RowMajorBit) && (!(ActualLhs::Flags&DirectAccessBit))))
+ && (!(Lhs::IsVectorAtCompileTime && (!(Rhs::Flags&RowMajorBit)) && (!(ActualRhs::Flags&DirectAccessBit))))
+ && (ei_is_same_type<typename Lhs::Scalar, typename Rhs::Scalar>::ret)
+ ? CacheFriendlyProduct
+ : NormalProduct };
};
/** \class Product
@@ -552,11 +566,11 @@ void ei_cache_friendly_product(
bool resRowMajor, Scalar* res, int resStride,
Scalar alpha);
-template<typename Scalar, typename RhsType>
+template<bool ConjugateLhs, bool ConjugateRhs, typename Scalar, typename RhsType>
static void ei_cache_friendly_product_colmajor_times_vector(
int size, const Scalar* lhs, int lhsStride, const RhsType& rhs, Scalar* res, Scalar alpha);
-template<typename Scalar, typename ResType>
+template<bool ConjugateLhs, bool ConjugateRhs, typename Scalar, typename ResType>
static void ei_cache_friendly_product_rowmajor_times_vector(
const Scalar* lhs, int lhsStride, const Scalar* rhs, int rhsSize, ResType& res, Scalar alpha);
@@ -572,10 +586,10 @@ static void ei_cache_friendly_product_rowmajor_times_vector(
template<typename ProductType,
int LhsRows = ei_traits<ProductType>::RowsAtCompileTime,
int LhsOrder = int(ei_traits<ProductType>::LhsFlags)&RowMajorBit ? RowMajor : ColMajor,
- int LhsHasDirectAccess = int(ei_traits<ProductType>::LhsFlags)&DirectAccessBit? HasDirectAccess : NoDirectAccess,
+ int LhsHasDirectAccess = ei_product_factor_traits<typename ei_traits<ProductType>::_LhsNested>::ActualAccess,
int RhsCols = ei_traits<ProductType>::ColsAtCompileTime,
int RhsOrder = int(ei_traits<ProductType>::RhsFlags)&RowMajorBit ? RowMajor : ColMajor,
- int RhsHasDirectAccess = int(ei_traits<ProductType>::RhsFlags)&DirectAccessBit? HasDirectAccess : NoDirectAccess>
+ int RhsHasDirectAccess = ei_product_factor_traits<typename ei_traits<ProductType>::_RhsNested>::ActualAccess>
struct ei_cache_friendly_product_selector
{
template<typename DestDerived>
@@ -592,7 +606,6 @@ struct ei_cache_friendly_product_selector<ProductType,LhsRows,ColMajor,NoDirectA
template<typename DestDerived>
inline static void run(DestDerived& res, const ProductType& product, typename ProductType::Scalar alpha)
{
- // FIXME is it really used ?
ei_assert(alpha==typename ProductType::Scalar(1));
const int size = product.rhs().rows();
for (int k=0; k<size; ++k)
@@ -606,10 +619,21 @@ template<typename ProductType, int LhsRows, int RhsOrder, int RhsAccess>
struct ei_cache_friendly_product_selector<ProductType,LhsRows,ColMajor,HasDirectAccess,1,RhsOrder,RhsAccess>
{
typedef typename ProductType::Scalar Scalar;
+ typedef ei_product_factor_traits<typename ei_traits<ProductType>::_LhsNested> LhsProductTraits;
+ typedef ei_product_factor_traits<typename ei_traits<ProductType>::_RhsNested> RhsProductTraits;
+ typedef typename LhsProductTraits::ActualXprType ActualLhsType;
+ typedef typename RhsProductTraits::ActualXprType ActualRhsType;
+
template<typename DestDerived>
inline static void run(DestDerived& res, const ProductType& product, typename ProductType::Scalar alpha)
{
+ const ActualLhsType& actualLhs = LhsProductTraits::extract(product.lhs());
+ const ActualRhsType& actualRhs = RhsProductTraits::extract(product.rhs());
+
+ Scalar actualAlpha = alpha * LhsProductTraits::extractScalarFactor(product.lhs())
+ * RhsProductTraits::extractScalarFactor(product.rhs());
+
enum {
EvalToRes = (ei_packet_traits<Scalar>::size==1)
||((DestDerived::Flags&ActualPacketAccessBit) && (!(DestDerived::Flags & RowMajorBit))) };
@@ -621,9 +645,12 @@ struct ei_cache_friendly_product_selector<ProductType,LhsRows,ColMajor,HasDirect
_res = ei_aligned_stack_new(Scalar,res.size());
Map<Matrix<Scalar,DestDerived::RowsAtCompileTime,1> >(_res, res.size()) = res;
}
- ei_cache_friendly_product_colmajor_times_vector(res.size(),
- &product.lhs().const_cast_derived().coeffRef(0,0), product.lhs().stride(),
- product.rhs(), _res, alpha);
+
+ ei_cache_friendly_product_colmajor_times_vector
+ <LhsProductTraits::NeedToConjugate,RhsProductTraits::NeedToConjugate>(
+ res.size(),
+ &actualLhs.const_cast_derived().coeffRef(0,0), actualLhs.stride(),
+ actualRhs, _res, actualAlpha);
if (!EvalToRes)
{
@@ -653,10 +680,21 @@ template<typename ProductType, int LhsOrder, int LhsAccess, int RhsCols>
struct ei_cache_friendly_product_selector<ProductType,1,LhsOrder,LhsAccess,RhsCols,RowMajor,HasDirectAccess>
{
typedef typename ProductType::Scalar Scalar;
+ typedef ei_product_factor_traits<typename ei_traits<ProductType>::_LhsNested> LhsProductTraits;
+ typedef ei_product_factor_traits<typename ei_traits<ProductType>::_RhsNested> RhsProductTraits;
+
+ typedef typename LhsProductTraits::ActualXprType ActualLhsType;
+ typedef typename RhsProductTraits::ActualXprType ActualRhsType;
template<typename DestDerived>
inline static void run(DestDerived& res, const ProductType& product, typename ProductType::Scalar alpha)
{
+ const ActualLhsType& actualLhs = LhsProductTraits::extract(product.lhs());
+ const ActualRhsType& actualRhs = RhsProductTraits::extract(product.rhs());
+
+ Scalar actualAlpha = alpha * LhsProductTraits::extractScalarFactor(product.lhs())
+ * RhsProductTraits::extractScalarFactor(product.rhs());
+
enum {
EvalToRes = (ei_packet_traits<Scalar>::size==1)
||((DestDerived::Flags & ActualPacketAccessBit) && (DestDerived::Flags & RowMajorBit)) };
@@ -668,9 +706,11 @@ struct ei_cache_friendly_product_selector<ProductType,1,LhsOrder,LhsAccess,RhsCo
_res = ei_aligned_stack_new(Scalar, res.size());
Map<Matrix<Scalar,DestDerived::SizeAtCompileTime,1> >(_res, res.size()) = res;
}
- ei_cache_friendly_product_colmajor_times_vector(res.size(),
- &product.rhs().const_cast_derived().coeffRef(0,0), product.rhs().stride(),
- product.lhs().transpose(), _res, alpha);
+
+ ei_cache_friendly_product_colmajor_times_vector
+ <RhsProductTraits::NeedToConjugate,LhsProductTraits::NeedToConjugate>(res.size(),
+ &actualRhs.const_cast_derived().coeffRef(0,0), actualRhs.stride(),
+ actualLhs.transpose(), _res, actualAlpha);
if (!EvalToRes)
{
@@ -685,24 +725,39 @@ template<typename ProductType, int LhsRows, int RhsOrder, int RhsAccess>
struct ei_cache_friendly_product_selector<ProductType,LhsRows,RowMajor,HasDirectAccess,1,RhsOrder,RhsAccess>
{
typedef typename ProductType::Scalar Scalar;
- typedef typename ei_traits<ProductType>::_RhsNested Rhs;
+
+ typedef ei_product_factor_traits<typename ei_traits<ProductType>::_LhsNested> LhsProductTraits;
+ typedef ei_product_factor_traits<typename ei_traits<ProductType>::_RhsNested> RhsProductTraits;
+
+ typedef typename LhsProductTraits::ActualXprType ActualLhsType;
+ typedef typename RhsProductTraits::ActualXprType ActualRhsType;
+
enum {
- UseRhsDirectly = ((ei_packet_traits<Scalar>::size==1) || (Rhs::Flags&ActualPacketAccessBit))
- && (!(Rhs::Flags & RowMajorBit)) };
+ UseRhsDirectly = ((ei_packet_traits<Scalar>::size==1) || (ActualRhsType::Flags&ActualPacketAccessBit))
+ && (!(ActualRhsType::Flags & RowMajorBit)) };
template<typename DestDerived>
inline static void run(DestDerived& res, const ProductType& product, typename ProductType::Scalar alpha)
{
+ const ActualLhsType& actualLhs = LhsProductTraits::extract(product.lhs());
+ const ActualRhsType& actualRhs = RhsProductTraits::extract(product.rhs());
+
+ Scalar actualAlpha = alpha * LhsProductTraits::extractScalarFactor(product.lhs())
+ * RhsProductTraits::extractScalarFactor(product.rhs());
+
Scalar* EIGEN_RESTRICT _rhs;
if (UseRhsDirectly)
- _rhs = &product.rhs().const_cast_derived().coeffRef(0);
+ _rhs = &actualRhs.const_cast_derived().coeffRef(0);
else
{
- _rhs = ei_aligned_stack_new(Scalar, product.rhs().size());
- Map<Matrix<Scalar,Rhs::SizeAtCompileTime,1> >(_rhs, product.rhs().size()) = product.rhs();
+ _rhs = ei_aligned_stack_new(Scalar, actualRhs.size());
+ Map<Matrix<Scalar,ActualRhsType::SizeAtCompileTime,1> >(_rhs, actualRhs.size()) = actualRhs;
}
- ei_cache_friendly_product_rowmajor_times_vector(&product.lhs().const_cast_derived().coeffRef(0,0), product.lhs().stride(),
- _rhs, product.rhs().size(), res, alpha);
+
+ ei_cache_friendly_product_rowmajor_times_vector
+ <LhsProductTraits::NeedToConjugate,RhsProductTraits::NeedToConjugate>(
+ &actualLhs.const_cast_derived().coeffRef(0,0), actualLhs.stride(),
+ _rhs, product.rhs().size(), res, actualAlpha);
if (!UseRhsDirectly) ei_aligned_stack_delete(Scalar, _rhs, product.rhs().size());
}
@@ -713,24 +768,39 @@ template<typename ProductType, int LhsOrder, int LhsAccess, int RhsCols>
struct ei_cache_friendly_product_selector<ProductType,1,LhsOrder,LhsAccess,RhsCols,ColMajor,HasDirectAccess>
{
typedef typename ProductType::Scalar Scalar;
- typedef typename ei_traits<ProductType>::_LhsNested Lhs;
+
+ typedef ei_product_factor_traits<typename ei_traits<ProductType>::_LhsNested> LhsProductTraits;
+ typedef ei_product_factor_traits<typename ei_traits<ProductType>::_RhsNested> RhsProductTraits;
+
+ typedef typename LhsProductTraits::ActualXprType ActualLhsType;
+ typedef typename RhsProductTraits::ActualXprType ActualRhsType;
+
enum {
- UseLhsDirectly = ((ei_packet_traits<Scalar>::size==1) || (Lhs::Flags&ActualPacketAccessBit))
- && (Lhs::Flags & RowMajorBit) };
+ UseLhsDirectly = ((ei_packet_traits<Scalar>::size==1) || (ActualLhsType::Flags&ActualPacketAccessBit))
+ && (ActualLhsType::Flags & RowMajorBit) };
template<typename DestDerived>
inline static void run(DestDerived& res, const ProductType& product, typename ProductType::Scalar alpha)
{
+ const ActualLhsType& actualLhs = LhsProductTraits::extract(product.lhs());
+ const ActualRhsType& actualRhs = RhsProductTraits::extract(product.rhs());
+
+ Scalar actualAlpha = alpha * LhsProductTraits::extractScalarFactor(product.lhs())
+ * RhsProductTraits::extractScalarFactor(product.rhs());
+
Scalar* EIGEN_RESTRICT _lhs;
if (UseLhsDirectly)
- _lhs = &product.lhs().const_cast_derived().coeffRef(0);
+ _lhs = &actualLhs.const_cast_derived().coeffRef(0);
else
{
- _lhs = ei_aligned_stack_new(Scalar, product.lhs().size());
- Map<Matrix<Scalar,Lhs::SizeAtCompileTime,1> >(_lhs, product.lhs().size()) = product.lhs();
+ _lhs = ei_aligned_stack_new(Scalar, actualLhs.size());
+ Map<Matrix<Scalar,ActualLhsType::SizeAtCompileTime,1> >(_lhs, actualLhs.size()) = actualLhs;
}
- ei_cache_friendly_product_rowmajor_times_vector(&product.rhs().const_cast_derived().coeffRef(0,0), product.rhs().stride(),
- _lhs, product.lhs().size(), res, alpha);
+
+ ei_cache_friendly_product_rowmajor_times_vector
+ <RhsProductTraits::NeedToConjugate, LhsProductTraits::NeedToConjugate>(
+ &actualRhs.const_cast_derived().coeffRef(0,0), actualRhs.stride(),
+ _lhs, product.lhs().size(), res, actualAlpha);
if(!UseLhsDirectly) ei_aligned_stack_delete(Scalar, _lhs, product.lhs().size());
}
@@ -827,8 +897,8 @@ inline void Product<Lhs,Rhs,ProductMode>::_cacheFriendlyEvalAndAdd(DestDerived&
const ActualLhsType& actualLhs = LhsProductTraits::extract(m_lhs);
const ActualRhsType& actualRhs = RhsProductTraits::extract(m_rhs);
- Scalar actualAlpha = alpha * LhsProductTraits::extractSalarFactor(m_lhs)
- * RhsProductTraits::extractSalarFactor(m_rhs);
+ Scalar actualAlpha = alpha * LhsProductTraits::extractScalarFactor(m_lhs)
+ * RhsProductTraits::extractScalarFactor(m_rhs);
typedef typename ei_product_copy_lhs<ActualLhsType>::type LhsCopy;
typedef typename ei_unref<LhsCopy>::type _LhsCopy;
@@ -837,7 +907,6 @@ inline void Product<Lhs,Rhs,ProductMode>::_cacheFriendlyEvalAndAdd(DestDerived&
LhsCopy lhs(actualLhs);
RhsCopy rhs(actualRhs);
ei_cache_friendly_product<Scalar,
-// LhsProductTraits::NeedToConjugate,RhsProductTraits::NeedToConjugate>
((int(Flags)&RowMajorBit) ? bool(RhsProductTraits::NeedToConjugate) : bool(LhsProductTraits::NeedToConjugate)),
((int(Flags)&RowMajorBit) ? bool(LhsProductTraits::NeedToConjugate) : bool(RhsProductTraits::NeedToConjugate))>
(