From 512b74aaa19fa12a05774dd30205d2c97e8bdef9 Mon Sep 17 00:00:00 2001 From: Gael Guennebaud Date: Mon, 18 Feb 2019 11:47:54 +0100 Subject: GEMM: catch all scalar-multiple variants when falling-back to a coeff-based product. Before only s*A*B was caught which was both inconsistent with GEMM, sub-optimal, and could even lead to compilation-errors (https://stackoverflow.com/questions/54738495). --- test/product_notemporary.cpp | 38 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) (limited to 'test/product_notemporary.cpp') diff --git a/test/product_notemporary.cpp b/test/product_notemporary.cpp index dffb07608..7f169e6ae 100644 --- a/test/product_notemporary.cpp +++ b/test/product_notemporary.cpp @@ -11,6 +11,35 @@ #include "main.h" +template +void check_scalar_multiple3(Dst &dst, const Lhs& A, const Rhs& B) +{ + VERIFY_EVALUATION_COUNT( (dst.noalias() = A * B), 0); + VERIFY_IS_APPROX( dst, (A.eval() * B.eval()).eval() ); + VERIFY_EVALUATION_COUNT( (dst.noalias() += A * B), 0); + VERIFY_IS_APPROX( dst, 2*(A.eval() * B.eval()).eval() ); + VERIFY_EVALUATION_COUNT( (dst.noalias() -= A * B), 0); + VERIFY_IS_APPROX( dst, (A.eval() * B.eval()).eval() ); +} + +template +void check_scalar_multiple2(Dst &dst, const Lhs& A, const Rhs& B, S2 s2) +{ + CALL_SUBTEST( check_scalar_multiple3(dst, A, B) ); + CALL_SUBTEST( check_scalar_multiple3(dst, A, -B) ); + CALL_SUBTEST( check_scalar_multiple3(dst, A, s2*B) ); + CALL_SUBTEST( check_scalar_multiple3(dst, A, B*s2) ); +} + +template +void check_scalar_multiple1(Dst &dst, const Lhs& A, const Rhs& B, S1 s1, S2 s2) +{ + CALL_SUBTEST( check_scalar_multiple2(dst, A, B, s2) ); + CALL_SUBTEST( check_scalar_multiple2(dst, -A, B, s2) ); + CALL_SUBTEST( check_scalar_multiple2(dst, s1*A, B, s2) ); + CALL_SUBTEST( check_scalar_multiple2(dst, A*s1, B, s2) ); +} + template void product_notemporary(const MatrixType& m) { /* This test checks the number of temporaries created @@ -148,6 +177,15 @@ template void product_notemporary(const MatrixType& m) // Check nested products VERIFY_EVALUATION_COUNT( cvres.noalias() = m1.adjoint() * m1 * cv1, 1 ); VERIFY_EVALUATION_COUNT( rvres.noalias() = rv1 * (m1 * m2.adjoint()), 1 ); + + // exhaustively check all scalar multiple combinations: + { + // Generic path: + check_scalar_multiple1(m3, m1, m2, s1, s2); + // Force fall back to coeff-based: + typename ColMajorMatrixType::BlockXpr m3_blck = m3.block(r0,r0,1,1); + check_scalar_multiple1(m3_blck, m1.block(r0,c0,1,1), m2.block(c0,r0,1,1), s1, s2); + } } EIGEN_DECLARE_TEST(product_notemporary) -- cgit v1.2.3