aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen/src/Core/util/BlasUtil.h
diff options
context:
space:
mode:
authorGravatar Gael Guennebaud <g.gael@free.fr>2019-02-18 11:47:54 +0100
committerGravatar Gael Guennebaud <g.gael@free.fr>2019-02-18 11:47:54 +0100
commit512b74aaa19fa12a05774dd30205d2c97e8bdef9 (patch)
treeefadb2022fb2291c4b733a7c7f4670dce6b01ba3 /Eigen/src/Core/util/BlasUtil.h
parentec032ac03b90dc6c58680a4dc858133e9a72fd1f (diff)
GEMM: catch all scalar-multiple variants when falling-back to a coeff-based product.
Before only s*A*B was caught which was both inconsistent with GEMM, sub-optimal, and could even lead to compilation-errors (https://stackoverflow.com/questions/54738495).
Diffstat (limited to 'Eigen/src/Core/util/BlasUtil.h')
-rwxr-xr-xEigen/src/Core/util/BlasUtil.h14
1 files changed, 12 insertions, 2 deletions
diff --git a/Eigen/src/Core/util/BlasUtil.h b/Eigen/src/Core/util/BlasUtil.h
index a32630ed7..bc0a01540 100755
--- a/Eigen/src/Core/util/BlasUtil.h
+++ b/Eigen/src/Core/util/BlasUtil.h
@@ -274,7 +274,8 @@ template<typename XprType> struct blas_traits
HasUsableDirectAccess = ( (int(XprType::Flags)&DirectAccessBit)
&& ( bool(XprType::IsVectorAtCompileTime)
|| int(inner_stride_at_compile_time<XprType>::ret) == 1)
- ) ? 1 : 0
+ ) ? 1 : 0,
+ HasScalarFactor = false
};
typedef typename conditional<bool(HasUsableDirectAccess),
ExtractType,
@@ -306,6 +307,9 @@ template<typename Scalar, typename NestedXpr, typename Plain>
struct blas_traits<CwiseBinaryOp<scalar_product_op<Scalar>, const CwiseNullaryOp<scalar_constant_op<Scalar>,Plain>, NestedXpr> >
: blas_traits<NestedXpr>
{
+ enum {
+ HasScalarFactor = true
+ };
typedef blas_traits<NestedXpr> Base;
typedef CwiseBinaryOp<scalar_product_op<Scalar>, const CwiseNullaryOp<scalar_constant_op<Scalar>,Plain>, NestedXpr> XprType;
typedef typename Base::ExtractType ExtractType;
@@ -317,6 +321,9 @@ template<typename Scalar, typename NestedXpr, typename Plain>
struct blas_traits<CwiseBinaryOp<scalar_product_op<Scalar>, NestedXpr, const CwiseNullaryOp<scalar_constant_op<Scalar>,Plain> > >
: blas_traits<NestedXpr>
{
+ enum {
+ HasScalarFactor = true
+ };
typedef blas_traits<NestedXpr> Base;
typedef CwiseBinaryOp<scalar_product_op<Scalar>, NestedXpr, const CwiseNullaryOp<scalar_constant_op<Scalar>,Plain> > XprType;
typedef typename Base::ExtractType ExtractType;
@@ -335,6 +342,9 @@ template<typename Scalar, typename NestedXpr>
struct blas_traits<CwiseUnaryOp<scalar_opposite_op<Scalar>, NestedXpr> >
: blas_traits<NestedXpr>
{
+ enum {
+ HasScalarFactor = true
+ };
typedef blas_traits<NestedXpr> Base;
typedef CwiseUnaryOp<scalar_opposite_op<Scalar>, NestedXpr> XprType;
typedef typename Base::ExtractType ExtractType;
@@ -358,7 +368,7 @@ struct blas_traits<Transpose<NestedXpr> >
typename ExtractType::PlainObject
>::type DirectLinearAccessType;
enum {
- IsTransposed = Base::IsTransposed ? 0 : 1
+ IsTransposed = Base::IsTransposed ? 0 : 1,
};
static inline ExtractType extract(const XprType& x) { return ExtractType(Base::extract(x.nestedExpression())); }
static inline Scalar extractScalarFactor(const XprType& x) { return Base::extractScalarFactor(x.nestedExpression()); }