aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen/src/Core/products/TriangularMatrixVector_MKL.h
diff options
context:
space:
mode:
authorGravatar Gael Guennebaud <g.gael@free.fr>2016-04-11 15:17:14 +0200
committerGravatar Gael Guennebaud <g.gael@free.fr>2016-04-11 15:17:14 +0200
commit6a9ca88e7e1bb72de621806b51c5a4fd17310943 (patch)
treef31ced3c1f2fba0e7f230b4c2acd71d1d1581c97 /Eigen/src/Core/products/TriangularMatrixVector_MKL.h
parent4e8e5888d7a78d514e54a518f6692f2838314328 (diff)
Relax dependency on MKL for EIGEN_USE_BLAS
Diffstat (limited to 'Eigen/src/Core/products/TriangularMatrixVector_MKL.h')
-rw-r--r--Eigen/src/Core/products/TriangularMatrixVector_MKL.h36
1 files changed, 16 insertions, 20 deletions
diff --git a/Eigen/src/Core/products/TriangularMatrixVector_MKL.h b/Eigen/src/Core/products/TriangularMatrixVector_MKL.h
index 3672b1240..3aaea3457 100644
--- a/Eigen/src/Core/products/TriangularMatrixVector_MKL.h
+++ b/Eigen/src/Core/products/TriangularMatrixVector_MKL.h
@@ -107,9 +107,7 @@ struct triangular_matrix_vector_product_trmv<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,
char trans, uplo, diag; \
MKL_INT m, n, lda, incx, incy; \
EIGTYPE const *a; \
- MKLTYPE alpha_, beta_; \
- assign_scalar_eig2mkl<MKLTYPE, EIGTYPE>(alpha_, alpha); \
- assign_scalar_eig2mkl<MKLTYPE, EIGTYPE>(beta_, EIGTYPE(1)); \
+ EIGTYPE beta(1); \
\
/* Set m, n */ \
n = (MKL_INT)size; \
@@ -123,10 +121,10 @@ struct triangular_matrix_vector_product_trmv<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,
diag = IsUnitDiag ? 'U' : 'N'; \
\
/* call ?TRMV*/ \
- MKLPREFIX##trmv(&uplo, &trans, &diag, &n, (const MKLTYPE*)_lhs, &lda, (MKLTYPE*)x, &incx); \
+ MKLPREFIX##trmv_(&uplo, &trans, &diag, &n, (const MKLTYPE*)_lhs, &lda, (MKLTYPE*)x, &incx); \
\
/* Add op(a_tr)rhs into res*/ \
- MKLPREFIX##axpy(&n, &alpha_,(const MKLTYPE*)x, &incx, (MKLTYPE*)_res, &incy); \
+ MKLPREFIX##axpy_(&n, &numext::real_ref(alpha),(const MKLTYPE*)x, &incx, (MKLTYPE*)_res, &incy); \
/* Non-square case - doesn't fit to MKL ?TRMV. Fall to default triangular product*/ \
if (size<(std::max)(rows,cols)) { \
if (ConjRhs) x_tmp = rhs.conjugate(); else x_tmp = rhs; \
@@ -144,15 +142,15 @@ struct triangular_matrix_vector_product_trmv<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,
m = size; \
n = cols-size; \
} \
- MKLPREFIX##gemv(&trans, &m, &n, &alpha_, (const MKLTYPE*)a, &lda, (const MKLTYPE*)x, &incx, &beta_, (MKLTYPE*)y, &incy); \
+ MKLPREFIX##gemv_(&trans, &m, &n, &numext::real_ref(alpha), (const MKLTYPE*)a, &lda, (const MKLTYPE*)x, &incx, &numext::real_ref(beta), (MKLTYPE*)y, &incy); \
} \
} \
};
-EIGEN_MKL_TRMV_CM(double, double, d, d)
-EIGEN_MKL_TRMV_CM(dcomplex, MKL_Complex16, cd, z)
-EIGEN_MKL_TRMV_CM(float, float, f, s)
-EIGEN_MKL_TRMV_CM(scomplex, MKL_Complex8, cf, c)
+EIGEN_MKL_TRMV_CM(double, double, d, d)
+EIGEN_MKL_TRMV_CM(dcomplex, double, cd, z)
+EIGEN_MKL_TRMV_CM(float, float, f, s)
+EIGEN_MKL_TRMV_CM(scomplex, float, cf, c)
// implements row-major: res += alpha * op(triangular) * vector
#define EIGEN_MKL_TRMV_RM(EIGTYPE, MKLTYPE, EIGPREFIX, MKLPREFIX) \
@@ -191,9 +189,7 @@ struct triangular_matrix_vector_product_trmv<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,
char trans, uplo, diag; \
MKL_INT m, n, lda, incx, incy; \
EIGTYPE const *a; \
- MKLTYPE alpha_, beta_; \
- assign_scalar_eig2mkl<MKLTYPE, EIGTYPE>(alpha_, alpha); \
- assign_scalar_eig2mkl<MKLTYPE, EIGTYPE>(beta_, EIGTYPE(1)); \
+ EIGTYPE beta(1); \
\
/* Set m, n */ \
n = (MKL_INT)size; \
@@ -207,10 +203,10 @@ struct triangular_matrix_vector_product_trmv<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,
diag = IsUnitDiag ? 'U' : 'N'; \
\
/* call ?TRMV*/ \
- MKLPREFIX##trmv(&uplo, &trans, &diag, &n, (const MKLTYPE*)_lhs, &lda, (MKLTYPE*)x, &incx); \
+ MKLPREFIX##trmv_(&uplo, &trans, &diag, &n, (const MKLTYPE*)_lhs, &lda, (MKLTYPE*)x, &incx); \
\
/* Add op(a_tr)rhs into res*/ \
- MKLPREFIX##axpy(&n, &alpha_,(const MKLTYPE*)x, &incx, (MKLTYPE*)_res, &incy); \
+ MKLPREFIX##axpy_(&n, &numext::real_ref(alpha),(const MKLTYPE*)x, &incx, (MKLTYPE*)_res, &incy); \
/* Non-square case - doesn't fit to MKL ?TRMV. Fall to default triangular product*/ \
if (size<(std::max)(rows,cols)) { \
if (ConjRhs) x_tmp = rhs.conjugate(); else x_tmp = rhs; \
@@ -228,15 +224,15 @@ struct triangular_matrix_vector_product_trmv<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,
m = size; \
n = cols-size; \
} \
- MKLPREFIX##gemv(&trans, &n, &m, &alpha_, (const MKLTYPE*)a, &lda, (const MKLTYPE*)x, &incx, &beta_, (MKLTYPE*)y, &incy); \
+ MKLPREFIX##gemv_(&trans, &n, &m, &numext::real_ref(alpha), (const MKLTYPE*)a, &lda, (const MKLTYPE*)x, &incx, &numext::real_ref(beta), (MKLTYPE*)y, &incy); \
} \
} \
};
-EIGEN_MKL_TRMV_RM(double, double, d, d)
-EIGEN_MKL_TRMV_RM(dcomplex, MKL_Complex16, cd, z)
-EIGEN_MKL_TRMV_RM(float, float, f, s)
-EIGEN_MKL_TRMV_RM(scomplex, MKL_Complex8, cf, c)
+EIGEN_MKL_TRMV_RM(double, double, d, d)
+EIGEN_MKL_TRMV_RM(dcomplex, double, cd, z)
+EIGEN_MKL_TRMV_RM(float, float, f, s)
+EIGEN_MKL_TRMV_RM(scomplex, float, cf, c)
} // end namespase internal