diff options
Diffstat (limited to 'Eigen/src/Core/ProductEvaluators.h')
-rw-r--r-- | Eigen/src/Core/ProductEvaluators.h | 146 |
1 files changed, 143 insertions, 3 deletions
diff --git a/Eigen/src/Core/ProductEvaluators.h b/Eigen/src/Core/ProductEvaluators.h index 2279ec33b..d991ff8b5 100644 --- a/Eigen/src/Core/ProductEvaluators.h +++ b/Eigen/src/Core/ProductEvaluators.h @@ -309,9 +309,9 @@ struct generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,CoeffBasedProductMode> }; // This specialization enforces the use of a coefficient-based evaluation strategy -// template<typename Lhs, typename Rhs> -// struct generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,LazyCoeffBasedProductMode> -// : generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,CoeffBasedProductMode> {}; +template<typename Lhs, typename Rhs> +struct generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,LazyCoeffBasedProductMode> + : generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,CoeffBasedProductMode> {}; // Case 2: Evaluate coeff by coeff // @@ -764,6 +764,146 @@ protected: PlainObject m_result; }; +/*************************************************************************** +* Diagonal products +***************************************************************************/ + +template<typename MatrixType, typename DiagonalType, typename Derived> +struct diagonal_product_evaluator_base + : evaluator_base<Derived> +{ + typedef typename MatrixType::Index Index; + typedef typename MatrixType::Scalar Scalar; + typedef typename MatrixType::PacketScalar PacketScalar; +public: + diagonal_product_evaluator_base(const MatrixType &mat, const DiagonalType &diag) + : m_diagImpl(diag), m_matImpl(mat) + { + } + + EIGEN_STRONG_INLINE const Scalar coeff(Index idx) const + { + return m_diagImpl.coeff(idx) * m_matImpl.coeff(idx); + } + +protected: + template<int LoadMode> + EIGEN_STRONG_INLINE PacketScalar packet_impl(Index row, Index col, Index id, internal::true_type) const + { + return internal::pmul(m_matImpl.template packet<LoadMode>(row, col), + internal::pset1<PacketScalar>(m_diagImpl.coeff(id))); + } + + template<int LoadMode> + EIGEN_STRONG_INLINE PacketScalar packet_impl(Index row, Index col, Index id, internal::false_type) const + { + enum { + InnerSize = (MatrixType::Flags & RowMajorBit) ? MatrixType::ColsAtCompileTime : MatrixType::RowsAtCompileTime, + DiagonalPacketLoadMode = (LoadMode == Aligned && (((InnerSize%16) == 0) || (int(DiagonalType::Flags)&AlignedBit)==AlignedBit) ? Aligned : Unaligned) + }; + return internal::pmul(m_matImpl.template packet<LoadMode>(row, col), + m_diagImpl.template packet<DiagonalPacketLoadMode>(id)); + } + + typename evaluator<DiagonalType>::nestedType m_diagImpl; + typename evaluator<MatrixType>::nestedType m_matImpl; +}; + +// diagonal * dense +template<typename Lhs, typename Rhs, int ProductKind, int ProductTag> +struct product_evaluator<Product<Lhs, Rhs, ProductKind>, ProductTag, DiagonalShape, DenseShape, typename Lhs::Scalar, typename Rhs::Scalar> + : diagonal_product_evaluator_base<Rhs, typename Lhs::DiagonalVectorType, Product<Lhs, Rhs, LazyProduct> > +{ + typedef diagonal_product_evaluator_base<Rhs, typename Lhs::DiagonalVectorType, Product<Lhs, Rhs, LazyProduct> > Base; + using Base::m_diagImpl; + using Base::m_matImpl; + using Base::coeff; + using Base::packet_impl; + typedef typename Base::Scalar Scalar; + typedef typename Base::Index Index; + typedef typename Base::PacketScalar PacketScalar; + + typedef Product<Lhs, Rhs, ProductKind> XprType; + typedef typename XprType::PlainObject PlainObject; + + product_evaluator(const XprType& xpr) + : Base(xpr.rhs(), xpr.lhs().diagonal()) + { + } + + EIGEN_STRONG_INLINE const Scalar coeff(Index row, Index col) const + { + return m_diagImpl.coeff(row) * m_matImpl.coeff(row, col); + } + + template<int LoadMode> + EIGEN_STRONG_INLINE PacketScalar packet(Index row, Index col) const + { + enum { + StorageOrder = Rhs::Flags & RowMajorBit ? RowMajor : ColMajor + }; + return this->template packet_impl<LoadMode>(row,col, row, + typename internal::conditional<int(StorageOrder)==RowMajor, internal::true_type, internal::false_type>::type()); + } + + template<int LoadMode> + EIGEN_STRONG_INLINE PacketScalar packet(Index idx) const + { + enum { + StorageOrder = int(Rhs::Flags) & RowMajorBit ? RowMajor : ColMajor + }; + return packet<LoadMode>(int(StorageOrder)==ColMajor?idx:0,int(StorageOrder)==ColMajor?0:idx); + } + +}; + +// dense * diagonal +template<typename Lhs, typename Rhs, int ProductKind, int ProductTag> +struct product_evaluator<Product<Lhs, Rhs, ProductKind>, ProductTag, DenseShape, DiagonalShape, typename Lhs::Scalar, typename Rhs::Scalar> + : diagonal_product_evaluator_base<Lhs, typename Rhs::DiagonalVectorType, Product<Lhs, Rhs, LazyProduct> > +{ + typedef diagonal_product_evaluator_base<Lhs, typename Rhs::DiagonalVectorType, Product<Lhs, Rhs, LazyProduct> > Base; + using Base::m_diagImpl; + using Base::m_matImpl; + using Base::coeff; + using Base::packet_impl; + typedef typename Base::Scalar Scalar; + typedef typename Base::Index Index; + typedef typename Base::PacketScalar PacketScalar; + + typedef Product<Lhs, Rhs, ProductKind> XprType; + typedef typename XprType::PlainObject PlainObject; + + product_evaluator(const XprType& xpr) + : Base(xpr.lhs(), xpr.rhs().diagonal()) + { + } + + EIGEN_STRONG_INLINE const Scalar coeff(Index row, Index col) const + { + return m_matImpl.coeff(row, col) * m_diagImpl.coeff(col); + } + + template<int LoadMode> + EIGEN_STRONG_INLINE PacketScalar packet(Index row, Index col) const + { + enum { + StorageOrder = Rhs::Flags & RowMajorBit ? RowMajor : ColMajor + }; + return this->template packet_impl<LoadMode>(row,col, col, + typename internal::conditional<int(StorageOrder)==ColMajor, internal::true_type, internal::false_type>::type()); + } + + template<int LoadMode> + EIGEN_STRONG_INLINE PacketScalar packet(Index idx) const + { + enum { + StorageOrder = int(Rhs::Flags) & RowMajorBit ? RowMajor : ColMajor + }; + return packet<LoadMode>(int(StorageOrder)==ColMajor?idx:0,int(StorageOrder)==ColMajor?0:idx); + } + +}; } // end namespace internal |