aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen/src/Core
diff options
context:
space:
mode:
Diffstat (limited to 'Eigen/src/Core')
-rw-r--r--Eigen/src/Core/ProductEvaluators.h7
-rw-r--r--Eigen/src/Core/products/GeneralMatrixMatrix.h10
2 files changed, 14 insertions, 3 deletions
diff --git a/Eigen/src/Core/ProductEvaluators.h b/Eigen/src/Core/ProductEvaluators.h
index 792b1811c..6b32c508b 100644
--- a/Eigen/src/Core/ProductEvaluators.h
+++ b/Eigen/src/Core/ProductEvaluators.h
@@ -256,7 +256,7 @@ struct generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,InnerProduct>
{
dst.coeffRef(0,0) += (lhs.transpose().cwiseProduct(rhs)).sum();
}
-
+
template<typename Dst>
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void subTo(Dst& dst, const Lhs& lhs, const Rhs& rhs)
{ dst.coeffRef(0,0) -= (lhs.transpose().cwiseProduct(rhs)).sum(); }
@@ -375,6 +375,11 @@ struct generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,GemvProduct>
template<typename Dest>
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void scaleAndAddTo(Dest& dst, const Lhs& lhs, const Rhs& rhs, const Scalar& alpha)
{
+ // Fallback to inner product if both the lhs and rhs is a runtime vector.
+ if (lhs.rows() == 1 && rhs.cols() == 1) {
+ dst.coeffRef(0,0) += alpha * (lhs.row(0).transpose().cwiseProduct(rhs.col(0)).sum());
+ return;
+ }
LhsNested actual_lhs(lhs);
RhsNested actual_rhs(rhs);
internal::gemv_dense_selector<Side,
diff --git a/Eigen/src/Core/products/GeneralMatrixMatrix.h b/Eigen/src/Core/products/GeneralMatrixMatrix.h
index 508c05c97..8cdf14702 100644
--- a/Eigen/src/Core/products/GeneralMatrixMatrix.h
+++ b/Eigen/src/Core/products/GeneralMatrixMatrix.h
@@ -471,15 +471,21 @@ struct generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,GemmProduct>
if(a_lhs.cols()==0 || a_lhs.rows()==0 || a_rhs.cols()==0)
return;
- // Fallback to GEMV if either the lhs or rhs is a runtime vector
- if (dst.cols() == 1)
+ if (dst.cols() == 1 && dst.rows() == 1) {
+ // Fallback to inner product if both the lhs and rhs is a runtime vector.
+ dst.coeffRef(0,0) += alpha * (a_lhs.row(0).transpose().cwiseProduct(a_rhs.col(0)).sum());
+ return;
+ }
+ else if (dst.cols() == 1)
{
+ // Fallback to GEMV if either the lhs or rhs is a runtime vector
typename Dest::ColXpr dst_vec(dst.col(0));
return internal::generic_product_impl<Lhs,typename Rhs::ConstColXpr,DenseShape,DenseShape,GemvProduct>
::scaleAndAddTo(dst_vec, a_lhs, a_rhs.col(0), alpha);
}
else if (dst.rows() == 1)
{
+ // Fallback to GEMV if either the lhs or rhs is a runtime vector
typename Dest::RowXpr dst_vec(dst.row(0));
return internal::generic_product_impl<typename Lhs::ConstRowXpr,Rhs,DenseShape,DenseShape,GemvProduct>
::scaleAndAddTo(dst_vec, a_lhs.row(0), a_rhs, alpha);