aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Gael Guennebaud <g.gael@free.fr>2016-02-15 21:43:07 +0100
committerGravatar Gael Guennebaud <g.gael@free.fr>2016-02-15 21:43:07 +0100
commitf6f057bb7d3fcd24b751cba2e70d416f4a803e1f (patch)
treecf0d1239c45aac2a3f4da0370a3dcc9c95f65243
parent8e1f1ba6a6cf0580da6f8756562f94b6410d5e58 (diff)
bug #1166: fix shortcomming in gemv when the destination is not a vector at compile-time.
-rw-r--r--Eigen/src/Core/GeneralProduct.h11
-rw-r--r--test/product.h16
2 files changed, 23 insertions, 4 deletions
diff --git a/Eigen/src/Core/GeneralProduct.h b/Eigen/src/Core/GeneralProduct.h
index 0769a212e..53f934999 100644
--- a/Eigen/src/Core/GeneralProduct.h
+++ b/Eigen/src/Core/GeneralProduct.h
@@ -213,15 +213,18 @@ template<> struct gemv_dense_selector<OnTheRight,ColMajor,true>
ResScalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(lhs)
* RhsBlasTraits::extractScalarFactor(rhs);
+ // make sure Dest is a compile-time vector type (bug 1166)
+ typedef typename conditional<Dest::IsVectorAtCompileTime, Dest, typename Dest::ColXpr>::type ActualDest;
+
enum {
// FIXME find a way to allow an inner stride on the result if packet_traits<Scalar>::size==1
// on, the other hand it is good for the cache to pack the vector anyways...
- EvalToDestAtCompileTime = Dest::InnerStrideAtCompileTime==1,
+ EvalToDestAtCompileTime = (ActualDest::InnerStrideAtCompileTime==1),
ComplexByReal = (NumTraits<LhsScalar>::IsComplex) && (!NumTraits<RhsScalar>::IsComplex),
- MightCannotUseDest = (Dest::InnerStrideAtCompileTime!=1) || ComplexByReal
+ MightCannotUseDest = (ActualDest::InnerStrideAtCompileTime!=1) || ComplexByReal
};
- gemv_static_vector_if<ResScalar,Dest::SizeAtCompileTime,Dest::MaxSizeAtCompileTime,MightCannotUseDest> static_dest;
+ gemv_static_vector_if<ResScalar,ActualDest::SizeAtCompileTime,ActualDest::MaxSizeAtCompileTime,MightCannotUseDest> static_dest;
const bool alphaIsCompatible = (!ComplexByReal) || (numext::imag(actualAlpha)==RealScalar(0));
const bool evalToDest = EvalToDestAtCompileTime && alphaIsCompatible;
@@ -314,7 +317,7 @@ template<> struct gemv_dense_selector<OnTheRight,RowMajor,true>
actualLhs.rows(), actualLhs.cols(),
LhsMapper(actualLhs.data(), actualLhs.outerStride()),
RhsMapper(actualRhsPtr, 1),
- dest.data(), dest.innerStride(),
+ dest.data(), dest.col(0).innerStride(), //NOTE if dest is not a vector at compile-time, then dest.innerStride() might be wrong. (bug 1166)
actualAlpha);
}
};
diff --git a/test/product.h b/test/product.h
index bd92309d2..45bb64958 100644
--- a/test/product.h
+++ b/test/product.h
@@ -144,6 +144,22 @@ template<typename MatrixType> void product(const MatrixType& m)
VERIFY_IS_APPROX(res.col(r).noalias() = square.adjoint() * square.col(r), (square.adjoint() * square.col(r)).eval());
VERIFY_IS_APPROX(res.col(r).noalias() = square * square.col(r), (square * square.col(r)).eval());
+ // vector at runtime (see bug 1166)
+ {
+ RowSquareMatrixType ref(square);
+ ColSquareMatrixType ref2(square2);
+ ref = res = square;
+ VERIFY_IS_APPROX(res.block(0,0,1,rows).noalias() = m1.col(0).transpose() * square.transpose(), (ref.row(0) = m1.col(0).transpose() * square.transpose()));
+ VERIFY_IS_APPROX(res.block(0,0,1,rows).noalias() = m1.block(0,0,rows,1).transpose() * square.transpose(), (ref.row(0) = m1.col(0).transpose() * square.transpose()));
+ VERIFY_IS_APPROX(res.block(0,0,1,rows).noalias() = m1.col(0).transpose() * square, (ref.row(0) = m1.col(0).transpose() * square));
+ VERIFY_IS_APPROX(res.block(0,0,1,rows).noalias() = m1.block(0,0,rows,1).transpose() * square, (ref.row(0) = m1.col(0).transpose() * square));
+ ref2 = res2 = square2;
+ VERIFY_IS_APPROX(res2.block(0,0,1,cols).noalias() = m1.row(0) * square2.transpose(), (ref2.row(0) = m1.row(0) * square2.transpose()));
+ VERIFY_IS_APPROX(res2.block(0,0,1,cols).noalias() = m1.block(0,0,1,cols) * square2.transpose(), (ref2.row(0) = m1.row(0) * square2.transpose()));
+ VERIFY_IS_APPROX(res2.block(0,0,1,cols).noalias() = m1.row(0) * square2, (ref2.row(0) = m1.row(0) * square2));
+ VERIFY_IS_APPROX(res2.block(0,0,1,cols).noalias() = m1.block(0,0,1,cols) * square2, (ref2.row(0) = m1.row(0) * square2));
+ }
+
// inner product
{
Scalar x = square2.row(c) * square2.col(c2);