aboutsummaryrefslogtreecommitdiffhomepage
path: root/bench/bench_gemm_blas.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'bench/bench_gemm_blas.cpp')
-rw-r--r--bench/bench_gemm_blas.cpp25
1 files changed, 20 insertions, 5 deletions
diff --git a/bench/bench_gemm_blas.cpp b/bench/bench_gemm_blas.cpp
index a9dfaa66f..babf1ec2c 100644
--- a/bench/bench_gemm_blas.cpp
+++ b/bench/bench_gemm_blas.cpp
@@ -31,7 +31,6 @@ static int intone = 1;
void blas_gemm(const MatrixXf& a, const MatrixXf& b, MatrixXf& c)
{
-// cblas_sgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, c.rows(), c.cols(), a.cols(), 1, a.data(), a.rows(), b.data(), b.rows(), 1, c.data(), c.rows());
int M = c.rows();
int N = c.cols();
int K = a.cols();
@@ -39,17 +38,33 @@ void blas_gemm(const MatrixXf& a, const MatrixXf& b, MatrixXf& c)
int lda = a.rows();
int ldb = b.rows();
int ldc = c.rows();
-
+
sgemm_(&notrans,&notrans,&M,&N,&K,&fone,
const_cast<float*>(a.data()),&lda,
- const_cast<float*>(b.data()),&ldb,&fzero,
+ const_cast<float*>(b.data()),&ldb,&fone,
+ c.data(),&ldc);
+}
+
+void blas_gemm(const MatrixXd& a, const MatrixXd& b, MatrixXd& c)
+{
+ int M = c.rows();
+ int N = c.cols();
+ int K = a.cols();
+
+ int lda = a.rows();
+ int ldb = b.rows();
+ int ldc = c.rows();
+
+ dgemm_(&notrans,&notrans,&M,&N,&K,&done,
+ const_cast<double*>(a.data()),&lda,
+ const_cast<double*>(b.data()),&ldb,&done,
c.data(),&ldc);
}
int main(int argc, char **argv)
{
- int rep = 2;
- int s = 1024;
+ int rep = 1;
+ int s = 2048;
int m = s;
int n = s;
int p = s;