From a7b9250ad04fe02f9c51085164478bc1687577f3 Mon Sep 17 00:00:00 2001 From: Gael Guennebaud Date: Mon, 1 Mar 2010 19:06:07 +0100 Subject: blas interface: fix compilation, fix GEMM, SYMM, TRMM, and TRSM, i,e., they all pass the blas test suite. More to come --- blas/level3_impl.h | 204 ++++++++++++++++++++++++++++++++--------------------- 1 file changed, 125 insertions(+), 79 deletions(-) (limited to 'blas/level3_impl.h') diff --git a/blas/level3_impl.h b/blas/level3_impl.h index d44de1b5d..76497ec26 100644 --- a/blas/level3_impl.h +++ b/blas/level3_impl.h @@ -26,8 +26,9 @@ int EIGEN_BLAS_FUNC(gemm)(char *opa, char *opb, int *m, int *n, int *k, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *pb, int *ldb, RealScalar *pbeta, RealScalar *pc, int *ldc) { +// std::cerr << "in gemm " << *opa << " " << *opb << " " << *m << " " << *n << " " << *k << " " << *lda << " " << *ldb << " " << *ldc << " " << *palpha << " " << *pbeta << "\n"; typedef void (*functype)(int, int, int, const Scalar *, int, const Scalar *, int, Scalar *, int, Scalar); - functype func[12]; + static functype func[12]; static bool init = false; if(!init) @@ -52,21 +53,29 @@ int EIGEN_BLAS_FUNC(gemm)(char *opa, char *opb, int *m, int *n, int *k, RealScal Scalar alpha = *reinterpret_cast(palpha); Scalar beta = *reinterpret_cast(pbeta); - if(beta!=Scalar(1)) - matrix(c, *m, *n, *ldc) *= beta; - int code = OP(*opa) | (OP(*opb) << 2); - if(code>=12 || func[code]==0) + if(code>=12 || func[code]==0 || (*m<0) || (*n<0) || (*k<0)) + { + int info = 1; + xerbla_("GEMM", &info, 4); return 0; + } + + if(beta!=Scalar(1)) + if(beta==Scalar(0)) + matrix(c, *m, *n, *ldc).setZero(); + else + matrix(c, *m, *n, *ldc) *= beta; func[code](*m, *n, *k, a, *lda, b, *ldb, c, *ldc, alpha); - return 1; + return 0; } int EIGEN_BLAS_FUNC(trsm)(char *side, char *uplo, char *opa, char *diag, int *m, int *n, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *pb, int *ldb) { +// std::cerr << "in trsm " << *side << " " << *uplo << " " << *opa << " " << *diag << " " << *m << "," << *n << " " << *palpha << " " << *lda << " " << *ldb<< "\n"; typedef void (*functype)(int, int, const Scalar *, int, Scalar *, int); - functype func[32]; + static functype func[32]; static bool init = false; if(!init) @@ -74,38 +83,38 @@ int EIGEN_BLAS_FUNC(trsm)(char *side, char *uplo, char *opa, char *diag, int *m, for(int k=0; k<32; ++k) func[k] = 0; - func[NOTR | (LEFT << 2) | (UP << 3) | (NUNIT << 4)] = (ei_triangular_solve_matrix::run); - func[TR | (LEFT << 2) | (UP << 3) | (NUNIT << 4)] = (ei_triangular_solve_matrix::run); - func[ADJ | (LEFT << 2) | (UP << 3) | (NUNIT << 4)] = (ei_triangular_solve_matrix::run); + func[NOTR | (LEFT << 2) | (UP << 3) | (NUNIT << 4)] = (ei_triangular_solve_matrix::run); + func[TR | (LEFT << 2) | (UP << 3) | (NUNIT << 4)] = (ei_triangular_solve_matrix::run); + func[ADJ | (LEFT << 2) | (UP << 3) | (NUNIT << 4)] = (ei_triangular_solve_matrix::run); - func[NOTR | (RIGHT << 2) | (UP << 3) | (NUNIT << 4)] = (ei_triangular_solve_matrix::run); - func[TR | (RIGHT << 2) | (UP << 3) | (NUNIT << 4)] = (ei_triangular_solve_matrix::run); - func[ADJ | (RIGHT << 2) | (UP << 3) | (NUNIT << 4)] = (ei_triangular_solve_matrix::run); + func[NOTR | (RIGHT << 2) | (UP << 3) | (NUNIT << 4)] = (ei_triangular_solve_matrix::run); + func[TR | (RIGHT << 2) | (UP << 3) | (NUNIT << 4)] = (ei_triangular_solve_matrix::run); + func[ADJ | (RIGHT << 2) | (UP << 3) | (NUNIT << 4)] = (ei_triangular_solve_matrix::run); - func[NOTR | (LEFT << 2) | (LO << 3) | (NUNIT << 4)] = (ei_triangular_solve_matrix::run); - func[TR | (LEFT << 2) | (LO << 3) | (NUNIT << 4)] = (ei_triangular_solve_matrix::run); - func[ADJ | (LEFT << 2) | (LO << 3) | (NUNIT << 4)] = (ei_triangular_solve_matrix::run); + func[NOTR | (LEFT << 2) | (LO << 3) | (NUNIT << 4)] = (ei_triangular_solve_matrix::run); + func[TR | (LEFT << 2) | (LO << 3) | (NUNIT << 4)] = (ei_triangular_solve_matrix::run); + func[ADJ | (LEFT << 2) | (LO << 3) | (NUNIT << 4)] = (ei_triangular_solve_matrix::run); - func[NOTR | (RIGHT << 2) | (LO << 3) | (NUNIT << 4)] = (ei_triangular_solve_matrix::run); - func[TR | (RIGHT << 2) | (LO << 3) | (NUNIT << 4)] = (ei_triangular_solve_matrix::run); - func[ADJ | (RIGHT << 2) | (LO << 3) | (NUNIT << 4)] = (ei_triangular_solve_matrix::run); + func[NOTR | (RIGHT << 2) | (LO << 3) | (NUNIT << 4)] = (ei_triangular_solve_matrix::run); + func[TR | (RIGHT << 2) | (LO << 3) | (NUNIT << 4)] = (ei_triangular_solve_matrix::run); + func[ADJ | (RIGHT << 2) | (LO << 3) | (NUNIT << 4)] = (ei_triangular_solve_matrix::run); - func[NOTR | (LEFT << 2) | (UP << 3) | (UNIT << 4)] = (ei_triangular_solve_matrix::run); - func[TR | (LEFT << 2) | (UP << 3) | (UNIT << 4)] = (ei_triangular_solve_matrix::run); - func[ADJ | (LEFT << 2) | (UP << 3) | (UNIT << 4)] = (ei_triangular_solve_matrix::run); + func[NOTR | (LEFT << 2) | (UP << 3) | (UNIT << 4)] = (ei_triangular_solve_matrix::run); + func[TR | (LEFT << 2) | (UP << 3) | (UNIT << 4)] = (ei_triangular_solve_matrix::run); + func[ADJ | (LEFT << 2) | (UP << 3) | (UNIT << 4)] = (ei_triangular_solve_matrix::run); - func[NOTR | (RIGHT << 2) | (UP << 3) | (UNIT << 4)] = (ei_triangular_solve_matrix::run); - func[TR | (RIGHT << 2) | (UP << 3) | (UNIT << 4)] = (ei_triangular_solve_matrix::run); - func[ADJ | (RIGHT << 2) | (UP << 3) | (UNIT << 4)] = (ei_triangular_solve_matrix::run); + func[NOTR | (RIGHT << 2) | (UP << 3) | (UNIT << 4)] = (ei_triangular_solve_matrix::run); + func[TR | (RIGHT << 2) | (UP << 3) | (UNIT << 4)] = (ei_triangular_solve_matrix::run); + func[ADJ | (RIGHT << 2) | (UP << 3) | (UNIT << 4)] = (ei_triangular_solve_matrix::run); - func[NOTR | (LEFT << 2) | (LO << 3) | (UNIT << 4)] = (ei_triangular_solve_matrix::run); - func[TR | (LEFT << 2) | (LO << 3) | (UNIT << 4)] = (ei_triangular_solve_matrix::run); - func[ADJ | (LEFT << 2) | (LO << 3) | (UNIT << 4)] = (ei_triangular_solve_matrix::run); + func[NOTR | (LEFT << 2) | (LO << 3) | (UNIT << 4)] = (ei_triangular_solve_matrix::run); + func[TR | (LEFT << 2) | (LO << 3) | (UNIT << 4)] = (ei_triangular_solve_matrix::run); + func[ADJ | (LEFT << 2) | (LO << 3) | (UNIT << 4)] = (ei_triangular_solve_matrix::run); - func[NOTR | (RIGHT << 2) | (LO << 3) | (UNIT << 4)] = (ei_triangular_solve_matrix::run); - func[TR | (RIGHT << 2) | (LO << 3) | (UNIT << 4)] = (ei_triangular_solve_matrix::run); - func[ADJ | (RIGHT << 2) | (LO << 3) | (UNIT << 4)] = (ei_triangular_solve_matrix::run); + func[NOTR | (RIGHT << 2) | (LO << 3) | (UNIT << 4)] = (ei_triangular_solve_matrix::run); + func[TR | (RIGHT << 2) | (LO << 3) | (UNIT << 4)] = (ei_triangular_solve_matrix::run); + func[ADJ | (RIGHT << 2) | (LO << 3) | (UNIT << 4)] = (ei_triangular_solve_matrix::run); init = true; } @@ -114,14 +123,23 @@ int EIGEN_BLAS_FUNC(trsm)(char *side, char *uplo, char *opa, char *diag, int *m, Scalar* b = reinterpret_cast(pb); Scalar alpha = *reinterpret_cast(palpha); - // TODO handle alpha - int code = OP(*opa) | (SIDE(*side) << 2) | (UPLO(*uplo) << 3) | (DIAG(*diag) << 4); - if(code>=32 || func[code]==0) + if(code>=32 || func[code]==0 || *m<0 || *n <0) + { + int info=1; + xerbla_("TRSM",&info,4); return 0; + } - func[code](*m, *n, a, *lda, b, *ldb); - return 1; + if(SIDE(*side)==LEFT) + func[code](*m, *n, a, *lda, b, *ldb); + else + func[code](*n, *m, a, *lda, b, *ldb); + + if(alpha!=Scalar(1)) + matrix(b,*m,*n,*ldb) *= alpha; + + return 0; } @@ -129,46 +147,46 @@ int EIGEN_BLAS_FUNC(trsm)(char *side, char *uplo, char *opa, char *diag, int *m, // b = alpha*b*op(a) for side = 'R'or'r' int EIGEN_BLAS_FUNC(trmm)(char *side, char *uplo, char *opa, char *diag, int *m, int *n, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *pb, int *ldb) { +// std::cerr << "in trmm " << *side << " " << *uplo << " " << *opa << " " << *diag << " " << *m << " " << *n << " " << *lda << " " << *ldb << " " << *palpha << "\n"; typedef void (*functype)(int, int, const Scalar *, int, const Scalar *, int, Scalar *, int, Scalar); - functype func[32]; - + static functype func[32]; static bool init = false; if(!init) { for(int k=0; k<32; ++k) func[k] = 0; - func[NOTR | (LEFT << 2) | (UP << 3) | (NUNIT << 4)] = (ei_product_triangular_matrix_matrix::run); - func[TR | (LEFT << 2) | (UP << 3) | (NUNIT << 4)] = (ei_product_triangular_matrix_matrix::run); - func[ADJ | (LEFT << 2) | (UP << 3) | (NUNIT << 4)] = (ei_product_triangular_matrix_matrix::run); + func[NOTR | (LEFT << 2) | (UP << 3) | (NUNIT << 4)] = (ei_product_triangular_matrix_matrix::run); + func[TR | (LEFT << 2) | (UP << 3) | (NUNIT << 4)] = (ei_product_triangular_matrix_matrix::run); + func[ADJ | (LEFT << 2) | (UP << 3) | (NUNIT << 4)] = (ei_product_triangular_matrix_matrix::run); - func[NOTR | (RIGHT << 2) | (UP << 3) | (NUNIT << 4)] = (ei_product_triangular_matrix_matrix::run); - func[TR | (RIGHT << 2) | (UP << 3) | (NUNIT << 4)] = (ei_product_triangular_matrix_matrix::run); - func[ADJ | (RIGHT << 2) | (UP << 3) | (NUNIT << 4)] = (ei_product_triangular_matrix_matrix::run); + func[NOTR | (RIGHT << 2) | (UP << 3) | (NUNIT << 4)] = (ei_product_triangular_matrix_matrix::run); + func[TR | (RIGHT << 2) | (UP << 3) | (NUNIT << 4)] = (ei_product_triangular_matrix_matrix::run); + func[ADJ | (RIGHT << 2) | (UP << 3) | (NUNIT << 4)] = (ei_product_triangular_matrix_matrix::run); - func[NOTR | (LEFT << 2) | (LO << 3) | (NUNIT << 4)] = (ei_product_triangular_matrix_matrix::run); - func[TR | (LEFT << 2) | (LO << 3) | (NUNIT << 4)] = (ei_product_triangular_matrix_matrix::run); - func[ADJ | (LEFT << 2) | (LO << 3) | (NUNIT << 4)] = (ei_product_triangular_matrix_matrix::run); + func[NOTR | (LEFT << 2) | (LO << 3) | (NUNIT << 4)] = (ei_product_triangular_matrix_matrix::run); + func[TR | (LEFT << 2) | (LO << 3) | (NUNIT << 4)] = (ei_product_triangular_matrix_matrix::run); + func[ADJ | (LEFT << 2) | (LO << 3) | (NUNIT << 4)] = (ei_product_triangular_matrix_matrix::run); - func[NOTR | (RIGHT << 2) | (LO << 3) | (NUNIT << 4)] = (ei_product_triangular_matrix_matrix::run); - func[TR | (RIGHT << 2) | (LO << 3) | (NUNIT << 4)] = (ei_product_triangular_matrix_matrix::run); - func[ADJ | (RIGHT << 2) | (LO << 3) | (NUNIT << 4)] = (ei_product_triangular_matrix_matrix::run); + func[NOTR | (RIGHT << 2) | (LO << 3) | (NUNIT << 4)] = (ei_product_triangular_matrix_matrix::run); + func[TR | (RIGHT << 2) | (LO << 3) | (NUNIT << 4)] = (ei_product_triangular_matrix_matrix::run); + func[ADJ | (RIGHT << 2) | (LO << 3) | (NUNIT << 4)] = (ei_product_triangular_matrix_matrix::run); - func[NOTR | (LEFT << 2) | (UP << 3) | (UNIT << 4)] = (ei_product_triangular_matrix_matrix::run); - func[TR | (LEFT << 2) | (UP << 3) | (UNIT << 4)] = (ei_product_triangular_matrix_matrix::run); - func[ADJ | (LEFT << 2) | (UP << 3) | (UNIT << 4)] = (ei_product_triangular_matrix_matrix::run); + func[NOTR | (LEFT << 2) | (UP << 3) | (UNIT << 4)] = (ei_product_triangular_matrix_matrix::run); + func[TR | (LEFT << 2) | (UP << 3) | (UNIT << 4)] = (ei_product_triangular_matrix_matrix::run); + func[ADJ | (LEFT << 2) | (UP << 3) | (UNIT << 4)] = (ei_product_triangular_matrix_matrix::run); - func[NOTR | (RIGHT << 2) | (UP << 3) | (UNIT << 4)] = (ei_product_triangular_matrix_matrix::run); - func[TR | (RIGHT << 2) | (UP << 3) | (UNIT << 4)] = (ei_product_triangular_matrix_matrix::run); - func[ADJ | (RIGHT << 2) | (UP << 3) | (UNIT << 4)] = (ei_product_triangular_matrix_matrix::run); + func[NOTR | (RIGHT << 2) | (UP << 3) | (UNIT << 4)] = (ei_product_triangular_matrix_matrix::run); + func[TR | (RIGHT << 2) | (UP << 3) | (UNIT << 4)] = (ei_product_triangular_matrix_matrix::run); + func[ADJ | (RIGHT << 2) | (UP << 3) | (UNIT << 4)] = (ei_product_triangular_matrix_matrix::run); - func[NOTR | (LEFT << 2) | (LO << 3) | (UNIT << 4)] = (ei_product_triangular_matrix_matrix::run); - func[TR | (LEFT << 2) | (LO << 3) | (UNIT << 4)] = (ei_product_triangular_matrix_matrix::run); - func[ADJ | (LEFT << 2) | (LO << 3) | (UNIT << 4)] = (ei_product_triangular_matrix_matrix::run); + func[NOTR | (LEFT << 2) | (LO << 3) | (UNIT << 4)] = (ei_product_triangular_matrix_matrix::run); + func[TR | (LEFT << 2) | (LO << 3) | (UNIT << 4)] = (ei_product_triangular_matrix_matrix::run); + func[ADJ | (LEFT << 2) | (LO << 3) | (UNIT << 4)] = (ei_product_triangular_matrix_matrix::run); - func[NOTR | (RIGHT << 2) | (LO << 3) | (UNIT << 4)] = (ei_product_triangular_matrix_matrix::run); - func[TR | (RIGHT << 2) | (LO << 3) | (UNIT << 4)] = (ei_product_triangular_matrix_matrix::run); - func[ADJ | (RIGHT << 2) | (LO << 3) | (UNIT << 4)] = (ei_product_triangular_matrix_matrix::run); + func[NOTR | (RIGHT << 2) | (LO << 3) | (UNIT << 4)] = (ei_product_triangular_matrix_matrix::run); + func[TR | (RIGHT << 2) | (LO << 3) | (UNIT << 4)] = (ei_product_triangular_matrix_matrix::run); + func[ADJ | (RIGHT << 2) | (LO << 3) | (UNIT << 4)] = (ei_product_triangular_matrix_matrix::run); init = true; } @@ -178,10 +196,21 @@ int EIGEN_BLAS_FUNC(trmm)(char *side, char *uplo, char *opa, char *diag, int *m, Scalar alpha = *reinterpret_cast(palpha); int code = OP(*opa) | (SIDE(*side) << 2) | (UPLO(*uplo) << 3) | (DIAG(*diag) << 4); - if(code>=32 || func[code]==0) + if(code>=32 || func[code]==0 || *m<0 || *n <0) + { + int info=1; + xerbla_("TRMM",&info,4); return 0; + } + + // FIXME find a way to avoid this copy + Matrix tmp = matrix(b,*m,*n,*ldb); + matrix(b,*m,*n,*ldb).setZero(); - func[code](*m, *n, a, *lda, b, *ldb, b, *ldb, alpha); + if(SIDE(*side)==LEFT) + func[code](*m, *n, a, *lda, tmp.data(), tmp.outerStride(), b, *ldb, alpha); + else + func[code](*n, *m, tmp.data(), tmp.outerStride(), a, *lda, b, *ldb, alpha); return 1; } @@ -189,14 +218,26 @@ int EIGEN_BLAS_FUNC(trmm)(char *side, char *uplo, char *opa, char *diag, int *m, // c = alpha*b*a + beta*c for side = 'R'or'r int EIGEN_BLAS_FUNC(symm)(char *side, char *uplo, int *m, int *n, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *pb, int *ldb, RealScalar *pbeta, RealScalar *pc, int *ldc) { +// std::cerr << "in symm " << *side << " " << *uplo << " " << *m << "x" << *n << " lda:" << *lda << " ldb:" << *ldb << " ldc:" << *ldc << " alpha:" << *palpha << " beta:" << *pbeta << " " +// << pa << " " << pb << " " << pc << "\n"; Scalar* a = reinterpret_cast(pa); Scalar* b = reinterpret_cast(pb); Scalar* c = reinterpret_cast(pc); Scalar alpha = *reinterpret_cast(palpha); Scalar beta = *reinterpret_cast(pbeta); + if(*m<0 || *n<0) + { + int info=1; + xerbla_("SYMM",&info,4); + return 0; + } + if(beta!=Scalar(1)) - matrix(c, *m, *n, *ldc) *= beta; + if(beta==Scalar(0)) + matrix(c, *m, *n, *ldc).setZero(); + else + matrix(c, *m, *n, *ldc) *= beta; if(SIDE(*side)==LEFT) if(UPLO(*uplo)==UP) @@ -215,15 +256,16 @@ int EIGEN_BLAS_FUNC(symm)(char *side, char *uplo, int *m, int *n, RealScalar *pa else return 0; - return 1; + return 0; } // c = alpha*a*a' + beta*c for op = 'N'or'n' // c = alpha*a'*a + beta*c for op = 'T'or't','C'or'c' int EIGEN_BLAS_FUNC(syrk)(char *uplo, char *op, int *n, int *k, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *pbeta, RealScalar *pc, int *ldc) { +// std::cerr << "in syrk " << *uplo << " " << *op << " " << *n << " " << *k << " " << *palpha << " " << *lda << " " << *pbeta << "\n"; typedef void (*functype)(int, int, const Scalar *, int, Scalar *, int, Scalar); - functype func[8]; + static functype func[8]; static bool init = false; if(!init) @@ -231,13 +273,13 @@ int EIGEN_BLAS_FUNC(syrk)(char *uplo, char *op, int *n, int *k, RealScalar *palp for(int k=0; k<8; ++k) func[k] = 0; - func[NOTR | (UP << 2)] = (ei_selfadjoint_product::run); - func[TR | (UP << 2)] = (ei_selfadjoint_product::run); - func[ADJ | (UP << 2)] = (ei_selfadjoint_product::run); + func[NOTR | (UP << 2)] = (ei_selfadjoint_product::run); + func[TR | (UP << 2)] = (ei_selfadjoint_product::run); + func[ADJ | (UP << 2)] = (ei_selfadjoint_product::run); - func[NOTR | (LO << 2)] = (ei_selfadjoint_product::run); - func[TR | (LO << 2)] = (ei_selfadjoint_product::run); - func[ADJ | (LO << 2)] = (ei_selfadjoint_product::run); + func[NOTR | (LO << 2)] = (ei_selfadjoint_product::run); + func[TR | (LO << 2)] = (ei_selfadjoint_product::run); + func[ADJ | (LO << 2)] = (ei_selfadjoint_product::run); init = true; } @@ -248,8 +290,12 @@ int EIGEN_BLAS_FUNC(syrk)(char *uplo, char *op, int *n, int *k, RealScalar *palp Scalar beta = *reinterpret_cast(pbeta); int code = OP(*op) | (UPLO(*uplo) << 2); - if(code>=8 || func[code]==0) + if(code>=8 || func[code]==0 || *n<0 || *k<0) + { + int info=1; + xerbla_("SYRK",&info,4); return 0; + } if(beta!=Scalar(1)) matrix(c, *n, *n, *ldc) *= beta; @@ -314,7 +360,7 @@ int EIGEN_BLAS_FUNC(hemm)(char *side, char *uplo, int *m, int *n, RealScalar *pa int EIGEN_BLAS_FUNC(herk)(char *uplo, char *op, int *n, int *k, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *pbeta, RealScalar *pc, int *ldc) { typedef void (*functype)(int, int, const Scalar *, int, Scalar *, int, Scalar); - functype func[8]; + static functype func[8]; static bool init = false; if(!init) @@ -322,11 +368,11 @@ int EIGEN_BLAS_FUNC(herk)(char *uplo, char *op, int *n, int *k, RealScalar *palp for(int k=0; k<8; ++k) func[k] = 0; - func[NOTR | (UP << 2)] = (ei_selfadjoint_product::run); - func[ADJ | (UP << 2)] = (ei_selfadjoint_product::run); + func[NOTR | (UP << 2)] = (ei_selfadjoint_product::run); + func[ADJ | (UP << 2)] = (ei_selfadjoint_product::run); - func[NOTR | (LO << 2)] = (ei_selfadjoint_product::run); - func[ADJ | (LO << 2)] = (ei_selfadjoint_product::run); + func[NOTR | (LO << 2)] = (ei_selfadjoint_product::run); + func[ADJ | (LO << 2)] = (ei_selfadjoint_product::run); init = true; } -- cgit v1.2.3