aboutsummaryrefslogtreecommitdiffhomepage
path: root/bench
diff options
context:
space:
mode:
authorGravatar Gael Guennebaud <g.gael@free.fr>2010-02-23 13:06:49 +0100
committerGravatar Gael Guennebaud <g.gael@free.fr>2010-02-23 13:06:49 +0100
commiteb905500b6c654860aa9f9d9c77c7c2614e0ad10 (patch)
tree73d13d1389ffb7594777e26a52823f6c45a48eec /bench
parentd579d4cc37693823d03fbfedd2e48c40dcaf8938 (diff)
significant speedup in the matrix-matrix products
Diffstat (limited to 'bench')
-rw-r--r--bench/bench_gemm.cpp4
-rw-r--r--bench/bench_gemm_blas.cpp25
2 files changed, 22 insertions, 7 deletions
diff --git a/bench/bench_gemm.cpp b/bench/bench_gemm.cpp
index e99fc2970..ccc155dc5 100644
--- a/bench/bench_gemm.cpp
+++ b/bench/bench_gemm.cpp
@@ -22,8 +22,8 @@ void gemm(const M& a, const M& b, M& c)
int main(int argc, char ** argv)
{
- int rep = 2;
- int s = 1024;
+ int rep = 1;
+ int s = 2048;
int m = s;
int n = s;
int p = s;
diff --git a/bench/bench_gemm_blas.cpp b/bench/bench_gemm_blas.cpp
index a9dfaa66f..babf1ec2c 100644
--- a/bench/bench_gemm_blas.cpp
+++ b/bench/bench_gemm_blas.cpp
@@ -31,7 +31,6 @@ static int intone = 1;
void blas_gemm(const MatrixXf& a, const MatrixXf& b, MatrixXf& c)
{
-// cblas_sgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, c.rows(), c.cols(), a.cols(), 1, a.data(), a.rows(), b.data(), b.rows(), 1, c.data(), c.rows());
int M = c.rows();
int N = c.cols();
int K = a.cols();
@@ -39,17 +38,33 @@ void blas_gemm(const MatrixXf& a, const MatrixXf& b, MatrixXf& c)
int lda = a.rows();
int ldb = b.rows();
int ldc = c.rows();
-
+
sgemm_(&notrans,&notrans,&M,&N,&K,&fone,
const_cast<float*>(a.data()),&lda,
- const_cast<float*>(b.data()),&ldb,&fzero,
+ const_cast<float*>(b.data()),&ldb,&fone,
+ c.data(),&ldc);
+}
+
+void blas_gemm(const MatrixXd& a, const MatrixXd& b, MatrixXd& c)
+{
+ int M = c.rows();
+ int N = c.cols();
+ int K = a.cols();
+
+ int lda = a.rows();
+ int ldb = b.rows();
+ int ldc = c.rows();
+
+ dgemm_(&notrans,&notrans,&M,&N,&K,&done,
+ const_cast<double*>(a.data()),&lda,
+ const_cast<double*>(b.data()),&ldb,&done,
c.data(),&ldc);
}
int main(int argc, char **argv)
{
- int rep = 2;
- int s = 1024;
+ int rep = 1;
+ int s = 2048;
int m = s;
int n = s;
int p = s;