diff options
author | Gael Guennebaud <g.gael@free.fr> | 2010-07-07 16:41:29 +0200 |
---|---|---|
committer | Gael Guennebaud <g.gael@free.fr> | 2010-07-07 16:41:29 +0200 |
commit | 0f2d480af0d4e498056b9148a8e5a9b37d1fd321 (patch) | |
tree | 7bf297f771011187e89c902d5ab5d4ff23d74794 | |
parent | a2415388ef05154ca5f655a58694ce908e21213a (diff) |
add support for complex
-rw-r--r-- | bench/bench_gemm.cpp | 24 |
1 files changed, 20 insertions, 4 deletions
diff --git a/bench/bench_gemm.cpp b/bench/bench_gemm.cpp index 0da87b583..06a124f8f 100644 --- a/bench/bench_gemm.cpp +++ b/bench/bench_gemm.cpp @@ -10,7 +10,8 @@ using namespace std; using namespace Eigen; #ifndef SCALAR -#define SCALAR std::complex<float> +#define SCALAR std::complex<double> +// #define SCALAR double #endif typedef SCALAR Scalar; @@ -28,6 +29,8 @@ static double done = 1; static double szero = 0; static std::complex<float> cfone = 1; static std::complex<float> cfzero = 0; +static std::complex<double> cdone = 1; +static std::complex<double> cdzero = 0; static char notrans = 'N'; static char trans = 'T'; static char nonunit = 'N'; @@ -57,6 +60,17 @@ void blas_gemm(const MatrixXcf& a, const MatrixXcf& b, MatrixXcf& c) (float*)c.data(),&ldc); } +void blas_gemm(const MatrixXcd& a, const MatrixXcd& b, MatrixXcd& c) +{ + int M = c.rows(); int N = c.cols(); int K = a.cols(); + int lda = a.rows(); int ldb = b.rows(); int ldc = c.rows(); + + zgemm_(¬rans,¬rans,&M,&N,&K,(double*)&cdone, + const_cast<double*>((const double*)a.data()),&lda, + const_cast<double*>((const double*)b.data()),&ldb,(double*)&cdone, + (double*)c.data(),&ldc); +} + void blas_gemm(const MatrixXd& a, const MatrixXd& b, MatrixXd& c) { int M = c.rows(); int N = c.cols(); int K = a.cols(); @@ -71,7 +85,7 @@ void blas_gemm(const MatrixXd& a, const MatrixXd& b, MatrixXd& c) #endif template<typename M> -void gemm(const M& a, const M& b, M& c) +EIGEN_DONT_INLINE void gemm(const M& a, const M& b, M& c) { c.noalias() += a * b; } @@ -80,8 +94,10 @@ int main(int argc, char ** argv) { std::ptrdiff_t l1 = ei_queryL1CacheSize(); std::ptrdiff_t l2 = ei_queryTopLevelCacheSize(); - std::cout << "L1 cache size = " << (l1>0 ? l1/1024 : -1) << " KB\n"; - std::cout << "L2/L3 cache size = " << (l2>0 ? l2/1024 : -1) << " KB\n"; + std::cout << "L1 cache size = " << (l1>0 ? l1/1024 : -1) << " KB\n"; + std::cout << "L2/L3 cache size = " << (l2>0 ? l2/1024 : -1) << " KB\n"; + typedef ei_product_blocking_traits<Scalar> Blocking; + std::cout << "Register blocking = " << Blocking::mr << " x " << Blocking::nr << "\n"; int rep = 1; // number of repetitions per try int tries = 2; // number of tries, we keep the best |