diff options
author | Gael Guennebaud <g.gael@free.fr> | 2010-03-02 12:44:40 +0100 |
---|---|---|
committer | Gael Guennebaud <g.gael@free.fr> | 2010-03-02 12:44:40 +0100 |
commit | a2d7c239f54190ddb40febb6b4b65d74c261f008 (patch) | |
tree | 121c9a6a88ce707419a5af481c3ebd0b1f113b72 | |
parent | 7fd6458fec694f213323d6dd0718d315513adbb5 (diff) |
blas: fix HEMM and HERK
-rw-r--r-- | blas/complex_double.cpp | 2 | ||||
-rw-r--r-- | blas/complex_single.cpp | 2 | ||||
-rw-r--r-- | blas/level3_impl.h | 90 |
3 files changed, 55 insertions, 39 deletions
diff --git a/blas/complex_double.cpp b/blas/complex_double.cpp index f51ccb25b..be2104a56 100644 --- a/blas/complex_double.cpp +++ b/blas/complex_double.cpp @@ -23,7 +23,7 @@ // Eigen. If not, see <http://www.gnu.org/licenses/>. #define SCALAR std::complex<double> -#define SCALAR_SUFFIX c +#define SCALAR_SUFFIX z #define ISCOMPLEX 1 #include "level1_impl.h" diff --git a/blas/complex_single.cpp b/blas/complex_single.cpp index b6617e7b9..2b13bc7ce 100644 --- a/blas/complex_single.cpp +++ b/blas/complex_single.cpp @@ -23,7 +23,7 @@ // Eigen. If not, see <http://www.gnu.org/licenses/>. #define SCALAR std::complex<float> -#define SCALAR_SUFFIX z +#define SCALAR_SUFFIX c #define ISCOMPLEX 1 #include "level1_impl.h" diff --git a/blas/level3_impl.h b/blas/level3_impl.h index 32b49b118..c9023ab37 100644 --- a/blas/level3_impl.h +++ b/blas/level3_impl.h @@ -218,8 +218,7 @@ int EIGEN_BLAS_FUNC(trmm)(char *side, char *uplo, char *opa, char *diag, int *m, // c = alpha*b*a + beta*c for side = 'R'or'r int EIGEN_BLAS_FUNC(symm)(char *side, char *uplo, int *m, int *n, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *pb, int *ldb, RealScalar *pbeta, RealScalar *pc, int *ldc) { -// std::cerr << "in symm " << *side << " " << *uplo << " " << *m << "x" << *n << " lda:" << *lda << " ldb:" << *ldb << " ldc:" << *ldc << " alpha:" << *palpha << " beta:" << *pbeta << " " -// << pa << " " << pb << " " << pc << "\n"; +// std::cerr << "in symm " << *side << " " << *uplo << " " << *m << "x" << *n << " lda:" << *lda << " ldb:" << *ldb << " ldc:" << *ldc << " alpha:" << *palpha << " beta:" << *pbeta << "\n"; Scalar* a = reinterpret_cast<Scalar*>(pa); Scalar* b = reinterpret_cast<Scalar*>(pb); Scalar* c = reinterpret_cast<Scalar*>(pc); @@ -234,25 +233,17 @@ int EIGEN_BLAS_FUNC(symm)(char *side, char *uplo, int *m, int *n, RealScalar *pa } if(beta!=Scalar(1)) - if(beta==Scalar(0)) - matrix(c, *m, *n, *ldc).setZero(); - else - matrix(c, *m, *n, *ldc) *= beta; + if(beta==Scalar(0)) matrix(c, *m, *n, *ldc).setZero(); + else matrix(c, *m, *n, *ldc) *= beta; if(SIDE(*side)==LEFT) - if(UPLO(*uplo)==UP) - ei_product_selfadjoint_matrix<Scalar, RowMajor,true,false, ColMajor,false,false, ColMajor>::run(*m, *n, a, *lda, b, *ldb, c, *ldc, alpha); - else if(UPLO(*uplo)==LO) - ei_product_selfadjoint_matrix<Scalar, ColMajor,true,false, ColMajor,false,false, ColMajor>::run(*m, *n, a, *lda, b, *ldb, c, *ldc, alpha); - else - return 0; + if(UPLO(*uplo)==UP) ei_product_selfadjoint_matrix<Scalar, RowMajor,true,false, ColMajor,false,false, ColMajor>::run(*m, *n, a, *lda, b, *ldb, c, *ldc, alpha); + else if(UPLO(*uplo)==LO) ei_product_selfadjoint_matrix<Scalar, ColMajor,true,false, ColMajor,false,false, ColMajor>::run(*m, *n, a, *lda, b, *ldb, c, *ldc, alpha); + else return 0; else if(SIDE(*side)==RIGHT) - if(UPLO(*uplo)==UP) - ei_product_selfadjoint_matrix<Scalar, ColMajor,false,false, RowMajor,true,false, ColMajor>::run(*m, *n, b, *ldb, a, *lda, c, *ldc, alpha); - else if(UPLO(*uplo)==LO) - ei_product_selfadjoint_matrix<Scalar, ColMajor,false,false, ColMajor,true,false, ColMajor>::run(*m, *n, b, *ldb, a, *lda, c, *ldc, alpha); - else - return 0; + if(UPLO(*uplo)==UP) ei_product_selfadjoint_matrix<Scalar, ColMajor,false,false, RowMajor,true,false, ColMajor>::run(*m, *n, b, *ldb, a, *lda, c, *ldc, alpha); + else if(UPLO(*uplo)==LO) ei_product_selfadjoint_matrix<Scalar, ColMajor,false,false, ColMajor,true,false, ColMajor>::run(*m, *n, b, *ldb, a, *lda, c, *ldc, alpha); + else return 0; else return 0; @@ -334,27 +325,30 @@ int EIGEN_BLAS_FUNC(hemm)(char *side, char *uplo, int *m, int *n, RealScalar *pa Scalar alpha = *reinterpret_cast<Scalar*>(palpha); Scalar beta = *reinterpret_cast<Scalar*>(pbeta); +// std::cerr << "in hemm " << *side << " " << *uplo << " " << *m << " " << *n << " " << alpha << " " << *lda << " " << beta << " " << *ldc << "\n"; + + if(*m<0 || *n<0) + { + return 0; + } + if(beta!=Scalar(1)) matrix(c, *m, *n, *ldc) *= beta; if(SIDE(*side)==LEFT) - if(UPLO(*uplo)==UP) - ei_product_selfadjoint_matrix<Scalar, RowMajor,true,Conj, ColMajor,false,false, ColMajor>::run(*m, *n, a, *lda, b, *ldb, c, *ldc, alpha); - else if(UPLO(*uplo)==LO) - ei_product_selfadjoint_matrix<Scalar, ColMajor,true,false, ColMajor,false,false, ColMajor>::run(*m, *n, a, *lda, b, *ldb, c, *ldc, alpha); - else - return 0; + if(UPLO(*uplo)==UP) ei_product_selfadjoint_matrix<Scalar, RowMajor,true,Conj, ColMajor,false,false, ColMajor>::run(*m, *n, a, *lda, b, *ldb, c, *ldc, alpha); + else if(UPLO(*uplo)==LO) ei_product_selfadjoint_matrix<Scalar, ColMajor,true,false, ColMajor,false,false, ColMajor>::run(*m, *n, a, *lda, b, *ldb, c, *ldc, alpha); + else return 0; else if(SIDE(*side)==RIGHT) - if(UPLO(*uplo)==UP) - ei_product_selfadjoint_matrix<Scalar, ColMajor,false,false, RowMajor,true,Conj, ColMajor>::run(*m, *n, b, *ldb, a, *lda, c, *ldc, alpha); - else if(UPLO(*uplo)==LO) - ei_product_selfadjoint_matrix<Scalar, ColMajor,false,false, ColMajor,true,false, ColMajor>::run(*m, *n, b, *ldb, a, *lda, c, *ldc, alpha); - else - return 0; + if(UPLO(*uplo)==UP) ei_product_selfadjoint_matrix<Scalar, ColMajor,false,false, RowMajor,true,Conj, ColMajor>::run(*m, *n, b, *ldb, a, *lda, c, *ldc, alpha); + else if(UPLO(*uplo)==LO) ei_product_selfadjoint_matrix<Scalar, ColMajor,false,false, ColMajor,true,false, ColMajor>::run(*m, *n, b, *ldb, a, *lda, c, *ldc, alpha); + else return 0; else + { return 0; + } - return 1; + return 0; } // c = alpha*a*conj(a') + beta*c for op = 'N'or'n' @@ -381,18 +375,35 @@ int EIGEN_BLAS_FUNC(herk)(char *uplo, char *op, int *n, int *k, RealScalar *palp Scalar* a = reinterpret_cast<Scalar*>(pa); Scalar* c = reinterpret_cast<Scalar*>(pc); - Scalar alpha = *reinterpret_cast<Scalar*>(palpha); - Scalar beta = *reinterpret_cast<Scalar*>(pbeta); + RealScalar alpha = *palpha; + RealScalar beta = *pbeta; + +// std::cerr << "in herk " << *uplo << " " << *op << " " << *n << " " << *k << " " << alpha << " " << *lda << " " << beta << " " << *ldc << "\n"; + + if(*n<0 || *k<0) + { + return 0; + } int code = OP(*op) | (UPLO(*uplo) << 2); if(code>=8 || func[code]==0) return 0; - if(beta!=Scalar(1)) - matrix(c, *n, *n, *ldc) *= beta; + if(beta!=RealScalar(1)) + { + if(UPLO(*uplo)==UP) matrix(c, *n, *n, *ldc).triangularView<StrictlyUpper>() *= beta; + else matrix(c, *n, *n, *ldc).triangularView<StrictlyLower>() *= beta; - func[code](*n, *k, a, *lda, c, *ldc, alpha); - return 1; + matrix(c, *n, *n, *ldc).diagonal().real() *= beta; + matrix(c, *n, *n, *ldc).diagonal().imag().setZero(); + } + + if(*k>0 && alpha!=RealScalar(0)) + { + func[code](*n, *k, a, *lda, c, *ldc, alpha); + matrix(c, *n, *n, *ldc).diagonal().imag().setZero(); + } + return 0; } // c = alpha*a*conj(b') + conj(alpha)*b*conj(a') + beta*c, for op = 'N'or'n' @@ -405,6 +416,11 @@ int EIGEN_BLAS_FUNC(her2k)(char *uplo, char *op, int *n, int *k, RealScalar *pal Scalar alpha = *reinterpret_cast<Scalar*>(palpha); Scalar beta = *reinterpret_cast<Scalar*>(pbeta); + if(*n<0 || *k<0) + { + return 0; + } + // TODO return 0; |