diff options
Diffstat (limited to 'bench/btl/libs/C_BLAS/C_BLAS_interface.hh')
-rw-r--r-- | bench/btl/libs/C_BLAS/C_BLAS_interface.hh | 74 |
1 files changed, 48 insertions, 26 deletions
diff --git a/bench/btl/libs/C_BLAS/C_BLAS_interface.hh b/bench/btl/libs/C_BLAS/C_BLAS_interface.hh index d0148a29d..319658c6b 100644 --- a/bench/btl/libs/C_BLAS/C_BLAS_interface.hh +++ b/bench/btl/libs/C_BLAS/C_BLAS_interface.hh @@ -26,33 +26,42 @@ extern "C" { #include "cblas.h" -void sgemm_(const char *transa, const char *transb, const int *m, const int *n, const int *k, - const float *alpha, const float *a, const int *lda, const float *b, const int *ldb, - const float *beta, float *c, const int *ldc); +// #ifdef PUREBLAS +#include "blas.h" +// #endif -void sgemv_(const char *trans, const int *m, const int *n, const float *alpha, - const float *a, const int *lda, const float *x, const int *incx, - const float *beta, float *y, const int *incy); - -void sscal_(const int *n, const float *alpha, const float *x, const int *incx); - -void saxpy_(const int *n, const float *alpha, const float *x, const int *incx, - float *y, const int *incy); - -void strsv_(const char *uplo, const char *trans, const char *diag, const int *n, - const float *a, const int *lda, float *x, const int *incx); - -void scopy_(const int *n, const float *x, const int *incx, float *y, const int *incy); +// void sgemm_(const char *transa, const char *transb, const int *m, const int *n, const int *k, +// const float *alpha, const float *a, const int *lda, const float *b, const int *ldb, +// const float *beta, float *c, const int *ldc); +// +// void sgemv_(const char *trans, const int *m, const int *n, const float *alpha, +// const float *a, const int *lda, const float *x, const int *incx, +// const float *beta, float *y, const int *incy); +// +// void ssymv_(const char *trans, const char* uplo, +// const int* N, const float* alpha, const float *A, +// const int* lda, const float *X, const int* incX, +// const float* beta, float *Y, const int* incY); +// +// void sscal_(const int *n, const float *alpha, const float *x, const int *incx); +// +// void saxpy_(const int *n, const float *alpha, const float *x, const int *incx, +// float *y, const int *incy); +// +// void strsv_(const char *uplo, const char *trans, const char *diag, const int *n, +// const float *a, const int *lda, float *x, const int *incx); +// +// void scopy_(const int *n, const float *x, const int *incx, float *y, const int *incy); // Cholesky Factorization // #include "mkl_lapack.h" - void spotrf_(const char* uplo, const int* n, float *a, const int* ld, int* info); - void dpotrf_(const char* uplo, const int* n, double *a, const int* ld, int* info); +// void spotrf_(const char* uplo, const int* n, float *a, const int* ld, int* info); +// void dpotrf_(const char* uplo, const int* n, double *a, const int* ld, int* info); void ssytrd_(char *uplo, const int *n, float *a, const int *lda, float *d, float *e, float *tau, float *work, int *lwork, int *info ); void sgehrd_( const int *n, int *ilo, int *ihi, float *a, const int *lda, float *tau, float *work, int *lwork, int *info ); // LU row pivoting - void sgetrf_(const int* m, const int* n, float *a, const int* ld, int* ipivot, int* info); +// void sgetrf_(const int* m, const int* n, float *a, const int* ld, int* ipivot, int* info); // LU full pivoting void sgetc2_(const int* n, float *a, const int *lda, int *ipiv, int *jpiv, int*info ); #ifdef HAS_LAPACK @@ -85,6 +94,11 @@ public : cblas_dgemv(CblasColMajor,CblasTrans,N,N,1.0,A,N,B,1,0.0,X,1); } + static inline void symv(gene_matrix & A, gene_vector & B, gene_vector & X, int N) + { + cblas_dsymv(CblasColMajor,CblasLower,CblasTrans,N,N,1.0,A,N,B,1,0.0,X,1); + } + static inline void matrix_matrix_product(gene_matrix & A, gene_matrix & B, gene_matrix & X, int N){ cblas_dgemm(CblasColMajor,CblasNoTrans,CblasNoTrans,N,N,N,1.0,A,N,B,N,0.0,X,N); } @@ -112,13 +126,13 @@ public : }; -static const float fone = 1; -static const float fzero = 0; -static const char notrans = 'N'; -static const char trans = 'T'; -static const char nonunit = 'N'; -static const char lower = 'L'; -static const int intone = 1; +static float fone = 1; +static float fzero = 0; +static char notrans = 'N'; +static char trans = 'T'; +static char nonunit = 'N'; +static char lower = 'L'; +static blasint intone = 1; template<> class C_BLAS_interface<float> : public f77_interface_base<float> @@ -139,6 +153,14 @@ public : #endif } + static inline void symv(gene_matrix & A, gene_vector & B, gene_vector & X, int N){ + #ifdef PUREBLAS + ssymv_(&lower, &N,&fone,A,&N,B,&intone,&fzero,X,&intone); + #else + cblas_ssymv(CblasColMajor,CblasLower,N,1.0,A,N,B,1,0.0,X,1); + #endif + } + static inline void atv_product(gene_matrix & A, gene_vector & B, gene_vector & X, int N){ #ifdef PUREBLAS sgemv_(&trans,&N,&N,&fone,A,&N,B,&intone,&fzero,X,&intone); |