aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Christoph Hertzberg <chtz@informatik.uni-bremen.de>2021-01-01 20:54:45 +0100
committerGravatar Rasmus Munk Larsen <rmlarsen@google.com>2021-01-05 18:15:30 +0000
commit12dda34b15fe2ea4d994fafdba8c4666963e1a55 (patch)
treefc1d5f1c879b229b9421af3d0a8fe76302551eb8
parent070d303d56d46d2e018a58214da24ca629ea454f (diff)
Eliminate boolean product warnings by factoring out a
`combine_scalar_factors` helper function.
-rw-r--r--Eigen/src/Core/GeneralProduct.h6
-rw-r--r--Eigen/src/Core/ProductEvaluators.h4
-rw-r--r--Eigen/src/Core/products/GeneralMatrixMatrix.h3
-rwxr-xr-xEigen/src/Core/util/BlasUtil.h41
-rw-r--r--test/product_small.cpp11
5 files changed, 51 insertions, 14 deletions
diff --git a/Eigen/src/Core/GeneralProduct.h b/Eigen/src/Core/GeneralProduct.h
index bf7ef54b5..6906aa75d 100644
--- a/Eigen/src/Core/GeneralProduct.h
+++ b/Eigen/src/Core/GeneralProduct.h
@@ -228,8 +228,7 @@ template<> struct gemv_dense_selector<OnTheRight,ColMajor,true>
ActualLhsType actualLhs = LhsBlasTraits::extract(lhs);
ActualRhsType actualRhs = RhsBlasTraits::extract(rhs);
- ResScalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(lhs)
- * RhsBlasTraits::extractScalarFactor(rhs);
+ ResScalar actualAlpha = combine_scalar_factors(alpha, lhs, rhs);
// make sure Dest is a compile-time vector type (bug 1166)
typedef typename conditional<Dest::IsVectorAtCompileTime, Dest, typename Dest::ColXpr>::type ActualDest;
@@ -320,8 +319,7 @@ template<> struct gemv_dense_selector<OnTheRight,RowMajor,true>
typename add_const<ActualLhsType>::type actualLhs = LhsBlasTraits::extract(lhs);
typename add_const<ActualRhsType>::type actualRhs = RhsBlasTraits::extract(rhs);
- ResScalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(lhs)
- * RhsBlasTraits::extractScalarFactor(rhs);
+ ResScalar actualAlpha = combine_scalar_factors(alpha, lhs, rhs);
enum {
// FIXME find a way to allow an inner stride on the result if packet_traits<Scalar>::size==1
diff --git a/Eigen/src/Core/ProductEvaluators.h b/Eigen/src/Core/ProductEvaluators.h
index 02b58438c..079189a10 100644
--- a/Eigen/src/Core/ProductEvaluators.h
+++ b/Eigen/src/Core/ProductEvaluators.h
@@ -441,8 +441,8 @@ struct generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,CoeffBasedProductMode>
};
// FIXME: in c++11 this should be auto, and extractScalarFactor should also return auto
// this is important for real*complex_mat
- Scalar actualAlpha = blas_traits<Lhs>::extractScalarFactor(lhs)
- * blas_traits<Rhs>::extractScalarFactor(rhs);
+ Scalar actualAlpha = combine_scalar_factors<Scalar>(lhs, rhs);
+
eval_dynamic_impl(dst,
blas_traits<Lhs>::extract(lhs).template conjugateIf<ConjLhs>(),
blas_traits<Rhs>::extract(rhs).template conjugateIf<ConjRhs>(),
diff --git a/Eigen/src/Core/products/GeneralMatrixMatrix.h b/Eigen/src/Core/products/GeneralMatrixMatrix.h
index 0d55bdf9e..caa65fccc 100644
--- a/Eigen/src/Core/products/GeneralMatrixMatrix.h
+++ b/Eigen/src/Core/products/GeneralMatrixMatrix.h
@@ -489,8 +489,7 @@ struct generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,GemmProduct>
typename internal::add_const_on_value_type<ActualLhsType>::type lhs = LhsBlasTraits::extract(a_lhs);
typename internal::add_const_on_value_type<ActualRhsType>::type rhs = RhsBlasTraits::extract(a_rhs);
- Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(a_lhs)
- * RhsBlasTraits::extractScalarFactor(a_rhs);
+ Scalar actualAlpha = combine_scalar_factors(alpha, a_lhs, a_rhs);
typedef internal::gemm_blocking_space<(Dest::Flags&RowMajorBit) ? RowMajor : ColMajor,LhsScalar,RhsScalar,
Dest::MaxRowsAtCompileTime,Dest::MaxColsAtCompileTime,MaxDepthAtCompileTime> BlockingType;
diff --git a/Eigen/src/Core/util/BlasUtil.h b/Eigen/src/Core/util/BlasUtil.h
index a90e57446..c5161022c 100755
--- a/Eigen/src/Core/util/BlasUtil.h
+++ b/Eigen/src/Core/util/BlasUtil.h
@@ -618,6 +618,47 @@ template<typename T> const typename T::Scalar* extract_data(const T& m)
return extract_data_selector<T>::run(m);
}
+/**
+ * \c combine_scalar_factors extracts and multiplies factors from GEMM and GEMV products.
+ * There is a specialization for booleans
+ */
+template<typename ResScalar, typename Lhs, typename Rhs>
+struct combine_scalar_factors_impl
+{
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static ResScalar run(const Lhs& lhs, const Rhs& rhs)
+ {
+ return blas_traits<Lhs>::extractScalarFactor(lhs) * blas_traits<Rhs>::extractScalarFactor(rhs);
+ }
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static ResScalar run(const ResScalar& alpha, const Lhs& lhs, const Rhs& rhs)
+ {
+ return alpha * blas_traits<Lhs>::extractScalarFactor(lhs) * blas_traits<Rhs>::extractScalarFactor(rhs);
+ }
+};
+template<typename Lhs, typename Rhs>
+struct combine_scalar_factors_impl<bool, Lhs, Rhs>
+{
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static bool run(const Lhs& lhs, const Rhs& rhs)
+ {
+ return blas_traits<Lhs>::extractScalarFactor(lhs) && blas_traits<Rhs>::extractScalarFactor(rhs);
+ }
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static bool run(const bool& alpha, const Lhs& lhs, const Rhs& rhs)
+ {
+ return alpha && blas_traits<Lhs>::extractScalarFactor(lhs) && blas_traits<Rhs>::extractScalarFactor(rhs);
+ }
+};
+
+template<typename ResScalar, typename Lhs, typename Rhs>
+EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE ResScalar combine_scalar_factors(const ResScalar& alpha, const Lhs& lhs, const Rhs& rhs)
+{
+ return combine_scalar_factors_impl<ResScalar,Lhs,Rhs>::run(alpha, lhs, rhs);
+}
+template<typename ResScalar, typename Lhs, typename Rhs>
+EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE ResScalar combine_scalar_factors(const Lhs& lhs, const Rhs& rhs)
+{
+ return combine_scalar_factors_impl<ResScalar,Lhs,Rhs>::run(lhs, rhs);
+}
+
+
} // end namespace internal
} // end namespace Eigen
diff --git a/test/product_small.cpp b/test/product_small.cpp
index 93876dba4..1d6df6e58 100644
--- a/test/product_small.cpp
+++ b/test/product_small.cpp
@@ -56,18 +56,17 @@ test_lazy_single(int rows, int cols, int depth)
VERIFY_IS_APPROX(C+=A.lazyProduct(B), ref_prod(D,A,B));
}
-template<typename T>
-void test_dynamic_exact()
+void test_dynamic_bool()
{
int rows = internal::random<int>(1,64);
int cols = internal::random<int>(1,64);
int depth = internal::random<int>(1,65);
- typedef Matrix<T,Dynamic,Dynamic> MatrixX;
+ typedef Matrix<bool,Dynamic,Dynamic> MatrixX;
MatrixX A(rows,depth); A.setRandom();
MatrixX B(depth,cols); B.setRandom();
- MatrixX C(rows,cols); C.setRandom();
- MatrixX D(C);
+ MatrixX C(rows,cols); C.setRandom();
+ MatrixX D(C);
for(Index i=0;i<C.rows();++i)
for(Index j=0;j<C.cols();++j)
for(Index k=0;k<A.cols();++k)
@@ -317,7 +316,7 @@ EIGEN_DECLARE_TEST(product_small)
CALL_SUBTEST_6( bug_1311<3>() );
CALL_SUBTEST_6( bug_1311<5>() );
- CALL_SUBTEST_9( test_dynamic_exact<bool>() );
+ CALL_SUBTEST_9( test_dynamic_bool() );
}
CALL_SUBTEST_6( product_small_regressions<0>() );