aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen
diff options
context:
space:
mode:
authorGravatar Gael Guennebaud <g.gael@free.fr>2017-08-17 12:07:10 +0200
committerGravatar Gael Guennebaud <g.gael@free.fr>2017-08-17 12:07:10 +0200
commitb95f92843c58a914c46ab091009146288b8b775c (patch)
tree2614c1eb9d2f8a21d0bcaf03635ffcb605a3f65f /Eigen
parent89c01a494aff2bd03b48a9858eed95a4a7ce9556 (diff)
Fix support for MKL's BLAS when using MKL_DIRECT_CALL.
Diffstat (limited to 'Eigen')
-rw-r--r--Eigen/src/Core/products/GeneralMatrixMatrixTriangular_BLAS.h8
-rw-r--r--Eigen/src/Core/products/GeneralMatrixMatrix_BLAS.h19
-rw-r--r--Eigen/src/Core/products/GeneralMatrixVector_BLAS.h19
-rw-r--r--Eigen/src/Core/products/SelfadjointMatrixMatrix_BLAS.h48
-rw-r--r--Eigen/src/Core/products/SelfadjointMatrixVector_BLAS.h9
-rw-r--r--Eigen/src/Core/products/TriangularMatrixMatrix_BLAS.h39
-rw-r--r--Eigen/src/Core/products/TriangularMatrixVector_BLAS.h46
-rw-r--r--Eigen/src/Core/products/TriangularSolverMatrix_BLAS.h40
-rwxr-xr-xEigen/src/Core/util/MKL_support.h8
9 files changed, 157 insertions, 79 deletions
diff --git a/Eigen/src/Core/products/GeneralMatrixMatrixTriangular_BLAS.h b/Eigen/src/Core/products/GeneralMatrixMatrixTriangular_BLAS.h
index 41e18ff07..9176a1382 100644
--- a/Eigen/src/Core/products/GeneralMatrixMatrixTriangular_BLAS.h
+++ b/Eigen/src/Core/products/GeneralMatrixMatrixTriangular_BLAS.h
@@ -88,7 +88,7 @@ struct general_matrix_matrix_rankupdate<Index,EIGTYPE,AStorageOrder,ConjugateA,C
BlasIndex lda=convert_index<BlasIndex>(lhsStride), ldc=convert_index<BlasIndex>(resStride), n=convert_index<BlasIndex>(size), k=convert_index<BlasIndex>(depth); \
char uplo=((IsLower) ? 'L' : 'U'), trans=((AStorageOrder==RowMajor) ? 'T':'N'); \
EIGTYPE beta(1); \
- BLASFUNC(&uplo, &trans, &n, &k, &numext::real_ref(alpha), lhs, &lda, &numext::real_ref(beta), res, &ldc); \
+ BLASFUNC(&uplo, &trans, &n, &k, (const BLASTYPE*)&numext::real_ref(alpha), lhs, &lda, (const BLASTYPE*)&numext::real_ref(beta), res, &ldc); \
} \
};
@@ -125,9 +125,13 @@ struct general_matrix_matrix_rankupdate<Index,EIGTYPE,AStorageOrder,ConjugateA,C
} \
};
-
+#ifdef EIGEN_USE_MKL
+EIGEN_BLAS_RANKUPDATE_R(double, double, dsyrk)
+EIGEN_BLAS_RANKUPDATE_R(float, float, ssyrk)
+#else
EIGEN_BLAS_RANKUPDATE_R(double, double, dsyrk_)
EIGEN_BLAS_RANKUPDATE_R(float, float, ssyrk_)
+#endif
// TODO hanlde complex cases
// EIGEN_BLAS_RANKUPDATE_C(dcomplex, double, double, zherk_)
diff --git a/Eigen/src/Core/products/GeneralMatrixMatrix_BLAS.h b/Eigen/src/Core/products/GeneralMatrixMatrix_BLAS.h
index 7a3bdbf20..b0f6b0d5b 100644
--- a/Eigen/src/Core/products/GeneralMatrixMatrix_BLAS.h
+++ b/Eigen/src/Core/products/GeneralMatrixMatrix_BLAS.h
@@ -46,7 +46,7 @@ namespace internal {
// gemm specialization
-#define GEMM_SPECIALIZATION(EIGTYPE, EIGPREFIX, BLASTYPE, BLASPREFIX) \
+#define GEMM_SPECIALIZATION(EIGTYPE, EIGPREFIX, BLASTYPE, BLASFUNC) \
template< \
typename Index, \
int LhsStorageOrder, bool ConjugateLhs, \
@@ -100,13 +100,20 @@ static void run(Index rows, Index cols, Index depth, \
ldb = convert_index<BlasIndex>(b_tmp.outerStride()); \
} else b = _rhs; \
\
- BLASPREFIX##gemm_(&transa, &transb, &m, &n, &k, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)b, &ldb, &numext::real_ref(beta), (BLASTYPE*)res, &ldc); \
+ BLASFUNC(&transa, &transb, &m, &n, &k, (const BLASTYPE*)&numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)b, &ldb, (const BLASTYPE*)&numext::real_ref(beta), (BLASTYPE*)res, &ldc); \
}};
-GEMM_SPECIALIZATION(double, d, double, d)
-GEMM_SPECIALIZATION(float, f, float, s)
-GEMM_SPECIALIZATION(dcomplex, cd, double, z)
-GEMM_SPECIALIZATION(scomplex, cf, float, c)
+#ifdef EIGEN_USE_MKL
+GEMM_SPECIALIZATION(double, d, double, dgemm)
+GEMM_SPECIALIZATION(float, f, float, sgemm)
+GEMM_SPECIALIZATION(dcomplex, cd, MKL_Complex16, zgemm)
+GEMM_SPECIALIZATION(scomplex, cf, MKL_Complex8, cgemm)
+#else
+GEMM_SPECIALIZATION(double, d, double, dgemm_)
+GEMM_SPECIALIZATION(float, f, float, sgemm_)
+GEMM_SPECIALIZATION(dcomplex, cd, double, zgemm_)
+GEMM_SPECIALIZATION(scomplex, cf, float, cgemm_)
+#endif
} // end namespase internal
diff --git a/Eigen/src/Core/products/GeneralMatrixVector_BLAS.h b/Eigen/src/Core/products/GeneralMatrixVector_BLAS.h
index e3a5d5892..6e36c2b3c 100644
--- a/Eigen/src/Core/products/GeneralMatrixVector_BLAS.h
+++ b/Eigen/src/Core/products/GeneralMatrixVector_BLAS.h
@@ -85,7 +85,7 @@ EIGEN_BLAS_GEMV_SPECIALIZE(float)
EIGEN_BLAS_GEMV_SPECIALIZE(dcomplex)
EIGEN_BLAS_GEMV_SPECIALIZE(scomplex)
-#define EIGEN_BLAS_GEMV_SPECIALIZATION(EIGTYPE,BLASTYPE,BLASPREFIX) \
+#define EIGEN_BLAS_GEMV_SPECIALIZATION(EIGTYPE,BLASTYPE,BLASFUNC) \
template<typename Index, int LhsStorageOrder, bool ConjugateLhs, bool ConjugateRhs> \
struct general_matrix_vector_product_gemv<Index,EIGTYPE,LhsStorageOrder,ConjugateLhs,EIGTYPE,ConjugateRhs> \
{ \
@@ -113,14 +113,21 @@ static void run( \
x_ptr=x_tmp.data(); \
incx=1; \
} else x_ptr=rhs; \
- BLASPREFIX##gemv_(&trans, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)lhs, &lda, (const BLASTYPE*)x_ptr, &incx, &numext::real_ref(beta), (BLASTYPE*)res, &incy); \
+ BLASFUNC(&trans, &m, &n, (const BLASTYPE*)&numext::real_ref(alpha), (const BLASTYPE*)lhs, &lda, (const BLASTYPE*)x_ptr, &incx, (const BLASTYPE*)&numext::real_ref(beta), (BLASTYPE*)res, &incy); \
}\
};
-EIGEN_BLAS_GEMV_SPECIALIZATION(double, double, d)
-EIGEN_BLAS_GEMV_SPECIALIZATION(float, float, s)
-EIGEN_BLAS_GEMV_SPECIALIZATION(dcomplex, double, z)
-EIGEN_BLAS_GEMV_SPECIALIZATION(scomplex, float, c)
+#ifdef EIGEN_USE_MKL
+EIGEN_BLAS_GEMV_SPECIALIZATION(double, double, dgemv)
+EIGEN_BLAS_GEMV_SPECIALIZATION(float, float, sgemv)
+EIGEN_BLAS_GEMV_SPECIALIZATION(dcomplex, MKL_Complex16, zgemv)
+EIGEN_BLAS_GEMV_SPECIALIZATION(scomplex, MKL_Complex8 , cgemv)
+#else
+EIGEN_BLAS_GEMV_SPECIALIZATION(double, double, dgemv_)
+EIGEN_BLAS_GEMV_SPECIALIZATION(float, float, sgemv_)
+EIGEN_BLAS_GEMV_SPECIALIZATION(dcomplex, double, zgemv_)
+EIGEN_BLAS_GEMV_SPECIALIZATION(scomplex, float, cgemv_)
+#endif
} // end namespase internal
diff --git a/Eigen/src/Core/products/SelfadjointMatrixMatrix_BLAS.h b/Eigen/src/Core/products/SelfadjointMatrixMatrix_BLAS.h
index a45238d69..9a5318507 100644
--- a/Eigen/src/Core/products/SelfadjointMatrixMatrix_BLAS.h
+++ b/Eigen/src/Core/products/SelfadjointMatrixMatrix_BLAS.h
@@ -40,7 +40,7 @@ namespace internal {
/* Optimized selfadjoint matrix * matrix (?SYMM/?HEMM) product */
-#define EIGEN_BLAS_SYMM_L(EIGTYPE, BLASTYPE, EIGPREFIX, BLASPREFIX) \
+#define EIGEN_BLAS_SYMM_L(EIGTYPE, BLASTYPE, EIGPREFIX, BLASFUNC) \
template <typename Index, \
int LhsStorageOrder, bool ConjugateLhs, \
int RhsStorageOrder, bool ConjugateRhs> \
@@ -81,13 +81,13 @@ struct product_selfadjoint_matrix<EIGTYPE,Index,LhsStorageOrder,true,ConjugateLh
ldb = convert_index<BlasIndex>(b_tmp.outerStride()); \
} else b = _rhs; \
\
- BLASPREFIX##symm_(&side, &uplo, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)b, &ldb, &numext::real_ref(beta), (BLASTYPE*)res, &ldc); \
+ BLASFUNC(&side, &uplo, &m, &n, (const BLASTYPE*)&numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)b, &ldb, (const BLASTYPE*)&numext::real_ref(beta), (BLASTYPE*)res, &ldc); \
\
} \
};
-#define EIGEN_BLAS_HEMM_L(EIGTYPE, BLASTYPE, EIGPREFIX, BLASPREFIX) \
+#define EIGEN_BLAS_HEMM_L(EIGTYPE, BLASTYPE, EIGPREFIX, BLASFUNC) \
template <typename Index, \
int LhsStorageOrder, bool ConjugateLhs, \
int RhsStorageOrder, bool ConjugateRhs> \
@@ -144,20 +144,26 @@ struct product_selfadjoint_matrix<EIGTYPE,Index,LhsStorageOrder,true,ConjugateLh
ldb = convert_index<BlasIndex>(b_tmp.outerStride()); \
} \
\
- BLASPREFIX##hemm_(&side, &uplo, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)b, &ldb, &numext::real_ref(beta), (BLASTYPE*)res, &ldc); \
+ BLASFUNC(&side, &uplo, &m, &n, (const BLASTYPE*)&numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)b, &ldb, (const BLASTYPE*)&numext::real_ref(beta), (BLASTYPE*)res, &ldc); \
\
} \
};
-EIGEN_BLAS_SYMM_L(double, double, d, d)
-EIGEN_BLAS_SYMM_L(float, float, f, s)
-EIGEN_BLAS_HEMM_L(dcomplex, double, cd, z)
-EIGEN_BLAS_HEMM_L(scomplex, float, cf, c)
-
+#ifdef EIGEN_USE_MKL
+EIGEN_BLAS_SYMM_L(double, double, d, dsymm)
+EIGEN_BLAS_SYMM_L(float, float, f, ssymm)
+EIGEN_BLAS_HEMM_L(dcomplex, MKL_Complex16, cd, zhemm)
+EIGEN_BLAS_HEMM_L(scomplex, MKL_Complex8, cf, chemm)
+#else
+EIGEN_BLAS_SYMM_L(double, double, d, dsymm_)
+EIGEN_BLAS_SYMM_L(float, float, f, ssymm_)
+EIGEN_BLAS_HEMM_L(dcomplex, double, cd, zhemm_)
+EIGEN_BLAS_HEMM_L(scomplex, float, cf, chemm_)
+#endif
/* Optimized matrix * selfadjoint matrix (?SYMM/?HEMM) product */
-#define EIGEN_BLAS_SYMM_R(EIGTYPE, BLASTYPE, EIGPREFIX, BLASPREFIX) \
+#define EIGEN_BLAS_SYMM_R(EIGTYPE, BLASTYPE, EIGPREFIX, BLASFUNC) \
template <typename Index, \
int LhsStorageOrder, bool ConjugateLhs, \
int RhsStorageOrder, bool ConjugateRhs> \
@@ -197,13 +203,13 @@ struct product_selfadjoint_matrix<EIGTYPE,Index,LhsStorageOrder,false,ConjugateL
ldb = convert_index<BlasIndex>(b_tmp.outerStride()); \
} else b = _lhs; \
\
- BLASPREFIX##symm_(&side, &uplo, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)b, &ldb, &numext::real_ref(beta), (BLASTYPE*)res, &ldc); \
+ BLASFUNC(&side, &uplo, &m, &n, (const BLASTYPE*)&numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)b, &ldb, (const BLASTYPE*)&numext::real_ref(beta), (BLASTYPE*)res, &ldc); \
\
} \
};
-#define EIGEN_BLAS_HEMM_R(EIGTYPE, BLASTYPE, EIGPREFIX, BLASPREFIX) \
+#define EIGEN_BLAS_HEMM_R(EIGTYPE, BLASTYPE, EIGPREFIX, BLASFUNC) \
template <typename Index, \
int LhsStorageOrder, bool ConjugateLhs, \
int RhsStorageOrder, bool ConjugateRhs> \
@@ -259,15 +265,21 @@ struct product_selfadjoint_matrix<EIGTYPE,Index,LhsStorageOrder,false,ConjugateL
ldb = convert_index<BlasIndex>(b_tmp.outerStride()); \
} \
\
- BLASPREFIX##hemm_(&side, &uplo, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)b, &ldb, &numext::real_ref(beta), (BLASTYPE*)res, &ldc); \
+ BLASFUNC(&side, &uplo, &m, &n, (const BLASTYPE*)&numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)b, &ldb, (const BLASTYPE*)&numext::real_ref(beta), (BLASTYPE*)res, &ldc); \
} \
};
-EIGEN_BLAS_SYMM_R(double, double, d, d)
-EIGEN_BLAS_SYMM_R(float, float, f, s)
-EIGEN_BLAS_HEMM_R(dcomplex, double, cd, z)
-EIGEN_BLAS_HEMM_R(scomplex, float, cf, c)
-
+#ifdef EIGEN_USE_MKL
+EIGEN_BLAS_SYMM_R(double, double, d, dsymm)
+EIGEN_BLAS_SYMM_R(float, float, f, ssymm)
+EIGEN_BLAS_HEMM_R(dcomplex, MKL_Complex16, cd, zhemm)
+EIGEN_BLAS_HEMM_R(scomplex, MKL_Complex8, cf, chemm)
+#else
+EIGEN_BLAS_SYMM_R(double, double, d, dsymm_)
+EIGEN_BLAS_SYMM_R(float, float, f, ssymm_)
+EIGEN_BLAS_HEMM_R(dcomplex, double, cd, zhemm_)
+EIGEN_BLAS_HEMM_R(scomplex, float, cf, chemm_)
+#endif
} // end namespace internal
} // end namespace Eigen
diff --git a/Eigen/src/Core/products/SelfadjointMatrixVector_BLAS.h b/Eigen/src/Core/products/SelfadjointMatrixVector_BLAS.h
index 38f23accf..1238345e3 100644
--- a/Eigen/src/Core/products/SelfadjointMatrixVector_BLAS.h
+++ b/Eigen/src/Core/products/SelfadjointMatrixVector_BLAS.h
@@ -95,14 +95,21 @@ const EIGTYPE* _rhs, EIGTYPE* res, EIGTYPE alpha) \
x_tmp=map_x.conjugate(); \
x_ptr=x_tmp.data(); \
} else x_ptr=_rhs; \
- BLASFUNC(&uplo, &n, &numext::real_ref(alpha), (const BLASTYPE*)lhs, &lda, (const BLASTYPE*)x_ptr, &incx, &numext::real_ref(beta), (BLASTYPE*)res, &incy); \
+ BLASFUNC(&uplo, &n, (const BLASTYPE*)&numext::real_ref(alpha), (const BLASTYPE*)lhs, &lda, (const BLASTYPE*)x_ptr, &incx, (const BLASTYPE*)&numext::real_ref(beta), (BLASTYPE*)res, &incy); \
}\
};
+#ifdef EIGEN_USE_MKL
+EIGEN_BLAS_SYMV_SPECIALIZATION(double, double, dsymv)
+EIGEN_BLAS_SYMV_SPECIALIZATION(float, float, ssymv)
+EIGEN_BLAS_SYMV_SPECIALIZATION(dcomplex, MKL_Complex16, zhemv)
+EIGEN_BLAS_SYMV_SPECIALIZATION(scomplex, MKL_Complex8, chemv)
+#else
EIGEN_BLAS_SYMV_SPECIALIZATION(double, double, dsymv_)
EIGEN_BLAS_SYMV_SPECIALIZATION(float, float, ssymv_)
EIGEN_BLAS_SYMV_SPECIALIZATION(dcomplex, double, zhemv_)
EIGEN_BLAS_SYMV_SPECIALIZATION(scomplex, float, chemv_)
+#endif
} // end namespace internal
diff --git a/Eigen/src/Core/products/TriangularMatrixMatrix_BLAS.h b/Eigen/src/Core/products/TriangularMatrixMatrix_BLAS.h
index aecded6bb..a25197ab0 100644
--- a/Eigen/src/Core/products/TriangularMatrixMatrix_BLAS.h
+++ b/Eigen/src/Core/products/TriangularMatrixMatrix_BLAS.h
@@ -75,7 +75,7 @@ EIGEN_BLAS_TRMM_SPECIALIZE(scomplex, true)
EIGEN_BLAS_TRMM_SPECIALIZE(scomplex, false)
// implements col-major += alpha * op(triangular) * op(general)
-#define EIGEN_BLAS_TRMM_L(EIGTYPE, BLASTYPE, EIGPREFIX, BLASPREFIX) \
+#define EIGEN_BLAS_TRMM_L(EIGTYPE, BLASTYPE, EIGPREFIX, BLASFUNC) \
template <typename Index, int Mode, \
int LhsStorageOrder, bool ConjugateLhs, \
int RhsStorageOrder, bool ConjugateRhs> \
@@ -172,7 +172,7 @@ struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,true, \
} \
/*std::cout << "TRMM_L: A is square! Go to BLAS TRMM implementation! \n";*/ \
/* call ?trmm*/ \
- BLASPREFIX##trmm_(&side, &uplo, &transa, &diag, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (BLASTYPE*)b, &ldb); \
+ BLASFUNC(&side, &uplo, &transa, &diag, &m, &n, (const BLASTYPE*)&numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (BLASTYPE*)b, &ldb); \
\
/* Add op(a_triangular)*b into res*/ \
Map<MatrixX##EIGPREFIX, 0, OuterStride<> > res_tmp(res,rows,cols,OuterStride<>(resStride)); \
@@ -180,13 +180,20 @@ struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,true, \
} \
};
-EIGEN_BLAS_TRMM_L(double, double, d, d)
-EIGEN_BLAS_TRMM_L(dcomplex, double, cd, z)
-EIGEN_BLAS_TRMM_L(float, float, f, s)
-EIGEN_BLAS_TRMM_L(scomplex, float, cf, c)
+#ifdef EIGEN_USE_MKL
+EIGEN_BLAS_TRMM_L(double, double, d, dtrmm)
+EIGEN_BLAS_TRMM_L(dcomplex, MKL_Complex16, cd, ztrmm)
+EIGEN_BLAS_TRMM_L(float, float, f, strmm)
+EIGEN_BLAS_TRMM_L(scomplex, MKL_Complex8, cf, ctrmm)
+#else
+EIGEN_BLAS_TRMM_L(double, double, d, dtrmm_)
+EIGEN_BLAS_TRMM_L(dcomplex, double, cd, ztrmm_)
+EIGEN_BLAS_TRMM_L(float, float, f, strmm_)
+EIGEN_BLAS_TRMM_L(scomplex, float, cf, ctrmm_)
+#endif
// implements col-major += alpha * op(general) * op(triangular)
-#define EIGEN_BLAS_TRMM_R(EIGTYPE, BLASTYPE, EIGPREFIX, BLASPREFIX) \
+#define EIGEN_BLAS_TRMM_R(EIGTYPE, BLASTYPE, EIGPREFIX, BLASFUNC) \
template <typename Index, int Mode, \
int LhsStorageOrder, bool ConjugateLhs, \
int RhsStorageOrder, bool ConjugateRhs> \
@@ -282,7 +289,7 @@ struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,false, \
} \
/*std::cout << "TRMM_R: A is square! Go to BLAS TRMM implementation! \n";*/ \
/* call ?trmm*/ \
- BLASPREFIX##trmm_(&side, &uplo, &transa, &diag, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (BLASTYPE*)b, &ldb); \
+ BLASFUNC(&side, &uplo, &transa, &diag, &m, &n, (const BLASTYPE*)&numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (BLASTYPE*)b, &ldb); \
\
/* Add op(a_triangular)*b into res*/ \
Map<MatrixX##EIGPREFIX, 0, OuterStride<> > res_tmp(res,rows,cols,OuterStride<>(resStride)); \
@@ -290,11 +297,17 @@ struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,false, \
} \
};
-EIGEN_BLAS_TRMM_R(double, double, d, d)
-EIGEN_BLAS_TRMM_R(dcomplex, double, cd, z)
-EIGEN_BLAS_TRMM_R(float, float, f, s)
-EIGEN_BLAS_TRMM_R(scomplex, float, cf, c)
-
+#ifdef EIGEN_USE_MKL
+EIGEN_BLAS_TRMM_R(double, double, d, dtrmm)
+EIGEN_BLAS_TRMM_R(dcomplex, MKL_Complex16, cd, ztrmm)
+EIGEN_BLAS_TRMM_R(float, float, f, strmm)
+EIGEN_BLAS_TRMM_R(scomplex, MKL_Complex8, cf, ctrmm)
+#else
+EIGEN_BLAS_TRMM_R(double, double, d, dtrmm_)
+EIGEN_BLAS_TRMM_R(dcomplex, double, cd, ztrmm_)
+EIGEN_BLAS_TRMM_R(float, float, f, strmm_)
+EIGEN_BLAS_TRMM_R(scomplex, float, cf, ctrmm_)
+#endif
} // end namespace internal
} // end namespace Eigen
diff --git a/Eigen/src/Core/products/TriangularMatrixVector_BLAS.h b/Eigen/src/Core/products/TriangularMatrixVector_BLAS.h
index 07bf26ce5..3d47a2b94 100644
--- a/Eigen/src/Core/products/TriangularMatrixVector_BLAS.h
+++ b/Eigen/src/Core/products/TriangularMatrixVector_BLAS.h
@@ -71,7 +71,7 @@ EIGEN_BLAS_TRMV_SPECIALIZE(dcomplex)
EIGEN_BLAS_TRMV_SPECIALIZE(scomplex)
// implements col-major: res += alpha * op(triangular) * vector
-#define EIGEN_BLAS_TRMV_CM(EIGTYPE, BLASTYPE, EIGPREFIX, BLASPREFIX) \
+#define EIGEN_BLAS_TRMV_CM(EIGTYPE, BLASTYPE, EIGPREFIX, BLASPREFIX, BLASPOSTFIX) \
template<typename Index, int Mode, bool ConjLhs, bool ConjRhs> \
struct triangular_matrix_vector_product_trmv<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,ConjRhs,ColMajor> { \
enum { \
@@ -121,10 +121,10 @@ struct triangular_matrix_vector_product_trmv<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,
diag = IsUnitDiag ? 'U' : 'N'; \
\
/* call ?TRMV*/ \
- BLASPREFIX##trmv_(&uplo, &trans, &diag, &n, (const BLASTYPE*)_lhs, &lda, (BLASTYPE*)x, &incx); \
+ BLASPREFIX##trmv##BLASPOSTFIX(&uplo, &trans, &diag, &n, (const BLASTYPE*)_lhs, &lda, (BLASTYPE*)x, &incx); \
\
/* Add op(a_tr)rhs into res*/ \
- BLASPREFIX##axpy_(&n, &numext::real_ref(alpha),(const BLASTYPE*)x, &incx, (BLASTYPE*)_res, &incy); \
+ BLASPREFIX##axpy##BLASPOSTFIX(&n, (const BLASTYPE*)&numext::real_ref(alpha),(const BLASTYPE*)x, &incx, (BLASTYPE*)_res, &incy); \
/* Non-square case - doesn't fit to BLAS ?TRMV. Fall to default triangular product*/ \
if (size<(std::max)(rows,cols)) { \
if (ConjRhs) x_tmp = rhs.conjugate(); else x_tmp = rhs; \
@@ -142,18 +142,25 @@ struct triangular_matrix_vector_product_trmv<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,
m = convert_index<BlasIndex>(size); \
n = convert_index<BlasIndex>(cols-size); \
} \
- BLASPREFIX##gemv_(&trans, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)x, &incx, &numext::real_ref(beta), (BLASTYPE*)y, &incy); \
+ BLASPREFIX##gemv##BLASPOSTFIX(&trans, &m, &n, (const BLASTYPE*)&numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)x, &incx, (const BLASTYPE*)&numext::real_ref(beta), (BLASTYPE*)y, &incy); \
} \
} \
};
-EIGEN_BLAS_TRMV_CM(double, double, d, d)
-EIGEN_BLAS_TRMV_CM(dcomplex, double, cd, z)
-EIGEN_BLAS_TRMV_CM(float, float, f, s)
-EIGEN_BLAS_TRMV_CM(scomplex, float, cf, c)
+#ifdef EIGEN_USE_MKL
+EIGEN_BLAS_TRMV_CM(double, double, d, d,)
+EIGEN_BLAS_TRMV_CM(dcomplex, MKL_Complex16, cd, z,)
+EIGEN_BLAS_TRMV_CM(float, float, f, s,)
+EIGEN_BLAS_TRMV_CM(scomplex, MKL_Complex8, cf, c,)
+#else
+EIGEN_BLAS_TRMV_CM(double, double, d, d, _)
+EIGEN_BLAS_TRMV_CM(dcomplex, double, cd, z, _)
+EIGEN_BLAS_TRMV_CM(float, float, f, s, _)
+EIGEN_BLAS_TRMV_CM(scomplex, float, cf, c, _)
+#endif
// implements row-major: res += alpha * op(triangular) * vector
-#define EIGEN_BLAS_TRMV_RM(EIGTYPE, BLASTYPE, EIGPREFIX, BLASPREFIX) \
+#define EIGEN_BLAS_TRMV_RM(EIGTYPE, BLASTYPE, EIGPREFIX, BLASPREFIX, BLASPOSTFIX) \
template<typename Index, int Mode, bool ConjLhs, bool ConjRhs> \
struct triangular_matrix_vector_product_trmv<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,ConjRhs,RowMajor> { \
enum { \
@@ -203,10 +210,10 @@ struct triangular_matrix_vector_product_trmv<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,
diag = IsUnitDiag ? 'U' : 'N'; \
\
/* call ?TRMV*/ \
- BLASPREFIX##trmv_(&uplo, &trans, &diag, &n, (const BLASTYPE*)_lhs, &lda, (BLASTYPE*)x, &incx); \
+ BLASPREFIX##trmv##BLASPOSTFIX(&uplo, &trans, &diag, &n, (const BLASTYPE*)_lhs, &lda, (BLASTYPE*)x, &incx); \
\
/* Add op(a_tr)rhs into res*/ \
- BLASPREFIX##axpy_(&n, &numext::real_ref(alpha),(const BLASTYPE*)x, &incx, (BLASTYPE*)_res, &incy); \
+ BLASPREFIX##axpy##BLASPOSTFIX(&n, (const BLASTYPE*)&numext::real_ref(alpha),(const BLASTYPE*)x, &incx, (BLASTYPE*)_res, &incy); \
/* Non-square case - doesn't fit to BLAS ?TRMV. Fall to default triangular product*/ \
if (size<(std::max)(rows,cols)) { \
if (ConjRhs) x_tmp = rhs.conjugate(); else x_tmp = rhs; \
@@ -224,15 +231,22 @@ struct triangular_matrix_vector_product_trmv<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,
m = convert_index<BlasIndex>(size); \
n = convert_index<BlasIndex>(cols-size); \
} \
- BLASPREFIX##gemv_(&trans, &n, &m, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)x, &incx, &numext::real_ref(beta), (BLASTYPE*)y, &incy); \
+ BLASPREFIX##gemv##BLASPOSTFIX(&trans, &n, &m, (const BLASTYPE*)&numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)x, &incx, (const BLASTYPE*)&numext::real_ref(beta), (BLASTYPE*)y, &incy); \
} \
} \
};
-EIGEN_BLAS_TRMV_RM(double, double, d, d)
-EIGEN_BLAS_TRMV_RM(dcomplex, double, cd, z)
-EIGEN_BLAS_TRMV_RM(float, float, f, s)
-EIGEN_BLAS_TRMV_RM(scomplex, float, cf, c)
+#ifdef EIGEN_USE_MKL
+EIGEN_BLAS_TRMV_RM(double, double, d, d,)
+EIGEN_BLAS_TRMV_RM(dcomplex, MKL_Complex16, cd, z,)
+EIGEN_BLAS_TRMV_RM(float, float, f, s,)
+EIGEN_BLAS_TRMV_RM(scomplex, MKL_Complex8, cf, c,)
+#else
+EIGEN_BLAS_TRMV_RM(double, double, d, d,_)
+EIGEN_BLAS_TRMV_RM(dcomplex, double, cd, z,_)
+EIGEN_BLAS_TRMV_RM(float, float, f, s,_)
+EIGEN_BLAS_TRMV_RM(scomplex, float, cf, c,_)
+#endif
} // end namespase internal
diff --git a/Eigen/src/Core/products/TriangularSolverMatrix_BLAS.h b/Eigen/src/Core/products/TriangularSolverMatrix_BLAS.h
index 88c0fb794..f0775116a 100644
--- a/Eigen/src/Core/products/TriangularSolverMatrix_BLAS.h
+++ b/Eigen/src/Core/products/TriangularSolverMatrix_BLAS.h
@@ -38,7 +38,7 @@ namespace Eigen {
namespace internal {
// implements LeftSide op(triangular)^-1 * general
-#define EIGEN_BLAS_TRSM_L(EIGTYPE, BLASTYPE, BLASPREFIX) \
+#define EIGEN_BLAS_TRSM_L(EIGTYPE, BLASTYPE, BLASFUNC) \
template <typename Index, int Mode, bool Conjugate, int TriStorageOrder> \
struct triangular_solve_matrix<EIGTYPE,Index,OnTheLeft,Mode,Conjugate,TriStorageOrder,ColMajor> \
{ \
@@ -80,18 +80,24 @@ struct triangular_solve_matrix<EIGTYPE,Index,OnTheLeft,Mode,Conjugate,TriStorage
} \
if (IsUnitDiag) diag='U'; \
/* call ?trsm*/ \
- BLASPREFIX##trsm_(&side, &uplo, &transa, &diag, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (BLASTYPE*)_other, &ldb); \
+ BLASFUNC(&side, &uplo, &transa, &diag, &m, &n, (const BLASTYPE*)&numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (BLASTYPE*)_other, &ldb); \
} \
};
-EIGEN_BLAS_TRSM_L(double, double, d)
-EIGEN_BLAS_TRSM_L(dcomplex, double, z)
-EIGEN_BLAS_TRSM_L(float, float, s)
-EIGEN_BLAS_TRSM_L(scomplex, float, c)
-
+#ifdef EIGEN_USE_MKL
+EIGEN_BLAS_TRSM_L(double, double, dtrsm)
+EIGEN_BLAS_TRSM_L(dcomplex, MKL_Complex16, ztrsm)
+EIGEN_BLAS_TRSM_L(float, float, strsm)
+EIGEN_BLAS_TRSM_L(scomplex, MKL_Complex8, ctrsm)
+#else
+EIGEN_BLAS_TRSM_L(double, double, dtrsm_)
+EIGEN_BLAS_TRSM_L(dcomplex, double, ztrsm_)
+EIGEN_BLAS_TRSM_L(float, float, strsm_)
+EIGEN_BLAS_TRSM_L(scomplex, float, ctrsm_)
+#endif
// implements RightSide general * op(triangular)^-1
-#define EIGEN_BLAS_TRSM_R(EIGTYPE, BLASTYPE, BLASPREFIX) \
+#define EIGEN_BLAS_TRSM_R(EIGTYPE, BLASTYPE, BLASFUNC) \
template <typename Index, int Mode, bool Conjugate, int TriStorageOrder> \
struct triangular_solve_matrix<EIGTYPE,Index,OnTheRight,Mode,Conjugate,TriStorageOrder,ColMajor> \
{ \
@@ -133,16 +139,22 @@ struct triangular_solve_matrix<EIGTYPE,Index,OnTheRight,Mode,Conjugate,TriStorag
} \
if (IsUnitDiag) diag='U'; \
/* call ?trsm*/ \
- BLASPREFIX##trsm_(&side, &uplo, &transa, &diag, &m, &n, &numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (BLASTYPE*)_other, &ldb); \
+ BLASFUNC(&side, &uplo, &transa, &diag, &m, &n, (const BLASTYPE*)&numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (BLASTYPE*)_other, &ldb); \
/*std::cout << "TRMS_L specialization!\n";*/ \
} \
};
-EIGEN_BLAS_TRSM_R(double, double, d)
-EIGEN_BLAS_TRSM_R(dcomplex, double, z)
-EIGEN_BLAS_TRSM_R(float, float, s)
-EIGEN_BLAS_TRSM_R(scomplex, float, c)
-
+#ifdef EIGEN_USE_MKL
+EIGEN_BLAS_TRSM_R(double, double, dtrsm)
+EIGEN_BLAS_TRSM_R(dcomplex, MKL_Complex16, ztrsm)
+EIGEN_BLAS_TRSM_R(float, float, strsm)
+EIGEN_BLAS_TRSM_R(scomplex, MKL_Complex8, ctrsm)
+#else
+EIGEN_BLAS_TRSM_R(double, double, dtrsm_)
+EIGEN_BLAS_TRSM_R(dcomplex, double, ztrsm_)
+EIGEN_BLAS_TRSM_R(float, float, strsm_)
+EIGEN_BLAS_TRSM_R(scomplex, float, ctrsm_)
+#endif
} // end namespace internal
diff --git a/Eigen/src/Core/util/MKL_support.h b/Eigen/src/Core/util/MKL_support.h
index 26b59669e..e656799bf 100755
--- a/Eigen/src/Core/util/MKL_support.h
+++ b/Eigen/src/Core/util/MKL_support.h
@@ -53,6 +53,7 @@
#define EIGEN_USE_MKL
#endif
+
#if defined EIGEN_USE_MKL
# include <mkl.h>
/*Check IMKL version for compatibility: < 10.3 is not usable with Eigen*/
@@ -108,6 +109,10 @@
#endif
#endif
+#if defined(EIGEN_USE_BLAS) && !defined(EIGEN_USE_MKL)
+#include "../../misc/blas.h"
+#endif
+
namespace Eigen {
typedef std::complex<double> dcomplex;
@@ -121,8 +126,5 @@ typedef int BlasIndex;
} // end namespace Eigen
-#if defined(EIGEN_USE_BLAS)
-#include "../../misc/blas.h"
-#endif
#endif // EIGEN_MKL_SUPPORT_H