aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen/src/Core/GeneralProduct.h
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2014-10-31 16:33:51 -0700
committerGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2014-10-31 16:33:51 -0700
commit2dde63499c4ef836a0d9dfd443494d863ad62b16 (patch)
tree572f24e62fb03318adb527eaefcdb4593cb31b50 /Eigen/src/Core/GeneralProduct.h
parent7f2c6ed2fa35d7f83f0da83c8564b7bd5b01d232 (diff)
Generalized the matrix vector product code.
Diffstat (limited to 'Eigen/src/Core/GeneralProduct.h')
-rw-r--r--Eigen/src/Core/GeneralProduct.h32
1 files changed, 18 insertions, 14 deletions
diff --git a/Eigen/src/Core/GeneralProduct.h b/Eigen/src/Core/GeneralProduct.h
index 7179eb124..9d3d5562c 100644
--- a/Eigen/src/Core/GeneralProduct.h
+++ b/Eigen/src/Core/GeneralProduct.h
@@ -11,7 +11,7 @@
#ifndef EIGEN_GENERAL_PRODUCT_H
#define EIGEN_GENERAL_PRODUCT_H
-namespace Eigen {
+namespace Eigen {
/** \class GeneralProduct
* \ingroup Core_Module
@@ -257,7 +257,7 @@ class GeneralProduct<Lhs, Rhs, OuterProduct>
: public ProductBase<GeneralProduct<Lhs,Rhs,OuterProduct>, Lhs, Rhs>
{
template<typename T> struct IsRowMajor : internal::conditional<(int(T::Flags)&RowMajorBit), internal::true_type, internal::false_type>::type {};
-
+
public:
EIGEN_PRODUCT_PUBLIC_INTERFACE(GeneralProduct)
@@ -266,7 +266,7 @@ class GeneralProduct<Lhs, Rhs, OuterProduct>
EIGEN_STATIC_ASSERT((internal::is_same<typename Lhs::RealScalar, typename Rhs::RealScalar>::value),
YOU_MIXED_DIFFERENT_NUMERIC_TYPES__YOU_NEED_TO_USE_THE_CAST_METHOD_OF_MATRIXBASE_TO_CAST_NUMERIC_TYPES_EXPLICITLY)
}
-
+
struct set { template<typename Dst, typename Src> void operator()(const Dst& dst, const Src& src) const { dst.const_cast_derived() = src; } };
struct add { template<typename Dst, typename Src> void operator()(const Dst& dst, const Src& src) const { dst.const_cast_derived() += src; } };
struct sub { template<typename Dst, typename Src> void operator()(const Dst& dst, const Src& src) const { dst.const_cast_derived() -= src; } };
@@ -277,12 +277,12 @@ class GeneralProduct<Lhs, Rhs, OuterProduct>
dst.const_cast_derived() += m_scale * src;
}
};
-
+
template<typename Dest>
inline void evalTo(Dest& dest) const {
internal::outer_product_selector_run(*this, dest, set(), IsRowMajor<Dest>());
}
-
+
template<typename Dest>
inline void addTo(Dest& dest) const {
internal::outer_product_selector_run(*this, dest, add(), IsRowMajor<Dest>());
@@ -436,12 +436,12 @@ template<> struct gemv_selector<OnTheRight,ColMajor,true>
bool alphaIsCompatible = (!ComplexByReal) || (numext::imag(actualAlpha)==RealScalar(0));
bool evalToDest = EvalToDestAtCompileTime && alphaIsCompatible;
-
+
RhsScalar compatibleAlpha = get_factor<ResScalar,RhsScalar>::run(actualAlpha);
ei_declare_aligned_stack_constructed_variable(ResScalar,actualDestPtr,dest.size(),
evalToDest ? dest.data() : static_dest.data());
-
+
if(!evalToDest)
{
#ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN
@@ -457,11 +457,13 @@ template<> struct gemv_selector<OnTheRight,ColMajor,true>
MappedDest(actualDestPtr, dest.size()) = dest;
}
+ typedef const_blas_data_mapper<LhsScalar,Index,ColMajor> LhsMapper;
+ typedef const_blas_data_mapper<RhsScalar,Index,RowMajor> RhsMapper;
general_matrix_vector_product
- <Index,LhsScalar,ColMajor,LhsBlasTraits::NeedToConjugate,RhsScalar,RhsBlasTraits::NeedToConjugate>::run(
+ <Index,LhsScalar,LhsMapper,ColMajor,LhsBlasTraits::NeedToConjugate,RhsScalar,RhsMapper,RhsBlasTraits::NeedToConjugate>::run(
actualLhs.rows(), actualLhs.cols(),
- actualLhs.data(), actualLhs.outerStride(),
- actualRhs.data(), actualRhs.innerStride(),
+ LhsMapper(actualLhs.data(), actualLhs.outerStride()),
+ RhsMapper(actualRhs.data(), actualRhs.innerStride()),
actualDestPtr, 1,
compatibleAlpha);
@@ -516,11 +518,13 @@ template<> struct gemv_selector<OnTheRight,RowMajor,true>
Map<typename _ActualRhsType::PlainObject>(actualRhsPtr, actualRhs.size()) = actualRhs;
}
+ typedef const_blas_data_mapper<LhsScalar,Index,RowMajor> LhsMapper;
+ typedef const_blas_data_mapper<RhsScalar,Index,ColMajor> RhsMapper;
general_matrix_vector_product
- <Index,LhsScalar,RowMajor,LhsBlasTraits::NeedToConjugate,RhsScalar,RhsBlasTraits::NeedToConjugate>::run(
+ <Index,LhsScalar,LhsMapper,RowMajor,LhsBlasTraits::NeedToConjugate,RhsScalar,RhsMapper,RhsBlasTraits::NeedToConjugate>::run(
actualLhs.rows(), actualLhs.cols(),
- actualLhs.data(), actualLhs.outerStride(),
- actualRhsPtr, 1,
+ LhsMapper(actualLhs.data(), actualLhs.outerStride()),
+ RhsMapper(actualRhsPtr, 1),
dest.data(), dest.innerStride(),
actualAlpha);
}
@@ -594,7 +598,7 @@ MatrixBase<Derived>::operator*(const MatrixBase<OtherDerived> &other) const
#ifdef EIGEN_DEBUG_PRODUCT
internal::product_type<Derived,OtherDerived>::debug();
#endif
-
+
return Product<Derived, OtherDerived>(derived(), other.derived());
}
#else