diff options
author | Gael Guennebaud <g.gael@free.fr> | 2019-02-18 11:47:54 +0100 |
---|---|---|
committer | Gael Guennebaud <g.gael@free.fr> | 2019-02-18 11:47:54 +0100 |
commit | 512b74aaa19fa12a05774dd30205d2c97e8bdef9 (patch) | |
tree | efadb2022fb2291c4b733a7c7f4670dce6b01ba3 /Eigen/src/Core/util/BlasUtil.h | |
parent | ec032ac03b90dc6c58680a4dc858133e9a72fd1f (diff) |
GEMM: catch all scalar-multiple variants when falling-back to a coeff-based product.
Before only s*A*B was caught which was both inconsistent with GEMM, sub-optimal,
and could even lead to compilation-errors (https://stackoverflow.com/questions/54738495).
Diffstat (limited to 'Eigen/src/Core/util/BlasUtil.h')
-rwxr-xr-x | Eigen/src/Core/util/BlasUtil.h | 14 |
1 files changed, 12 insertions, 2 deletions
diff --git a/Eigen/src/Core/util/BlasUtil.h b/Eigen/src/Core/util/BlasUtil.h index a32630ed7..bc0a01540 100755 --- a/Eigen/src/Core/util/BlasUtil.h +++ b/Eigen/src/Core/util/BlasUtil.h @@ -274,7 +274,8 @@ template<typename XprType> struct blas_traits HasUsableDirectAccess = ( (int(XprType::Flags)&DirectAccessBit) && ( bool(XprType::IsVectorAtCompileTime) || int(inner_stride_at_compile_time<XprType>::ret) == 1) - ) ? 1 : 0 + ) ? 1 : 0, + HasScalarFactor = false }; typedef typename conditional<bool(HasUsableDirectAccess), ExtractType, @@ -306,6 +307,9 @@ template<typename Scalar, typename NestedXpr, typename Plain> struct blas_traits<CwiseBinaryOp<scalar_product_op<Scalar>, const CwiseNullaryOp<scalar_constant_op<Scalar>,Plain>, NestedXpr> > : blas_traits<NestedXpr> { + enum { + HasScalarFactor = true + }; typedef blas_traits<NestedXpr> Base; typedef CwiseBinaryOp<scalar_product_op<Scalar>, const CwiseNullaryOp<scalar_constant_op<Scalar>,Plain>, NestedXpr> XprType; typedef typename Base::ExtractType ExtractType; @@ -317,6 +321,9 @@ template<typename Scalar, typename NestedXpr, typename Plain> struct blas_traits<CwiseBinaryOp<scalar_product_op<Scalar>, NestedXpr, const CwiseNullaryOp<scalar_constant_op<Scalar>,Plain> > > : blas_traits<NestedXpr> { + enum { + HasScalarFactor = true + }; typedef blas_traits<NestedXpr> Base; typedef CwiseBinaryOp<scalar_product_op<Scalar>, NestedXpr, const CwiseNullaryOp<scalar_constant_op<Scalar>,Plain> > XprType; typedef typename Base::ExtractType ExtractType; @@ -335,6 +342,9 @@ template<typename Scalar, typename NestedXpr> struct blas_traits<CwiseUnaryOp<scalar_opposite_op<Scalar>, NestedXpr> > : blas_traits<NestedXpr> { + enum { + HasScalarFactor = true + }; typedef blas_traits<NestedXpr> Base; typedef CwiseUnaryOp<scalar_opposite_op<Scalar>, NestedXpr> XprType; typedef typename Base::ExtractType ExtractType; @@ -358,7 +368,7 @@ struct blas_traits<Transpose<NestedXpr> > typename ExtractType::PlainObject >::type DirectLinearAccessType; enum { - IsTransposed = Base::IsTransposed ? 0 : 1 + IsTransposed = Base::IsTransposed ? 0 : 1, }; static inline ExtractType extract(const XprType& x) { return ExtractType(Base::extract(x.nestedExpression())); } static inline Scalar extractScalarFactor(const XprType& x) { return Base::extractScalarFactor(x.nestedExpression()); } |