aboutsummaryrefslogtreecommitdiffhomepage
path: root/bench/bench_gemm.cpp
diff options
context:
space:
mode:
authorGravatar Gael Guennebaud <g.gael@free.fr>2010-02-26 12:32:00 +0100
committerGravatar Gael Guennebaud <g.gael@free.fr>2010-02-26 12:32:00 +0100
commit3ac2b96a2f131e8162d39f0976cfb31b1a853237 (patch)
tree977798f989db9d3182f48807b43ba7269029d216 /bench/bench_gemm.cpp
parenta1e110332829a4bb38ca8e55608a2b048876018e (diff)
implement a smarter parallelization strategy for gemm avoiding multiple
paking of the same data
Diffstat (limited to 'bench/bench_gemm.cpp')
-rw-r--r--bench/bench_gemm.cpp79
1 files changed, 73 insertions, 6 deletions
diff --git a/bench/bench_gemm.cpp b/bench/bench_gemm.cpp
index ccc155dc5..d958cc1bf 100644
--- a/bench/bench_gemm.cpp
+++ b/bench/bench_gemm.cpp
@@ -15,6 +15,52 @@ using namespace Eigen;
typedef SCALAR Scalar;
typedef Matrix<Scalar,Dynamic,Dynamic> M;
+#ifdef HAVE_BLAS
+
+extern "C" {
+ #include <bench/btl/libs/C_BLAS/blas.h>
+
+ void sgemm_kernel(int actual_mc, int cols, int actual_kc, float alpha,
+ float* blockA, float* blockB, float* res, int resStride);
+ void sgemm_oncopy(int actual_kc, int cols, const float* rhs, int rhsStride, float* blockB);
+ void sgemm_itcopy(int actual_kc, int cols, const float* rhs, int rhsStride, float* blockB);
+}
+
+static float fone = 1;
+static float fzero = 0;
+static double done = 1;
+static double szero = 0;
+static char notrans = 'N';
+static char trans = 'T';
+static char nonunit = 'N';
+static char lower = 'L';
+static char right = 'R';
+static int intone = 1;
+
+void blas_gemm(const MatrixXf& a, const MatrixXf& b, MatrixXf& c)
+{
+ int M = c.rows(); int N = c.cols(); int K = a.cols();
+ int lda = a.rows(); int ldb = b.rows(); int ldc = c.rows();
+
+ sgemm_(&notrans,&notrans,&M,&N,&K,&fone,
+ const_cast<float*>(a.data()),&lda,
+ const_cast<float*>(b.data()),&ldb,&fone,
+ c.data(),&ldc);
+}
+
+void blas_gemm(const MatrixXd& a, const MatrixXd& b, MatrixXd& c)
+{
+ int M = c.rows(); int N = c.cols(); int K = a.cols();
+ int lda = a.rows(); int ldb = b.rows(); int ldc = c.rows();
+
+ dgemm_(&notrans,&notrans,&M,&N,&K,&done,
+ const_cast<double*>(a.data()),&lda,
+ const_cast<double*>(b.data()),&ldb,&done,
+ c.data(),&ldc);
+}
+
+#endif
+
void gemm(const M& a, const M& b, M& c)
{
c.noalias() += a * b;
@@ -22,21 +68,42 @@ void gemm(const M& a, const M& b, M& c)
int main(int argc, char ** argv)
{
- int rep = 1;
+ int rep = 1; // number of repetitions per try
+ int tries = 5; // number of tries, we keep the best
+
int s = 2048;
int m = s;
int n = s;
int p = s;
- M a(m,n); a.setOnes();
- M b(n,p); b.setOnes();
+ M a(m,n); a.setRandom();
+ M b(n,p); b.setRandom();
M c(m,p); c.setOnes();
BenchTimer t;
- BENCH(t, 5, rep, gemm(a,b,c));
+ M r = c;
+
+ // check the parallel product is correct
+ #ifdef HAVE_BLAS
+ blas_gemm(a,b,r);
+ #else
+ int procs = omp_get_max_threads();
+ omp_set_num_threads(1);
+ r.noalias() += a * b;
+ omp_set_num_threads(procs);
+ #endif
+ c.noalias() += a * b;
+ if(!r.isApprox(c)) std::cerr << "Warning, your parallel product is crap!\n\n";
+
+ #ifdef HAVE_BLAS
+ BENCH(t, tries, rep, blas_gemm(a,b,c));
+ std::cerr << "blas cpu " << t.best(CPU_TIMER)/rep << "s \t" << (double(m)*n*p*rep*2/t.best(CPU_TIMER))*1e-9 << " GFLOPS \t(" << t.total(CPU_TIMER) << "s)\n";
+ std::cerr << "blas real " << t.best(REAL_TIMER)/rep << "s \t" << (double(m)*n*p*rep*2/t.best(REAL_TIMER))*1e-9 << " GFLOPS \t(" << t.total(REAL_TIMER) << "s)\n";
+ #endif
- std::cerr << "cpu " << t.best(CPU_TIMER)/rep << "s \t" << (double(m)*n*p*rep*2/t.best(CPU_TIMER))*1e-9 << " GFLOPS \t(" << t.total(CPU_TIMER) << "s)\n";
- std::cerr << "real " << t.best(REAL_TIMER)/rep << "s \t" << (double(m)*n*p*rep*2/t.best(REAL_TIMER))*1e-9 << " GFLOPS \t(" << t.total(REAL_TIMER) << "s)\n";
+ BENCH(t, tries, rep, gemm(a,b,c));
+ std::cerr << "eigen cpu " << t.best(CPU_TIMER)/rep << "s \t" << (double(m)*n*p*rep*2/t.best(CPU_TIMER))*1e-9 << " GFLOPS \t(" << t.total(CPU_TIMER) << "s)\n";
+ std::cerr << "eigen real " << t.best(REAL_TIMER)/rep << "s \t" << (double(m)*n*p*rep*2/t.best(REAL_TIMER))*1e-9 << " GFLOPS \t(" << t.total(REAL_TIMER) << "s)\n";
return 0;
}