From 0f2d480af0d4e498056b9148a8e5a9b37d1fd321 Mon Sep 17 00:00:00 2001 From: Gael Guennebaud Date: Wed, 7 Jul 2010 16:41:29 +0200 Subject: add support for complex --- bench/bench_gemm.cpp | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) (limited to 'bench/bench_gemm.cpp') diff --git a/bench/bench_gemm.cpp b/bench/bench_gemm.cpp index 0da87b583..06a124f8f 100644 --- a/bench/bench_gemm.cpp +++ b/bench/bench_gemm.cpp @@ -10,7 +10,8 @@ using namespace std; using namespace Eigen; #ifndef SCALAR -#define SCALAR std::complex +#define SCALAR std::complex +// #define SCALAR double #endif typedef SCALAR Scalar; @@ -28,6 +29,8 @@ static double done = 1; static double szero = 0; static std::complex cfone = 1; static std::complex cfzero = 0; +static std::complex cdone = 1; +static std::complex cdzero = 0; static char notrans = 'N'; static char trans = 'T'; static char nonunit = 'N'; @@ -57,6 +60,17 @@ void blas_gemm(const MatrixXcf& a, const MatrixXcf& b, MatrixXcf& c) (float*)c.data(),&ldc); } +void blas_gemm(const MatrixXcd& a, const MatrixXcd& b, MatrixXcd& c) +{ + int M = c.rows(); int N = c.cols(); int K = a.cols(); + int lda = a.rows(); int ldb = b.rows(); int ldc = c.rows(); + + zgemm_(¬rans,¬rans,&M,&N,&K,(double*)&cdone, + const_cast((const double*)a.data()),&lda, + const_cast((const double*)b.data()),&ldb,(double*)&cdone, + (double*)c.data(),&ldc); +} + void blas_gemm(const MatrixXd& a, const MatrixXd& b, MatrixXd& c) { int M = c.rows(); int N = c.cols(); int K = a.cols(); @@ -71,7 +85,7 @@ void blas_gemm(const MatrixXd& a, const MatrixXd& b, MatrixXd& c) #endif template -void gemm(const M& a, const M& b, M& c) +EIGEN_DONT_INLINE void gemm(const M& a, const M& b, M& c) { c.noalias() += a * b; } @@ -80,8 +94,10 @@ int main(int argc, char ** argv) { std::ptrdiff_t l1 = ei_queryL1CacheSize(); std::ptrdiff_t l2 = ei_queryTopLevelCacheSize(); - std::cout << "L1 cache size = " << (l1>0 ? l1/1024 : -1) << " KB\n"; - std::cout << "L2/L3 cache size = " << (l2>0 ? l2/1024 : -1) << " KB\n"; + std::cout << "L1 cache size = " << (l1>0 ? l1/1024 : -1) << " KB\n"; + std::cout << "L2/L3 cache size = " << (l2>0 ? l2/1024 : -1) << " KB\n"; + typedef ei_product_blocking_traits Blocking; + std::cout << "Register blocking = " << Blocking::mr << " x " << Blocking::nr << "\n"; int rep = 1; // number of repetitions per try int tries = 2; // number of tries, we keep the best -- cgit v1.2.3