aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen/src/Core/util
diff options
context:
space:
mode:
authorGravatar Christoph Hertzberg <chtz@informatik.uni-bremen.de>2021-01-01 20:54:45 +0100
committerGravatar Rasmus Munk Larsen <rmlarsen@google.com>2021-01-05 18:15:30 +0000
commit12dda34b15fe2ea4d994fafdba8c4666963e1a55 (patch)
treefc1d5f1c879b229b9421af3d0a8fe76302551eb8 /Eigen/src/Core/util
parent070d303d56d46d2e018a58214da24ca629ea454f (diff)
Eliminate boolean product warnings by factoring out a
`combine_scalar_factors` helper function.
Diffstat (limited to 'Eigen/src/Core/util')
-rwxr-xr-xEigen/src/Core/util/BlasUtil.h41
1 files changed, 41 insertions, 0 deletions
diff --git a/Eigen/src/Core/util/BlasUtil.h b/Eigen/src/Core/util/BlasUtil.h
index a90e57446..c5161022c 100755
--- a/Eigen/src/Core/util/BlasUtil.h
+++ b/Eigen/src/Core/util/BlasUtil.h
@@ -618,6 +618,47 @@ template<typename T> const typename T::Scalar* extract_data(const T& m)
return extract_data_selector<T>::run(m);
}
+/**
+ * \c combine_scalar_factors extracts and multiplies factors from GEMM and GEMV products.
+ * There is a specialization for booleans
+ */
+template<typename ResScalar, typename Lhs, typename Rhs>
+struct combine_scalar_factors_impl
+{
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static ResScalar run(const Lhs& lhs, const Rhs& rhs)
+ {
+ return blas_traits<Lhs>::extractScalarFactor(lhs) * blas_traits<Rhs>::extractScalarFactor(rhs);
+ }
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static ResScalar run(const ResScalar& alpha, const Lhs& lhs, const Rhs& rhs)
+ {
+ return alpha * blas_traits<Lhs>::extractScalarFactor(lhs) * blas_traits<Rhs>::extractScalarFactor(rhs);
+ }
+};
+template<typename Lhs, typename Rhs>
+struct combine_scalar_factors_impl<bool, Lhs, Rhs>
+{
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static bool run(const Lhs& lhs, const Rhs& rhs)
+ {
+ return blas_traits<Lhs>::extractScalarFactor(lhs) && blas_traits<Rhs>::extractScalarFactor(rhs);
+ }
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static bool run(const bool& alpha, const Lhs& lhs, const Rhs& rhs)
+ {
+ return alpha && blas_traits<Lhs>::extractScalarFactor(lhs) && blas_traits<Rhs>::extractScalarFactor(rhs);
+ }
+};
+
+template<typename ResScalar, typename Lhs, typename Rhs>
+EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE ResScalar combine_scalar_factors(const ResScalar& alpha, const Lhs& lhs, const Rhs& rhs)
+{
+ return combine_scalar_factors_impl<ResScalar,Lhs,Rhs>::run(alpha, lhs, rhs);
+}
+template<typename ResScalar, typename Lhs, typename Rhs>
+EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE ResScalar combine_scalar_factors(const Lhs& lhs, const Rhs& rhs)
+{
+ return combine_scalar_factors_impl<ResScalar,Lhs,Rhs>::run(lhs, rhs);
+}
+
+
} // end namespace internal
} // end namespace Eigen