diff options
author | Gael Guennebaud <g.gael@free.fr> | 2013-12-14 22:53:47 +0100 |
---|---|---|
committer | Gael Guennebaud <g.gael@free.fr> | 2013-12-14 22:53:47 +0100 |
commit | d357bbd9c06f4b6088de0a8e47b3e56fdd0b99b3 (patch) | |
tree | 344a2618a4d283f08816498f6da210e90bf4bb54 /Eigen/src/Core/ProductEvaluators.h | |
parent | 27c068e9d6230398b74a1c7b7146d7842c509de7 (diff) |
Fix a few regression regarding temporaries and products
Diffstat (limited to 'Eigen/src/Core/ProductEvaluators.h')
-rw-r--r-- | Eigen/src/Core/ProductEvaluators.h | 72 |
1 files changed, 68 insertions, 4 deletions
diff --git a/Eigen/src/Core/ProductEvaluators.h b/Eigen/src/Core/ProductEvaluators.h index f0eb57d67..46048882b 100644 --- a/Eigen/src/Core/ProductEvaluators.h +++ b/Eigen/src/Core/ProductEvaluators.h @@ -19,7 +19,7 @@ namespace internal { // Like more general binary expressions, products need their own evaluator: template< typename T, - int ProductTag = internal::product_tag<typename T::Lhs,typename T::Rhs>::ret, + int ProductTag = internal::product_type<typename T::Lhs,typename T::Rhs>::ret, typename LhsShape = typename evaluator_traits<typename T::Lhs>::Shape, typename RhsShape = typename evaluator_traits<typename T::Rhs>::Shape, typename LhsScalar = typename T::Lhs::Scalar, @@ -38,7 +38,43 @@ struct evaluator<Product<Lhs, Rhs, Options> > evaluator(const XprType& xpr) : Base(xpr) {} }; + +// Catch scalar * ( A * B ) and transform it to (A*scalar) * B +// TODO we should apply that rule if that's really helpful +template<typename Lhs, typename Rhs, typename Scalar> +struct evaluator<CwiseUnaryOp<internal::scalar_multiple_op<Scalar>, const Product<Lhs, Rhs, DefaultProduct> > > + : public evaluator<Product<CwiseUnaryOp<internal::scalar_multiple_op<Scalar>,const Lhs>, Rhs, DefaultProduct> > +{ + typedef CwiseUnaryOp<internal::scalar_multiple_op<Scalar>, const Product<Lhs, Rhs, DefaultProduct> > XprType; + typedef evaluator<Product<CwiseUnaryOp<internal::scalar_multiple_op<Scalar>,const Lhs>, Rhs, DefaultProduct> > Base; + + typedef evaluator type; + typedef evaluator nestedType; + + evaluator(const XprType& xpr) + : Base(xpr.functor().m_other * xpr.nestedExpression().lhs() * xpr.nestedExpression().rhs()) + {} +}; + + +template<typename Lhs, typename Rhs, int DiagIndex> +struct evaluator<Diagonal<const Product<Lhs, Rhs, DefaultProduct>, DiagIndex> > + : public evaluator<Diagonal<const Product<Lhs, Rhs, LazyProduct>, DiagIndex> > +{ + typedef Diagonal<const Product<Lhs, Rhs, DefaultProduct>, DiagIndex> XprType; + typedef evaluator<Diagonal<const Product<Lhs, Rhs, LazyProduct>, DiagIndex> > Base; + typedef evaluator type; + typedef evaluator nestedType; +// + evaluator(const XprType& xpr) + : Base(Diagonal<const Product<Lhs, Rhs, LazyProduct>, DiagIndex>( + Product<Lhs, Rhs, LazyProduct>(xpr.nestedExpression().lhs(), xpr.nestedExpression().rhs()), + xpr.index() )) + {} +}; + + // Helper class to perform a matrix product with the destination at hand. // Depending on the sizes of the factors, there are different evaluation strategies // as controlled by internal::product_type. @@ -108,6 +144,23 @@ struct Assignment<DstXprType, Product<Lhs,Rhs,DefaultProduct>, internal::sub_ass } }; + +// Dense ?= scalar * Product +// TODO we should apply that rule if that's really helpful +// for instance, this is not good for inner products +template< typename DstXprType, typename Lhs, typename Rhs, typename AssignFunc, typename Scalar, typename ScalarBis> +struct Assignment<DstXprType, CwiseUnaryOp<internal::scalar_multiple_op<ScalarBis>, + const Product<Lhs,Rhs,DefaultProduct> >, AssignFunc, Dense2Dense, Scalar> +{ + typedef CwiseUnaryOp<internal::scalar_multiple_op<ScalarBis>, + const Product<Lhs,Rhs,DefaultProduct> > SrcXprType; + static void run(DstXprType &dst, const SrcXprType &src, const AssignFunc& func) + { + call_assignment(dst.noalias(), (src.functor().m_other * src.nestedExpression().lhs()) * src.nestedExpression().rhs(), func); + } +}; + + template<typename Lhs, typename Rhs> struct generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,InnerProduct> { @@ -255,9 +308,9 @@ struct generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,CoeffBasedProductMode> }; // This specialization enforces the use of a coefficient-based evaluation strategy -template<typename Lhs, typename Rhs> -struct generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,LazyCoeffBasedProductMode> - : generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,CoeffBasedProductMode> {}; +// template<typename Lhs, typename Rhs> +// struct generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,LazyCoeffBasedProductMode> +// : generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,CoeffBasedProductMode> {}; // Case 2: Evaluate coeff by coeff // @@ -347,6 +400,17 @@ protected: Index m_innerDim; }; +template<typename Lhs, typename Rhs> +struct product_evaluator<Product<Lhs, Rhs, DefaultProduct>, LazyCoeffBasedProductMode, DenseShape, DenseShape, typename Lhs::Scalar, typename Rhs::Scalar > + : product_evaluator<Product<Lhs, Rhs, LazyProduct>, CoeffBasedProductMode, DenseShape, DenseShape, typename Lhs::Scalar, typename Rhs::Scalar > +{ + typedef Product<Lhs, Rhs, DefaultProduct> XprType; + typedef Product<Lhs, Rhs, LazyProduct> BaseProduct; + typedef product_evaluator<BaseProduct, CoeffBasedProductMode, DenseShape, DenseShape, typename Lhs::Scalar, typename Rhs::Scalar > Base; + product_evaluator(const XprType& xpr) + : Base(BaseProduct(xpr.lhs(),xpr.rhs())) + {} +}; /*************************************************************************** * Normal product .coeff() implementation (with meta-unrolling) |