aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen/src/Core/GeneralProduct.h
diff options
context:
space:
mode:
authorGravatar Gael Guennebaud <g.gael@free.fr>2014-02-21 17:13:28 +0100
committerGravatar Gael Guennebaud <g.gael@free.fr>2014-02-21 17:13:28 +0100
commitd67548f345d01c69d9dbba5869d8cc0159e96464 (patch)
tree07188e26fe6758f5311a857ba5597719a6c57581 /Eigen/src/Core/GeneralProduct.h
parent6c7ab508117d84671054808c921980a4908efb20 (diff)
Get rid of GeneralProduct<> for GemvProduct
Diffstat (limited to 'Eigen/src/Core/GeneralProduct.h')
-rw-r--r--Eigen/src/Core/GeneralProduct.h208
1 files changed, 193 insertions, 15 deletions
diff --git a/Eigen/src/Core/GeneralProduct.h b/Eigen/src/Core/GeneralProduct.h
index 4c0fc7f63..57d5d3c38 100644
--- a/Eigen/src/Core/GeneralProduct.h
+++ b/Eigen/src/Core/GeneralProduct.h
@@ -342,16 +342,19 @@ class GeneralProduct<Lhs, Rhs, OuterProduct>
*/
namespace internal {
+#ifndef EIGEN_TEST_EVALUATORS
template<typename Lhs, typename Rhs>
struct traits<GeneralProduct<Lhs,Rhs,GemvProduct> >
: traits<ProductBase<GeneralProduct<Lhs,Rhs,GemvProduct>, Lhs, Rhs> >
{};
+#endif
template<int Side, int StorageOrder, bool BlasCompatible>
struct gemv_selector;
} // end namespace internal
+#ifndef EIGEN_TEST_EVALUATORS
template<typename Lhs, typename Rhs>
class GeneralProduct<Lhs, Rhs, GemvProduct>
: public ProductBase<GeneralProduct<Lhs,Rhs,GemvProduct>, Lhs, Rhs>
@@ -378,24 +381,10 @@ class GeneralProduct<Lhs, Rhs, GemvProduct>
bool(internal::blas_traits<MatrixType>::HasUsableDirectAccess)>::run(*this, dst, alpha);
}
};
+#endif
namespace internal {
-// The vector is on the left => transposition
-template<int StorageOrder, bool BlasCompatible>
-struct gemv_selector<OnTheLeft,StorageOrder,BlasCompatible>
-{
- template<typename ProductType, typename Dest>
- static void run(const ProductType& prod, Dest& dest, const typename ProductType::Scalar& alpha)
- {
- Transpose<Dest> destT(dest);
- enum { OtherStorageOrder = StorageOrder == RowMajor ? ColMajor : RowMajor };
- gemv_selector<OnTheRight,OtherStorageOrder,BlasCompatible>
- ::run(GeneralProduct<Transpose<const typename ProductType::_RhsNested>,Transpose<const typename ProductType::_LhsNested>, GemvProduct>
- (prod.rhs().transpose(), prod.lhs().transpose()), destT, alpha);
- }
-};
-
template<typename Scalar,int Size,int MaxSize,bool Cond> struct gemv_static_vector_if;
template<typename Scalar,int Size,int MaxSize>
@@ -432,6 +421,23 @@ struct gemv_static_vector_if<Scalar,Size,MaxSize,true>
#endif
};
+#ifndef EIGEN_TEST_EVALUATORS
+
+// The vector is on the left => transposition
+template<int StorageOrder, bool BlasCompatible>
+struct gemv_selector<OnTheLeft,StorageOrder,BlasCompatible>
+{
+ template<typename ProductType, typename Dest>
+ static void run(const ProductType& prod, Dest& dest, const typename ProductType::Scalar& alpha)
+ {
+ Transpose<Dest> destT(dest);
+ enum { OtherStorageOrder = StorageOrder == RowMajor ? ColMajor : RowMajor };
+ gemv_selector<OnTheRight,OtherStorageOrder,BlasCompatible>
+ ::run(GeneralProduct<Transpose<const typename ProductType::_RhsNested>,Transpose<const typename ProductType::_LhsNested>, GemvProduct>
+ (prod.rhs().transpose(), prod.lhs().transpose()), destT, alpha);
+ }
+};
+
template<> struct gemv_selector<OnTheRight,ColMajor,true>
{
template<typename ProductType, typename Dest>
@@ -582,6 +588,178 @@ template<> struct gemv_selector<OnTheRight,RowMajor,false>
}
};
+#else // EIGEN_TEST_EVALUATORS
+
+// The vector is on the left => transposition
+template<int StorageOrder, bool BlasCompatible>
+struct gemv_selector<OnTheLeft,StorageOrder,BlasCompatible>
+{
+ template<typename Lhs, typename Rhs, typename Dest>
+ static void run(const Lhs &lhs, const Rhs &rhs, Dest& dest, const typename Dest::Scalar& alpha)
+ {
+ Transpose<Dest> destT(dest);
+ enum { OtherStorageOrder = StorageOrder == RowMajor ? ColMajor : RowMajor };
+ gemv_selector<OnTheRight,OtherStorageOrder,BlasCompatible>
+ ::run(rhs.transpose(), lhs.transpose(), destT, alpha);
+ }
+};
+
+template<> struct gemv_selector<OnTheRight,ColMajor,true>
+{
+ template<typename Lhs, typename Rhs, typename Dest>
+ static inline void run(const Lhs &lhs, const Rhs &rhs, Dest& dest, const typename Dest::Scalar& alpha)
+ {
+ typedef typename Dest::Index Index;
+ typedef typename Lhs::Scalar LhsScalar;
+ typedef typename Rhs::Scalar RhsScalar;
+ typedef typename Dest::Scalar ResScalar;
+ typedef typename Dest::RealScalar RealScalar;
+
+ typedef internal::blas_traits<Lhs> LhsBlasTraits;
+ typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
+ typedef internal::blas_traits<Rhs> RhsBlasTraits;
+ typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
+
+ typedef Map<Matrix<ResScalar,Dynamic,1>, Aligned> MappedDest;
+
+ ActualLhsType actualLhs = LhsBlasTraits::extract(lhs);
+ ActualRhsType actualRhs = RhsBlasTraits::extract(rhs);
+
+ ResScalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(lhs)
+ * RhsBlasTraits::extractScalarFactor(rhs);
+
+ enum {
+ // FIXME find a way to allow an inner stride on the result if packet_traits<Scalar>::size==1
+ // on, the other hand it is good for the cache to pack the vector anyways...
+ EvalToDestAtCompileTime = Dest::InnerStrideAtCompileTime==1,
+ ComplexByReal = (NumTraits<LhsScalar>::IsComplex) && (!NumTraits<RhsScalar>::IsComplex),
+ MightCannotUseDest = (Dest::InnerStrideAtCompileTime!=1) || ComplexByReal
+ };
+
+ gemv_static_vector_if<ResScalar,Dest::SizeAtCompileTime,Dest::MaxSizeAtCompileTime,MightCannotUseDest> static_dest;
+
+ bool alphaIsCompatible = (!ComplexByReal) || (numext::imag(actualAlpha)==RealScalar(0));
+ bool evalToDest = EvalToDestAtCompileTime && alphaIsCompatible;
+
+ RhsScalar compatibleAlpha = get_factor<ResScalar,RhsScalar>::run(actualAlpha);
+
+ ei_declare_aligned_stack_constructed_variable(ResScalar,actualDestPtr,dest.size(),
+ evalToDest ? dest.data() : static_dest.data());
+
+ if(!evalToDest)
+ {
+ #ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN
+ int size = dest.size();
+ EIGEN_DENSE_STORAGE_CTOR_PLUGIN
+ #endif
+ if(!alphaIsCompatible)
+ {
+ MappedDest(actualDestPtr, dest.size()).setZero();
+ compatibleAlpha = RhsScalar(1);
+ }
+ else
+ MappedDest(actualDestPtr, dest.size()) = dest;
+ }
+
+ general_matrix_vector_product
+ <Index,LhsScalar,ColMajor,LhsBlasTraits::NeedToConjugate,RhsScalar,RhsBlasTraits::NeedToConjugate>::run(
+ actualLhs.rows(), actualLhs.cols(),
+ actualLhs.data(), actualLhs.outerStride(),
+ actualRhs.data(), actualRhs.innerStride(),
+ actualDestPtr, 1,
+ compatibleAlpha);
+
+ if (!evalToDest)
+ {
+ if(!alphaIsCompatible)
+ dest += actualAlpha * MappedDest(actualDestPtr, dest.size());
+ else
+ dest = MappedDest(actualDestPtr, dest.size());
+ }
+ }
+};
+
+template<> struct gemv_selector<OnTheRight,RowMajor,true>
+{
+ template<typename Lhs, typename Rhs, typename Dest>
+ static void run(const Lhs &lhs, const Rhs &rhs, Dest& dest, const typename Dest::Scalar& alpha)
+ {
+ typedef typename Dest::Index Index;
+ typedef typename Lhs::Scalar LhsScalar;
+ typedef typename Rhs::Scalar RhsScalar;
+ typedef typename Dest::Scalar ResScalar;
+ typedef typename Dest::RealScalar RealScalar;
+
+ typedef internal::blas_traits<Lhs> LhsBlasTraits;
+ typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
+ typedef internal::blas_traits<Rhs> RhsBlasTraits;
+ typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
+ typedef typename internal::remove_all<ActualRhsType>::type ActualRhsTypeCleaned;
+
+ typename add_const<ActualLhsType>::type actualLhs = LhsBlasTraits::extract(lhs);
+ typename add_const<ActualRhsType>::type actualRhs = RhsBlasTraits::extract(rhs);
+
+ ResScalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(lhs)
+ * RhsBlasTraits::extractScalarFactor(rhs);
+
+ enum {
+ // FIXME find a way to allow an inner stride on the result if packet_traits<Scalar>::size==1
+ // on, the other hand it is good for the cache to pack the vector anyways...
+ DirectlyUseRhs = ActualRhsTypeCleaned::InnerStrideAtCompileTime==1
+ };
+
+ gemv_static_vector_if<RhsScalar,ActualRhsTypeCleaned::SizeAtCompileTime,ActualRhsTypeCleaned::MaxSizeAtCompileTime,!DirectlyUseRhs> static_rhs;
+
+ ei_declare_aligned_stack_constructed_variable(RhsScalar,actualRhsPtr,actualRhs.size(),
+ DirectlyUseRhs ? const_cast<RhsScalar*>(actualRhs.data()) : static_rhs.data());
+
+ if(!DirectlyUseRhs)
+ {
+ #ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN
+ int size = actualRhs.size();
+ EIGEN_DENSE_STORAGE_CTOR_PLUGIN
+ #endif
+ Map<typename ActualRhsTypeCleaned::PlainObject>(actualRhsPtr, actualRhs.size()) = actualRhs;
+ }
+
+ general_matrix_vector_product
+ <Index,LhsScalar,RowMajor,LhsBlasTraits::NeedToConjugate,RhsScalar,RhsBlasTraits::NeedToConjugate>::run(
+ actualLhs.rows(), actualLhs.cols(),
+ actualLhs.data(), actualLhs.outerStride(),
+ actualRhsPtr, 1,
+ dest.data(), dest.innerStride(),
+ actualAlpha);
+ }
+};
+
+template<> struct gemv_selector<OnTheRight,ColMajor,false>
+{
+ template<typename Lhs, typename Rhs, typename Dest>
+ static void run(const Lhs &lhs, const Rhs &rhs, Dest& dest, const typename Dest::Scalar& alpha)
+ {
+ typedef typename Dest::Index Index;
+ // TODO makes sure dest is sequentially stored in memory, otherwise use a temp
+ const Index size = rhs.rows();
+ for(Index k=0; k<size; ++k)
+ dest += (alpha*rhs.coeff(k)) * lhs.col(k);
+ }
+};
+
+template<> struct gemv_selector<OnTheRight,RowMajor,false>
+{
+ template<typename Lhs, typename Rhs, typename Dest>
+ static void run(const Lhs &lhs, const Rhs &rhs, Dest& dest, const typename Dest::Scalar& alpha)
+ {
+ typedef typename Dest::Index Index;
+ // TODO makes sure rhs is sequentially stored in memory, otherwise use a temp
+ const Index rows = dest.rows();
+ for(Index i=0; i<rows; ++i)
+ dest.coeffRef(i) += alpha * (lhs.row(i).cwiseProduct(rhs.transpose())).sum();
+ }
+};
+
+#endif // EIGEN_TEST_EVALUATORS
+
} // end namespace internal
/***************************************************************************