From fcf4457b781831b51ac70d4141b29b062d29fdf3 Mon Sep 17 00:00:00 2001 From: Gael Guennebaud Date: Sat, 31 May 2008 21:35:11 +0000 Subject: added optimized matrix times diagonal matrix product via Diagonal flag shortcut. --- Eigen/src/Array/CwiseOperators.h | 8 ++++++ Eigen/src/Array/Functors.h | 13 +++++++++ Eigen/src/Core/DiagonalMatrix.h | 2 +- Eigen/src/Core/MatrixBase.h | 1 + Eigen/src/Core/Product.h | 44 +++++++++++++++++++++---------- Eigen/src/Core/util/Constants.h | 1 + Eigen/src/Core/util/ForwardDeclarations.h | 1 + Eigen/src/Core/util/Meta.h | 3 +-- 8 files changed, 56 insertions(+), 17 deletions(-) (limited to 'Eigen/src') diff --git a/Eigen/src/Array/CwiseOperators.h b/Eigen/src/Array/CwiseOperators.h index 82a18d11c..f9f9e1267 100644 --- a/Eigen/src/Array/CwiseOperators.h +++ b/Eigen/src/Array/CwiseOperators.h @@ -76,6 +76,14 @@ MatrixBase::cwisePow(const Scalar& exponent) const (derived(), ei_scalar_pow_op(exponent)); } +/** \returns an expression of the coefficient-wise reciprocal of *this. */ +template +inline const CwiseUnaryOp::Scalar>, Derived> +MatrixBase::cwiseInverse() const +{ + return derived(); +} + // -- binary operators -- /** \returns an expression of the coefficient-wise \< operator of *this and \a other diff --git a/Eigen/src/Array/Functors.h b/Eigen/src/Array/Functors.h index 805f2dcad..3169106d8 100644 --- a/Eigen/src/Array/Functors.h +++ b/Eigen/src/Array/Functors.h @@ -100,6 +100,19 @@ template struct ei_functor_traits > { enum { Cost = 5 * NumTraits::MulCost, IsVectorizable = false }; }; +/** \internal + * \brief Template functor to compute the reciprocal of a scalar + * + * \sa class CwiseUnaryOp, MatrixBase::cwiseInverse + */ +template +struct ei_scalar_inverse_op { + inline Scalar operator() (const Scalar& a) const { return Scalar(1)/a; } +}; +template +struct ei_functor_traits > +{ enum { Cost = NumTraits::MulCost, IsVectorizable = false }; }; + // default ei_functor_traits for STL functors: template diff --git a/Eigen/src/Core/DiagonalMatrix.h b/Eigen/src/Core/DiagonalMatrix.h index 0581c669c..fa343e7d7 100644 --- a/Eigen/src/Core/DiagonalMatrix.h +++ b/Eigen/src/Core/DiagonalMatrix.h @@ -49,7 +49,7 @@ struct ei_traits > ColsAtCompileTime = CoeffsVectorType::SizeAtCompileTime, MaxRowsAtCompileTime = CoeffsVectorType::MaxSizeAtCompileTime, MaxColsAtCompileTime = CoeffsVectorType::MaxSizeAtCompileTime, - Flags = _CoeffsVectorTypeTypeNested::Flags & HereditaryBits, + Flags = _CoeffsVectorTypeTypeNested::Flags & HereditaryBits | Diagonal, CoeffReadCost = _CoeffsVectorTypeTypeNested::CoeffReadCost }; }; diff --git a/Eigen/src/Core/MatrixBase.h b/Eigen/src/Core/MatrixBase.h index fc1d524bd..ab35b94a6 100644 --- a/Eigen/src/Core/MatrixBase.h +++ b/Eigen/src/Core/MatrixBase.h @@ -544,6 +544,7 @@ template class MatrixBase : public ArrayBase const CwiseUnaryOp::Scalar>, Derived> cwiseSin() const; const CwiseUnaryOp::Scalar>, Derived> cwisePow(const Scalar& exponent) const; + const CwiseUnaryOp::Scalar>, Derived> cwiseInverse() const; template const CwiseBinaryOp::Scalar>, Derived, OtherDerived> diff --git a/Eigen/src/Core/Product.h b/Eigen/src/Core/Product.h index 1def64777..7643ac610 100644 --- a/Eigen/src/Core/Product.h +++ b/Eigen/src/Core/Product.h @@ -147,6 +147,7 @@ template struct ei_product_eval_mode enum{ value = Lhs::MaxRowsAtCompileTime >= EIGEN_CACHEFRIENDLY_PRODUCT_THRESHOLD && Rhs::MaxColsAtCompileTime >= EIGEN_CACHEFRIENDLY_PRODUCT_THRESHOLD && Lhs::MaxColsAtCompileTime >= EIGEN_CACHEFRIENDLY_PRODUCT_THRESHOLD + && Rhs::Flags&Diagonal!=Diagonal ? CacheFriendlyProduct : NormalProduct }; }; @@ -259,25 +260,40 @@ template class Product : ei_no_assignm const Scalar _coeff(int row, int col) const { - Scalar res; - const bool unroll = CoeffReadCost <= EIGEN_UNROLLING_LIMIT; - ei_product_impl - ::run(row, col, m_lhs, m_rhs, res); - return res; + if ((Rhs::Flags&Diagonal)==Diagonal) + { + return m_lhs.coeff(row, col) * m_rhs.coeff(col, col); + } + else + { + Scalar res; + const bool unroll = CoeffReadCost <= EIGEN_UNROLLING_LIMIT; + ei_product_impl + ::run(row, col, m_lhs, m_rhs, res); + return res; + } } template const PacketScalar _packetCoeff(int row, int col) const { - const bool unroll = CoeffReadCost <= EIGEN_UNROLLING_LIMIT; - PacketScalar res; - ei_packet_product_impl - ::run(row, col, m_lhs, m_rhs, res); - return res; + if ((Rhs::Flags&Diagonal)==Diagonal) + { + assert(_LhsNested::Flags&RowMajorBit==0); + return ei_pmul(m_lhs.template packetCoeff(row, col), ei_pset1(m_rhs.coeff(col, col))); + } + else + { + const bool unroll = CoeffReadCost <= EIGEN_UNROLLING_LIMIT; + PacketScalar res; + ei_packet_product_impl + ::run(row, col, m_lhs, m_rhs, res); + return res; + } } template diff --git a/Eigen/src/Core/util/Constants.h b/Eigen/src/Core/util/Constants.h index 909921f35..e17563c9a 100644 --- a/Eigen/src/Core/util/Constants.h +++ b/Eigen/src/Core/util/Constants.h @@ -66,6 +66,7 @@ const unsigned int SelfAdjoint = SelfAdjointBit; // additional possible values for the Mode parameter of extract() const unsigned int UnitUpper = UpperTriangularBit | UnitDiagBit; const unsigned int UnitLower = LowerTriangularBit | UnitDiagBit; +const unsigned int Diagonal = Upper | Lower; diff --git a/Eigen/src/Core/util/ForwardDeclarations.h b/Eigen/src/Core/util/ForwardDeclarations.h index cc759ddea..9fa3718d3 100644 --- a/Eigen/src/Core/util/ForwardDeclarations.h +++ b/Eigen/src/Core/util/ForwardDeclarations.h @@ -67,6 +67,7 @@ template struct ei_scalar_log_op; template struct ei_scalar_cos_op; template struct ei_scalar_sin_op; template struct ei_scalar_pow_op; +template struct ei_scalar_inverse_op; template struct ei_scalar_cast_op; template struct ei_scalar_multiple_op; template struct ei_scalar_quotient1_op; diff --git a/Eigen/src/Core/util/Meta.h b/Eigen/src/Core/util/Meta.h index f1939f59e..1ce7f3cfe 100644 --- a/Eigen/src/Core/util/Meta.h +++ b/Eigen/src/Core/util/Meta.h @@ -212,8 +212,7 @@ template struct ei_nested template struct ei_are_flags_consistent { - enum { ret = !( (Flags&UnitDiagBit && Flags&ZeroDiagBit) - || (Flags&UpperTriangularBit && Flags&LowerTriangularBit) ) + enum { ret = !( (Flags&UnitDiagBit && Flags&ZeroDiagBit) ) }; }; -- cgit v1.2.3