diff options
-rw-r--r-- | Eigen/src/Core/ProductEvaluators.h | 36 | ||||
-rw-r--r-- | test/product.h | 8 | ||||
-rw-r--r-- | test/product_notemporary.cpp | 3 |
3 files changed, 28 insertions, 19 deletions
diff --git a/Eigen/src/Core/ProductEvaluators.h b/Eigen/src/Core/ProductEvaluators.h index 955668bef..a64bda394 100644 --- a/Eigen/src/Core/ProductEvaluators.h +++ b/Eigen/src/Core/ProductEvaluators.h @@ -194,7 +194,6 @@ struct Assignment<DstXprType, CwiseBinaryOp<internal::scalar_product_op<ScalarBi //---------------------------------------- // Catch "Dense ?= xpr + Product<>" expression to save one temporary // FIXME we could probably enable these rules for any product, i.e., not only Dense and DefaultProduct -// TODO enable it for "Dense ?= xpr - Product<>" as well. template<typename OtherXpr, typename Lhs, typename Rhs> struct evaluator_assume_aliasing<CwiseBinaryOp<internal::scalar_sum_op<typename OtherXpr::Scalar,typename Product<Lhs,Rhs,DefaultProduct>::Scalar>, const OtherXpr, @@ -203,10 +202,9 @@ struct evaluator_assume_aliasing<CwiseBinaryOp<internal::scalar_sum_op<typename }; template<typename DstXprType, typename OtherXpr, typename ProductType, typename Func1, typename Func2> -struct assignment_from_xpr_plus_product +struct assignment_from_xpr_op_product { - typedef CwiseBinaryOp<internal::scalar_sum_op<typename OtherXpr::Scalar,typename ProductType::Scalar>, const OtherXpr, const ProductType> SrcXprType; - template<typename InitialFunc> + template<typename SrcXprType, typename InitialFunc> static EIGEN_STRONG_INLINE void run(DstXprType &dst, const SrcXprType &src, const InitialFunc& /*func*/) { @@ -215,21 +213,21 @@ struct assignment_from_xpr_plus_product } }; -template< typename DstXprType, typename OtherXpr, typename Lhs, typename Rhs, typename DstScalar, typename SrcScalar, typename OtherScalar,typename ProdScalar> -struct Assignment<DstXprType, CwiseBinaryOp<internal::scalar_sum_op<OtherScalar,ProdScalar>, const OtherXpr, - const Product<Lhs,Rhs,DefaultProduct> >, internal::assign_op<DstScalar,SrcScalar>, Dense2Dense> - : assignment_from_xpr_plus_product<DstXprType, OtherXpr, Product<Lhs,Rhs,DefaultProduct>, internal::assign_op<DstScalar,OtherScalar>, internal::add_assign_op<DstScalar,ProdScalar> > -{}; -template< typename DstXprType, typename OtherXpr, typename Lhs, typename Rhs, typename DstScalar, typename SrcScalar, typename OtherScalar,typename ProdScalar> -struct Assignment<DstXprType, CwiseBinaryOp<internal::scalar_sum_op<OtherScalar,ProdScalar>, const OtherXpr, - const Product<Lhs,Rhs,DefaultProduct> >, internal::add_assign_op<DstScalar,SrcScalar>, Dense2Dense> - : assignment_from_xpr_plus_product<DstXprType, OtherXpr, Product<Lhs,Rhs,DefaultProduct>, internal::add_assign_op<DstScalar,OtherScalar>, internal::add_assign_op<DstScalar,ProdScalar> > -{}; -template< typename DstXprType, typename OtherXpr, typename Lhs, typename Rhs, typename DstScalar, typename SrcScalar, typename OtherScalar,typename ProdScalar> -struct Assignment<DstXprType, CwiseBinaryOp<internal::scalar_sum_op<OtherScalar,ProdScalar>, const OtherXpr, - const Product<Lhs,Rhs,DefaultProduct> >, internal::sub_assign_op<DstScalar,SrcScalar>, Dense2Dense> - : assignment_from_xpr_plus_product<DstXprType, OtherXpr, Product<Lhs,Rhs,DefaultProduct>, internal::sub_assign_op<DstScalar,OtherScalar>, internal::sub_assign_op<DstScalar,ProdScalar> > -{}; +#define EIGEN_CATCH_ASSIGN_XPR_OP_PRODUCT(ASSIGN_OP,BINOP,ASSIGN_OP2) \ + template< typename DstXprType, typename OtherXpr, typename Lhs, typename Rhs, typename DstScalar, typename SrcScalar, typename OtherScalar,typename ProdScalar> \ + struct Assignment<DstXprType, CwiseBinaryOp<internal::BINOP<OtherScalar,ProdScalar>, const OtherXpr, \ + const Product<Lhs,Rhs,DefaultProduct> >, internal::ASSIGN_OP<DstScalar,SrcScalar>, Dense2Dense> \ + : assignment_from_xpr_op_product<DstXprType, OtherXpr, Product<Lhs,Rhs,DefaultProduct>, internal::ASSIGN_OP<DstScalar,OtherScalar>, internal::ASSIGN_OP2<DstScalar,ProdScalar> > \ + {} + +EIGEN_CATCH_ASSIGN_XPR_OP_PRODUCT(assign_op, scalar_sum_op,add_assign_op); +EIGEN_CATCH_ASSIGN_XPR_OP_PRODUCT(add_assign_op,scalar_sum_op,add_assign_op); +EIGEN_CATCH_ASSIGN_XPR_OP_PRODUCT(sub_assign_op,scalar_sum_op,sub_assign_op); + +EIGEN_CATCH_ASSIGN_XPR_OP_PRODUCT(assign_op, scalar_difference_op,sub_assign_op); +EIGEN_CATCH_ASSIGN_XPR_OP_PRODUCT(add_assign_op,scalar_difference_op,sub_assign_op); +EIGEN_CATCH_ASSIGN_XPR_OP_PRODUCT(sub_assign_op,scalar_difference_op,add_assign_op); + //---------------------------------------- template<typename Lhs, typename Rhs> diff --git a/test/product.h b/test/product.h index 27976a4ae..cabfc0b03 100644 --- a/test/product.h +++ b/test/product.h @@ -119,6 +119,14 @@ template<typename MatrixType> void product(const MatrixType& m) res.noalias() -= square + m1 * m2.transpose(); VERIFY_IS_APPROX(res, square + m1 * m2.transpose()); + // test d ?= a-b*c rules + res.noalias() = square - m1 * m2.transpose(); + VERIFY_IS_APPROX(res, square - m1 * m2.transpose()); + res.noalias() += square - m1 * m2.transpose(); + VERIFY_IS_APPROX(res, 2*(square - m1 * m2.transpose())); + res.noalias() -= square - m1 * m2.transpose(); + VERIFY_IS_APPROX(res, square - m1 * m2.transpose()); + tm1 = m1; VERIFY_IS_APPROX(tm1.transpose() * v1, m1.transpose() * v1); diff --git a/test/product_notemporary.cpp b/test/product_notemporary.cpp index 5a3f3a01a..2bb19a681 100644 --- a/test/product_notemporary.cpp +++ b/test/product_notemporary.cpp @@ -56,6 +56,9 @@ template<typename MatrixType> void product_notemporary(const MatrixType& m) VERIFY_EVALUATION_COUNT( m3.noalias() = m3 + m1 * m2.transpose(), 0); VERIFY_EVALUATION_COUNT( m3.noalias() += m3 + m1 * m2.transpose(), 0); VERIFY_EVALUATION_COUNT( m3.noalias() -= m3 + m1 * m2.transpose(), 0); + VERIFY_EVALUATION_COUNT( m3.noalias() = m3 - m1 * m2.transpose(), 0); + VERIFY_EVALUATION_COUNT( m3.noalias() += m3 - m1 * m2.transpose(), 0); + VERIFY_EVALUATION_COUNT( m3.noalias() -= m3 - m1 * m2.transpose(), 0); VERIFY_EVALUATION_COUNT( m3.noalias() = s1 * m1 * s2 * m2.adjoint(), 0); VERIFY_EVALUATION_COUNT( m3.noalias() = s1 * m1 * s2 * (m1*s3+m2*s2).adjoint(), 1); |