aboutsummaryrefslogtreecommitdiffhomepage
path: root/blas
diff options
context:
space:
mode:
authorGravatar Gael Guennebaud <g.gael@free.fr>2015-02-12 21:48:41 +0100
committerGravatar Gael Guennebaud <g.gael@free.fr>2015-02-12 21:48:41 +0100
commit0918c51e600bed36a53448fa276b01387119a3c2 (patch)
tree8183416a03dc22d1cc37b886e0e8f0dd0afe4e85 /blas
parent409547a0c83604b6dea70b8523674ac19e2af958 (diff)
parent4470c9997559522e9b81810948d9783b58444ae4 (diff)
merge Tensor module within Eigen/unsupported and update gemv BLAS wrapper
Diffstat (limited to 'blas')
-rw-r--r--blas/level2_impl.h20
-rw-r--r--blas/level3_impl.h12
2 files changed, 23 insertions, 9 deletions
diff --git a/blas/level2_impl.h b/blas/level2_impl.h
index 233c7b753..e604fe611 100644
--- a/blas/level2_impl.h
+++ b/blas/level2_impl.h
@@ -9,6 +9,20 @@
#include "common.h"
+template<typename Index, typename Scalar, int StorageOrder, bool ConjugateLhs, bool ConjugateRhs>
+struct general_matrix_vector_product_wrapper
+{
+ static void run(Index rows, Index cols,const Scalar *lhs, Index lhsStride, const Scalar *rhs, Index rhsIncr, Scalar* res, Index resIncr, Scalar alpha)
+ {
+ typedef internal::const_blas_data_mapper<Scalar,Index,StorageOrder> LhsMapper;
+ typedef internal::const_blas_data_mapper<Scalar,Index,RowMajor> RhsMapper;
+
+ internal::general_matrix_vector_product
+ <Index,Scalar,LhsMapper,StorageOrder,ConjugateLhs,Scalar,RhsMapper,ConjugateRhs>::run(
+ rows, cols, LhsMapper(lhs, lhsStride), RhsMapper(rhs, rhsIncr), res, resIncr, alpha);
+ }
+};
+
int EIGEN_BLAS_FUNC(gemv)(char *opa, int *m, int *n, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *pb, int *incb, RealScalar *pbeta, RealScalar *pc, int *incc)
{
typedef void (*functype)(int, int, const Scalar *, int, const Scalar *, int , Scalar *, int, Scalar);
@@ -20,9 +34,9 @@ int EIGEN_BLAS_FUNC(gemv)(char *opa, int *m, int *n, RealScalar *palpha, RealSca
for(int k=0; k<4; ++k)
func[k] = 0;
- func[NOTR] = (internal::general_matrix_vector_product<int,Scalar,ColMajor,false,Scalar,false>::run);
- func[TR ] = (internal::general_matrix_vector_product<int,Scalar,RowMajor,false,Scalar,false>::run);
- func[ADJ ] = (internal::general_matrix_vector_product<int,Scalar,RowMajor,Conj, Scalar,false>::run);
+ func[NOTR] = (general_matrix_vector_product_wrapper<int,Scalar,ColMajor,false,false>::run);
+ func[TR ] = (general_matrix_vector_product_wrapper<int,Scalar,RowMajor,false,false>::run);
+ func[ADJ ] = (general_matrix_vector_product_wrapper<int,Scalar,RowMajor,Conj ,false>::run);
init = true;
}
diff --git a/blas/level3_impl.h b/blas/level3_impl.h
index a05872666..37a803ced 100644
--- a/blas/level3_impl.h
+++ b/blas/level3_impl.h
@@ -56,7 +56,7 @@ int EIGEN_BLAS_FUNC(gemm)(char *opa, char *opb, int *m, int *n, int *k, RealScal
else matrix(c, *m, *n, *ldc) *= beta;
}
- internal::gemm_blocking_space<ColMajor,Scalar,Scalar,Dynamic,Dynamic,Dynamic> blocking(*m,*n,*k,true);
+ internal::gemm_blocking_space<ColMajor,Scalar,Scalar,Dynamic,Dynamic,Dynamic> blocking(*m,*n,*k,1,true);
int code = OP(*opa) | (OP(*opb) << 2);
func[code](*m, *n, *k, a, *lda, b, *ldb, c, *ldc, alpha, blocking, 0);
@@ -131,12 +131,12 @@ int EIGEN_BLAS_FUNC(trsm)(char *side, char *uplo, char *opa, char *diag, int *m,
if(SIDE(*side)==LEFT)
{
- internal::gemm_blocking_space<ColMajor,Scalar,Scalar,Dynamic,Dynamic,Dynamic,4> blocking(*m,*n,*m);
+ internal::gemm_blocking_space<ColMajor,Scalar,Scalar,Dynamic,Dynamic,Dynamic,4> blocking(*m,*n,*m,1,false);
func[code](*m, *n, a, *lda, b, *ldb, blocking);
}
else
{
- internal::gemm_blocking_space<ColMajor,Scalar,Scalar,Dynamic,Dynamic,Dynamic,4> blocking(*m,*n,*n);
+ internal::gemm_blocking_space<ColMajor,Scalar,Scalar,Dynamic,Dynamic,Dynamic,4> blocking(*m,*n,*n,1,false);
func[code](*n, *m, a, *lda, b, *ldb, blocking);
}
@@ -222,12 +222,12 @@ int EIGEN_BLAS_FUNC(trmm)(char *side, char *uplo, char *opa, char *diag, int *m,
if(SIDE(*side)==LEFT)
{
- internal::gemm_blocking_space<ColMajor,Scalar,Scalar,Dynamic,Dynamic,Dynamic,4> blocking(*m,*n,*m);
+ internal::gemm_blocking_space<ColMajor,Scalar,Scalar,Dynamic,Dynamic,Dynamic,4> blocking(*m,*n,*m,1,false);
func[code](*m, *n, *m, a, *lda, tmp.data(), tmp.outerStride(), b, *ldb, alpha, blocking);
}
else
{
- internal::gemm_blocking_space<ColMajor,Scalar,Scalar,Dynamic,Dynamic,Dynamic,4> blocking(*m,*n,*n);
+ internal::gemm_blocking_space<ColMajor,Scalar,Scalar,Dynamic,Dynamic,Dynamic,4> blocking(*m,*n,*n,1,false);
func[code](*m, *n, *n, tmp.data(), tmp.outerStride(), a, *lda, b, *ldb, alpha, blocking);
}
return 1;
@@ -577,7 +577,7 @@ int EIGEN_BLAS_FUNC(her2k)(char *uplo, char *op, int *n, int *k, RealScalar *pal
else if(*n<0) info = 3;
else if(*k<0) info = 4;
else if(*lda<std::max(1,(OP(*op)==NOTR)?*n:*k)) info = 7;
- else if(*lda<std::max(1,(OP(*op)==NOTR)?*n:*k)) info = 9;
+ else if(*ldb<std::max(1,(OP(*op)==NOTR)?*n:*k)) info = 9;
else if(*ldc<std::max(1,*n)) info = 12;
if(info)
return xerbla_(SCALAR_SUFFIX_UP"HER2K",&info,6);