diff options
author | Gael Guennebaud <g.gael@free.fr> | 2010-02-23 13:06:49 +0100 |
---|---|---|
committer | Gael Guennebaud <g.gael@free.fr> | 2010-02-23 13:06:49 +0100 |
commit | eb905500b6c654860aa9f9d9c77c7c2614e0ad10 (patch) | |
tree | 73d13d1389ffb7594777e26a52823f6c45a48eec /bench | |
parent | d579d4cc37693823d03fbfedd2e48c40dcaf8938 (diff) |
significant speedup in the matrix-matrix products
Diffstat (limited to 'bench')
-rw-r--r-- | bench/bench_gemm.cpp | 4 | ||||
-rw-r--r-- | bench/bench_gemm_blas.cpp | 25 |
2 files changed, 22 insertions, 7 deletions
diff --git a/bench/bench_gemm.cpp b/bench/bench_gemm.cpp index e99fc2970..ccc155dc5 100644 --- a/bench/bench_gemm.cpp +++ b/bench/bench_gemm.cpp @@ -22,8 +22,8 @@ void gemm(const M& a, const M& b, M& c) int main(int argc, char ** argv) { - int rep = 2; - int s = 1024; + int rep = 1; + int s = 2048; int m = s; int n = s; int p = s; diff --git a/bench/bench_gemm_blas.cpp b/bench/bench_gemm_blas.cpp index a9dfaa66f..babf1ec2c 100644 --- a/bench/bench_gemm_blas.cpp +++ b/bench/bench_gemm_blas.cpp @@ -31,7 +31,6 @@ static int intone = 1; void blas_gemm(const MatrixXf& a, const MatrixXf& b, MatrixXf& c) { -// cblas_sgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, c.rows(), c.cols(), a.cols(), 1, a.data(), a.rows(), b.data(), b.rows(), 1, c.data(), c.rows()); int M = c.rows(); int N = c.cols(); int K = a.cols(); @@ -39,17 +38,33 @@ void blas_gemm(const MatrixXf& a, const MatrixXf& b, MatrixXf& c) int lda = a.rows(); int ldb = b.rows(); int ldc = c.rows(); - + sgemm_(¬rans,¬rans,&M,&N,&K,&fone, const_cast<float*>(a.data()),&lda, - const_cast<float*>(b.data()),&ldb,&fzero, + const_cast<float*>(b.data()),&ldb,&fone, + c.data(),&ldc); +} + +void blas_gemm(const MatrixXd& a, const MatrixXd& b, MatrixXd& c) +{ + int M = c.rows(); + int N = c.cols(); + int K = a.cols(); + + int lda = a.rows(); + int ldb = b.rows(); + int ldc = c.rows(); + + dgemm_(¬rans,¬rans,&M,&N,&K,&done, + const_cast<double*>(a.data()),&lda, + const_cast<double*>(b.data()),&ldb,&done, c.data(),&ldc); } int main(int argc, char **argv) { - int rep = 2; - int s = 1024; + int rep = 1; + int s = 2048; int m = s; int n = s; int p = s; |