From 12dda34b15fe2ea4d994fafdba8c4666963e1a55 Mon Sep 17 00:00:00 2001 From: Christoph Hertzberg Date: Fri, 1 Jan 2021 20:54:45 +0100 Subject: Eliminate boolean product warnings by factoring out a `combine_scalar_factors` helper function. --- Eigen/src/Core/util/BlasUtil.h | 41 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) (limited to 'Eigen/src/Core/util') diff --git a/Eigen/src/Core/util/BlasUtil.h b/Eigen/src/Core/util/BlasUtil.h index a90e57446..c5161022c 100755 --- a/Eigen/src/Core/util/BlasUtil.h +++ b/Eigen/src/Core/util/BlasUtil.h @@ -618,6 +618,47 @@ template const typename T::Scalar* extract_data(const T& m) return extract_data_selector::run(m); } +/** + * \c combine_scalar_factors extracts and multiplies factors from GEMM and GEMV products. + * There is a specialization for booleans + */ +template +struct combine_scalar_factors_impl +{ + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static ResScalar run(const Lhs& lhs, const Rhs& rhs) + { + return blas_traits::extractScalarFactor(lhs) * blas_traits::extractScalarFactor(rhs); + } + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static ResScalar run(const ResScalar& alpha, const Lhs& lhs, const Rhs& rhs) + { + return alpha * blas_traits::extractScalarFactor(lhs) * blas_traits::extractScalarFactor(rhs); + } +}; +template +struct combine_scalar_factors_impl +{ + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static bool run(const Lhs& lhs, const Rhs& rhs) + { + return blas_traits::extractScalarFactor(lhs) && blas_traits::extractScalarFactor(rhs); + } + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static bool run(const bool& alpha, const Lhs& lhs, const Rhs& rhs) + { + return alpha && blas_traits::extractScalarFactor(lhs) && blas_traits::extractScalarFactor(rhs); + } +}; + +template +EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE ResScalar combine_scalar_factors(const ResScalar& alpha, const Lhs& lhs, const Rhs& rhs) +{ + return combine_scalar_factors_impl::run(alpha, lhs, rhs); +} +template +EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE ResScalar combine_scalar_factors(const Lhs& lhs, const Rhs& rhs) +{ + return combine_scalar_factors_impl::run(lhs, rhs); +} + + } // end namespace internal } // end namespace Eigen -- cgit v1.2.3