diff options
Diffstat (limited to 'blas/level2_impl.h')
-rw-r--r-- | blas/level2_impl.h | 20 |
1 files changed, 17 insertions, 3 deletions
diff --git a/blas/level2_impl.h b/blas/level2_impl.h index 233c7b753..e604fe611 100644 --- a/blas/level2_impl.h +++ b/blas/level2_impl.h @@ -9,6 +9,20 @@ #include "common.h" +template<typename Index, typename Scalar, int StorageOrder, bool ConjugateLhs, bool ConjugateRhs> +struct general_matrix_vector_product_wrapper +{ + static void run(Index rows, Index cols,const Scalar *lhs, Index lhsStride, const Scalar *rhs, Index rhsIncr, Scalar* res, Index resIncr, Scalar alpha) + { + typedef internal::const_blas_data_mapper<Scalar,Index,StorageOrder> LhsMapper; + typedef internal::const_blas_data_mapper<Scalar,Index,RowMajor> RhsMapper; + + internal::general_matrix_vector_product + <Index,Scalar,LhsMapper,StorageOrder,ConjugateLhs,Scalar,RhsMapper,ConjugateRhs>::run( + rows, cols, LhsMapper(lhs, lhsStride), RhsMapper(rhs, rhsIncr), res, resIncr, alpha); + } +}; + int EIGEN_BLAS_FUNC(gemv)(char *opa, int *m, int *n, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *pb, int *incb, RealScalar *pbeta, RealScalar *pc, int *incc) { typedef void (*functype)(int, int, const Scalar *, int, const Scalar *, int , Scalar *, int, Scalar); @@ -20,9 +34,9 @@ int EIGEN_BLAS_FUNC(gemv)(char *opa, int *m, int *n, RealScalar *palpha, RealSca for(int k=0; k<4; ++k) func[k] = 0; - func[NOTR] = (internal::general_matrix_vector_product<int,Scalar,ColMajor,false,Scalar,false>::run); - func[TR ] = (internal::general_matrix_vector_product<int,Scalar,RowMajor,false,Scalar,false>::run); - func[ADJ ] = (internal::general_matrix_vector_product<int,Scalar,RowMajor,Conj, Scalar,false>::run); + func[NOTR] = (general_matrix_vector_product_wrapper<int,Scalar,ColMajor,false,false>::run); + func[TR ] = (general_matrix_vector_product_wrapper<int,Scalar,RowMajor,false,false>::run); + func[ADJ ] = (general_matrix_vector_product_wrapper<int,Scalar,RowMajor,Conj ,false>::run); init = true; } |