aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Gael Guennebaud <g.gael@free.fr>2010-07-07 16:41:29 +0200
committerGravatar Gael Guennebaud <g.gael@free.fr>2010-07-07 16:41:29 +0200
commit0f2d480af0d4e498056b9148a8e5a9b37d1fd321 (patch)
tree7bf297f771011187e89c902d5ab5d4ff23d74794
parenta2415388ef05154ca5f655a58694ce908e21213a (diff)
add support for complex
-rw-r--r--bench/bench_gemm.cpp24
1 files changed, 20 insertions, 4 deletions
diff --git a/bench/bench_gemm.cpp b/bench/bench_gemm.cpp
index 0da87b583..06a124f8f 100644
--- a/bench/bench_gemm.cpp
+++ b/bench/bench_gemm.cpp
@@ -10,7 +10,8 @@ using namespace std;
using namespace Eigen;
#ifndef SCALAR
-#define SCALAR std::complex<float>
+#define SCALAR std::complex<double>
+// #define SCALAR double
#endif
typedef SCALAR Scalar;
@@ -28,6 +29,8 @@ static double done = 1;
static double szero = 0;
static std::complex<float> cfone = 1;
static std::complex<float> cfzero = 0;
+static std::complex<double> cdone = 1;
+static std::complex<double> cdzero = 0;
static char notrans = 'N';
static char trans = 'T';
static char nonunit = 'N';
@@ -57,6 +60,17 @@ void blas_gemm(const MatrixXcf& a, const MatrixXcf& b, MatrixXcf& c)
(float*)c.data(),&ldc);
}
+void blas_gemm(const MatrixXcd& a, const MatrixXcd& b, MatrixXcd& c)
+{
+ int M = c.rows(); int N = c.cols(); int K = a.cols();
+ int lda = a.rows(); int ldb = b.rows(); int ldc = c.rows();
+
+ zgemm_(&notrans,&notrans,&M,&N,&K,(double*)&cdone,
+ const_cast<double*>((const double*)a.data()),&lda,
+ const_cast<double*>((const double*)b.data()),&ldb,(double*)&cdone,
+ (double*)c.data(),&ldc);
+}
+
void blas_gemm(const MatrixXd& a, const MatrixXd& b, MatrixXd& c)
{
int M = c.rows(); int N = c.cols(); int K = a.cols();
@@ -71,7 +85,7 @@ void blas_gemm(const MatrixXd& a, const MatrixXd& b, MatrixXd& c)
#endif
template<typename M>
-void gemm(const M& a, const M& b, M& c)
+EIGEN_DONT_INLINE void gemm(const M& a, const M& b, M& c)
{
c.noalias() += a * b;
}
@@ -80,8 +94,10 @@ int main(int argc, char ** argv)
{
std::ptrdiff_t l1 = ei_queryL1CacheSize();
std::ptrdiff_t l2 = ei_queryTopLevelCacheSize();
- std::cout << "L1 cache size = " << (l1>0 ? l1/1024 : -1) << " KB\n";
- std::cout << "L2/L3 cache size = " << (l2>0 ? l2/1024 : -1) << " KB\n";
+ std::cout << "L1 cache size = " << (l1>0 ? l1/1024 : -1) << " KB\n";
+ std::cout << "L2/L3 cache size = " << (l2>0 ? l2/1024 : -1) << " KB\n";
+ typedef ei_product_blocking_traits<Scalar> Blocking;
+ std::cout << "Register blocking = " << Blocking::mr << " x " << Blocking::nr << "\n";
int rep = 1; // number of repetitions per try
int tries = 2; // number of tries, we keep the best