aboutsummaryrefslogtreecommitdiffhomepage
path: root/test/product_notemporary.cpp
diff options
context:
space:
mode:
authorGravatar Gael Guennebaud <g.gael@free.fr>2019-02-18 11:47:54 +0100
committerGravatar Gael Guennebaud <g.gael@free.fr>2019-02-18 11:47:54 +0100
commit512b74aaa19fa12a05774dd30205d2c97e8bdef9 (patch)
treeefadb2022fb2291c4b733a7c7f4670dce6b01ba3 /test/product_notemporary.cpp
parentec032ac03b90dc6c58680a4dc858133e9a72fd1f (diff)
GEMM: catch all scalar-multiple variants when falling-back to a coeff-based product.
Before only s*A*B was caught which was both inconsistent with GEMM, sub-optimal, and could even lead to compilation-errors (https://stackoverflow.com/questions/54738495).
Diffstat (limited to 'test/product_notemporary.cpp')
-rw-r--r--test/product_notemporary.cpp38
1 files changed, 38 insertions, 0 deletions
diff --git a/test/product_notemporary.cpp b/test/product_notemporary.cpp
index dffb07608..7f169e6ae 100644
--- a/test/product_notemporary.cpp
+++ b/test/product_notemporary.cpp
@@ -11,6 +11,35 @@
#include "main.h"
+template<typename Dst, typename Lhs, typename Rhs>
+void check_scalar_multiple3(Dst &dst, const Lhs& A, const Rhs& B)
+{
+ VERIFY_EVALUATION_COUNT( (dst.noalias() = A * B), 0);
+ VERIFY_IS_APPROX( dst, (A.eval() * B.eval()).eval() );
+ VERIFY_EVALUATION_COUNT( (dst.noalias() += A * B), 0);
+ VERIFY_IS_APPROX( dst, 2*(A.eval() * B.eval()).eval() );
+ VERIFY_EVALUATION_COUNT( (dst.noalias() -= A * B), 0);
+ VERIFY_IS_APPROX( dst, (A.eval() * B.eval()).eval() );
+}
+
+template<typename Dst, typename Lhs, typename Rhs, typename S2>
+void check_scalar_multiple2(Dst &dst, const Lhs& A, const Rhs& B, S2 s2)
+{
+ CALL_SUBTEST( check_scalar_multiple3(dst, A, B) );
+ CALL_SUBTEST( check_scalar_multiple3(dst, A, -B) );
+ CALL_SUBTEST( check_scalar_multiple3(dst, A, s2*B) );
+ CALL_SUBTEST( check_scalar_multiple3(dst, A, B*s2) );
+}
+
+template<typename Dst, typename Lhs, typename Rhs, typename S1, typename S2>
+void check_scalar_multiple1(Dst &dst, const Lhs& A, const Rhs& B, S1 s1, S2 s2)
+{
+ CALL_SUBTEST( check_scalar_multiple2(dst, A, B, s2) );
+ CALL_SUBTEST( check_scalar_multiple2(dst, -A, B, s2) );
+ CALL_SUBTEST( check_scalar_multiple2(dst, s1*A, B, s2) );
+ CALL_SUBTEST( check_scalar_multiple2(dst, A*s1, B, s2) );
+}
+
template<typename MatrixType> void product_notemporary(const MatrixType& m)
{
/* This test checks the number of temporaries created
@@ -148,6 +177,15 @@ template<typename MatrixType> void product_notemporary(const MatrixType& m)
// Check nested products
VERIFY_EVALUATION_COUNT( cvres.noalias() = m1.adjoint() * m1 * cv1, 1 );
VERIFY_EVALUATION_COUNT( rvres.noalias() = rv1 * (m1 * m2.adjoint()), 1 );
+
+ // exhaustively check all scalar multiple combinations:
+ {
+ // Generic path:
+ check_scalar_multiple1(m3, m1, m2, s1, s2);
+ // Force fall back to coeff-based:
+ typename ColMajorMatrixType::BlockXpr m3_blck = m3.block(r0,r0,1,1);
+ check_scalar_multiple1(m3_blck, m1.block(r0,c0,1,1), m2.block(c0,r0,1,1), s1, s2);
+ }
}
EIGEN_DECLARE_TEST(product_notemporary)