diff options
author | Gael Guennebaud <g.gael@free.fr> | 2014-02-23 22:51:13 +0100 |
---|---|---|
committer | Gael Guennebaud <g.gael@free.fr> | 2014-02-23 22:51:13 +0100 |
commit | c98881e1306c94c53ad3de77f9e9036d98dbcf2a (patch) | |
tree | 0f76a14db3be2fc62b363ef99b32917d3be9cf9d /Eigen/src/Core | |
parent | d67548f345d01c69d9dbba5869d8cc0159e96464 (diff) |
By-pass ProductBase for triangular and selfadjoint products and get rid of ProductBase
Diffstat (limited to 'Eigen/src/Core')
-rw-r--r-- | Eigen/src/Core/GeneralProduct.h | 1 | ||||
-rw-r--r-- | Eigen/src/Core/ProductBase.h | 4 | ||||
-rw-r--r-- | Eigen/src/Core/ProductEvaluators.h | 20 | ||||
-rw-r--r-- | Eigen/src/Core/SelfAdjointView.h | 2 | ||||
-rw-r--r-- | Eigen/src/Core/products/SelfadjointMatrixMatrix.h | 52 | ||||
-rw-r--r-- | Eigen/src/Core/products/SelfadjointMatrixVector.h | 104 | ||||
-rw-r--r-- | Eigen/src/Core/products/TriangularMatrixMatrix.h | 53 | ||||
-rw-r--r-- | Eigen/src/Core/products/TriangularMatrixVector.h | 121 |
8 files changed, 301 insertions, 56 deletions
diff --git a/Eigen/src/Core/GeneralProduct.h b/Eigen/src/Core/GeneralProduct.h index 57d5d3c38..06aa05ee6 100644 --- a/Eigen/src/Core/GeneralProduct.h +++ b/Eigen/src/Core/GeneralProduct.h @@ -688,7 +688,6 @@ template<> struct gemv_selector<OnTheRight,RowMajor,true> typedef typename Lhs::Scalar LhsScalar; typedef typename Rhs::Scalar RhsScalar; typedef typename Dest::Scalar ResScalar; - typedef typename Dest::RealScalar RealScalar; typedef internal::blas_traits<Lhs> LhsBlasTraits; typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType; diff --git a/Eigen/src/Core/ProductBase.h b/Eigen/src/Core/ProductBase.h index f7cef9a9e..3b2246fd8 100644 --- a/Eigen/src/Core/ProductBase.h +++ b/Eigen/src/Core/ProductBase.h @@ -11,6 +11,8 @@ #define EIGEN_PRODUCTBASE_H namespace Eigen { + +#ifndef EIGEN_TEST_EVALUATORS /** \class ProductBase * \ingroup Core_Module @@ -174,8 +176,6 @@ class ProductBase : public MatrixBase<Derived> mutable PlainObject m_result; }; -#ifndef EIGEN_TEST_EVALUATORS - // here we need to overload the nested rule for products // such that the nested type is a const reference to a plain matrix namespace internal { diff --git a/Eigen/src/Core/ProductEvaluators.h b/Eigen/src/Core/ProductEvaluators.h index 5f9cf69a2..186ae4a34 100644 --- a/Eigen/src/Core/ProductEvaluators.h +++ b/Eigen/src/Core/ProductEvaluators.h @@ -532,6 +532,10 @@ struct etor_product_packet_impl<ColMajor, Dynamic, Lhs, Rhs, Packet, LoadMode> /*************************************************************************** * Triangular products ***************************************************************************/ +template<int Mode, bool LhsIsTriangular, + typename Lhs, bool LhsIsVector, + typename Rhs, bool RhsIsVector> +struct triangular_product_impl; template<typename Lhs, typename Rhs, int ProductTag> struct generic_product_impl<Lhs,Rhs,TriangularShape,DenseShape,ProductTag> @@ -542,8 +546,8 @@ struct generic_product_impl<Lhs,Rhs,TriangularShape,DenseShape,ProductTag> template<typename Dest> static void scaleAndAddTo(Dest& dst, const Lhs& lhs, const Rhs& rhs, const Scalar& alpha) { - // TODO bypass TriangularProduct class - TriangularProduct<Lhs::Mode,true,typename Lhs::MatrixType,false,Rhs, Rhs::IsVectorAtCompileTime>(lhs.nestedExpression(),rhs).scaleAndAddTo(dst, alpha); + triangular_product_impl<Lhs::Mode,true,typename Lhs::MatrixType,false,Rhs, Rhs::IsVectorAtCompileTime> + ::run(dst, lhs.nestedExpression(), rhs, alpha); } }; @@ -576,8 +580,7 @@ struct generic_product_impl<Lhs,Rhs,DenseShape,TriangularShape,ProductTag> template<typename Dest> static void scaleAndAddTo(Dest& dst, const Lhs& lhs, const Rhs& rhs, const Scalar& alpha) { - // TODO bypass TriangularProduct class - TriangularProduct<Rhs::Mode,false,Lhs,Lhs::IsVectorAtCompileTime, typename Rhs::MatrixType, false>(lhs,rhs.nestedExpression()).scaleAndAddTo(dst, alpha); + triangular_product_impl<Rhs::Mode,false,Lhs,Lhs::IsVectorAtCompileTime, typename Rhs::MatrixType, false>::run(dst, lhs, rhs.nestedExpression(), alpha); } }; @@ -605,6 +608,9 @@ protected: /*************************************************************************** * SelfAdjoint products ***************************************************************************/ +template <typename Lhs, int LhsMode, bool LhsIsVector, + typename Rhs, int RhsMode, bool RhsIsVector> +struct selfadjoint_product_impl; template<typename Lhs, typename Rhs, int ProductTag> struct generic_product_impl<Lhs,Rhs,SelfAdjointShape,DenseShape,ProductTag> @@ -615,8 +621,7 @@ struct generic_product_impl<Lhs,Rhs,SelfAdjointShape,DenseShape,ProductTag> template<typename Dest> static void scaleAndAddTo(Dest& dst, const Lhs& lhs, const Rhs& rhs, const Scalar& alpha) { - // TODO bypass SelfadjointProductMatrix class - SelfadjointProductMatrix<typename Lhs::MatrixType,Lhs::Mode,false,Rhs,0,Rhs::IsVectorAtCompileTime>(lhs.nestedExpression(),rhs).scaleAndAddTo(dst, alpha); + selfadjoint_product_impl<typename Lhs::MatrixType,Lhs::Mode,false,Rhs,0,Rhs::IsVectorAtCompileTime>::run(dst, lhs.nestedExpression(), rhs, alpha); } }; @@ -649,8 +654,7 @@ struct generic_product_impl<Lhs,Rhs,DenseShape,SelfAdjointShape,ProductTag> template<typename Dest> static void scaleAndAddTo(Dest& dst, const Lhs& lhs, const Rhs& rhs, const Scalar& alpha) { - // TODO bypass SelfadjointProductMatrix class - SelfadjointProductMatrix<Lhs,0,Lhs::IsVectorAtCompileTime,typename Rhs::MatrixType,Rhs::Mode,false>(lhs,rhs.nestedExpression()).scaleAndAddTo(dst, alpha); + selfadjoint_product_impl<Lhs,0,Lhs::IsVectorAtCompileTime,typename Rhs::MatrixType,Rhs::Mode,false>::run(dst, lhs, rhs.nestedExpression(), alpha); } }; diff --git a/Eigen/src/Core/SelfAdjointView.h b/Eigen/src/Core/SelfAdjointView.h index 2cc1815fd..f7f512cf4 100644 --- a/Eigen/src/Core/SelfAdjointView.h +++ b/Eigen/src/Core/SelfAdjointView.h @@ -45,9 +45,11 @@ struct traits<SelfAdjointView<MatrixType, UpLo> > : traits<MatrixType> }; } +#ifndef EIGEN_TEST_EVALUATORS template <typename Lhs, int LhsMode, bool LhsIsVector, typename Rhs, int RhsMode, bool RhsIsVector> struct SelfadjointProductMatrix; +#endif // FIXME could also be called SelfAdjointWrapper to be consistent with DiagonalWrapper ?? template<typename _MatrixType, unsigned int UpLo> class SelfAdjointView diff --git a/Eigen/src/Core/products/SelfadjointMatrixMatrix.h b/Eigen/src/Core/products/SelfadjointMatrixMatrix.h index 99cf9e0ae..f252aef85 100644 --- a/Eigen/src/Core/products/SelfadjointMatrixMatrix.h +++ b/Eigen/src/Core/products/SelfadjointMatrixMatrix.h @@ -381,6 +381,7 @@ EIGEN_DONT_INLINE void product_selfadjoint_matrix<Scalar,Index,LhsStorageOrder,f * Wrapper to product_selfadjoint_matrix ***************************************************************************/ +#ifndef EIGEN_TEST_EVALUATORS namespace internal { template<typename Lhs, int LhsMode, typename Rhs, int RhsMode> struct traits<SelfadjointProductMatrix<Lhs,LhsMode,false,Rhs,RhsMode,false> > @@ -430,6 +431,57 @@ struct SelfadjointProductMatrix<Lhs,LhsMode,false,Rhs,RhsMode,false> ); } }; +#else // EIGEN_TEST_EVALUATORS +namespace internal { + +template<typename Lhs, int LhsMode, typename Rhs, int RhsMode> +struct selfadjoint_product_impl<Lhs,LhsMode,false,Rhs,RhsMode,false> +{ + typedef typename Product<Lhs,Rhs>::Scalar Scalar; + typedef typename Product<Lhs,Rhs>::Index Index; + + typedef internal::blas_traits<Lhs> LhsBlasTraits; + typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType; + typedef internal::blas_traits<Rhs> RhsBlasTraits; + typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType; + + enum { + LhsIsUpper = (LhsMode&(Upper|Lower))==Upper, + LhsIsSelfAdjoint = (LhsMode&SelfAdjoint)==SelfAdjoint, + RhsIsUpper = (RhsMode&(Upper|Lower))==Upper, + RhsIsSelfAdjoint = (RhsMode&SelfAdjoint)==SelfAdjoint + }; + + template<typename Dest> + static void run(Dest &dst, const Lhs &a_lhs, const Rhs &a_rhs, const Scalar& alpha) + { + eigen_assert(dst.rows()==a_lhs.rows() && dst.cols()==a_rhs.cols()); + + typename internal::add_const_on_value_type<ActualLhsType>::type lhs = LhsBlasTraits::extract(a_lhs); + typename internal::add_const_on_value_type<ActualRhsType>::type rhs = RhsBlasTraits::extract(a_rhs); + + Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(a_lhs) + * RhsBlasTraits::extractScalarFactor(a_rhs); + + internal::product_selfadjoint_matrix<Scalar, Index, + EIGEN_LOGICAL_XOR(LhsIsUpper,internal::traits<Lhs>::Flags &RowMajorBit) ? RowMajor : ColMajor, LhsIsSelfAdjoint, + NumTraits<Scalar>::IsComplex && EIGEN_LOGICAL_XOR(LhsIsUpper,bool(LhsBlasTraits::NeedToConjugate)), + EIGEN_LOGICAL_XOR(RhsIsUpper,internal::traits<Rhs>::Flags &RowMajorBit) ? RowMajor : ColMajor, RhsIsSelfAdjoint, + NumTraits<Scalar>::IsComplex && EIGEN_LOGICAL_XOR(RhsIsUpper,bool(RhsBlasTraits::NeedToConjugate)), + internal::traits<Dest>::Flags&RowMajorBit ? RowMajor : ColMajor> + ::run( + lhs.rows(), rhs.cols(), // sizes + &lhs.coeffRef(0,0), lhs.outerStride(), // lhs info + &rhs.coeffRef(0,0), rhs.outerStride(), // rhs info + &dst.coeffRef(0,0), dst.outerStride(), // result info + actualAlpha // alpha + ); + } +}; + +} // end namespace internal + +#endif } // end namespace Eigen diff --git a/Eigen/src/Core/products/SelfadjointMatrixVector.h b/Eigen/src/Core/products/SelfadjointMatrixVector.h index f698f67f9..ddc07d535 100644 --- a/Eigen/src/Core/products/SelfadjointMatrixVector.h +++ b/Eigen/src/Core/products/SelfadjointMatrixVector.h @@ -168,6 +168,7 @@ EIGEN_DONT_INLINE void selfadjoint_matrix_vector_product<Scalar,Index,StorageOrd * Wrapper to product_selfadjoint_vector ***************************************************************************/ +#ifndef EIGEN_TEST_EVALUATORS namespace internal { template<typename Lhs, int LhsMode, typename Rhs> struct traits<SelfadjointProductMatrix<Lhs,LhsMode,false,Rhs,0,true> > @@ -276,6 +277,109 @@ struct SelfadjointProductMatrix<Lhs,0,true,Rhs,RhsMode,false> } }; +#else // EIGEN_TEST_EVALUATORS + +namespace internal { + +template<typename Lhs, int LhsMode, typename Rhs> +struct selfadjoint_product_impl<Lhs,LhsMode,false,Rhs,0,true> +{ + typedef typename Product<Lhs,Rhs>::Scalar Scalar; + typedef typename Product<Lhs,Rhs>::Index Index; + + typedef internal::blas_traits<Lhs> LhsBlasTraits; + typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType; + typedef typename internal::remove_all<ActualLhsType>::type ActualLhsTypeCleaned; + + typedef internal::blas_traits<Rhs> RhsBlasTraits; + typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType; + typedef typename internal::remove_all<ActualRhsType>::type ActualRhsTypeCleaned; + + enum { LhsUpLo = LhsMode&(Upper|Lower) }; + + template<typename Dest> + static void run(Dest& dest, const Lhs &a_lhs, const Rhs &a_rhs, const Scalar& alpha) + { + typedef typename Dest::Scalar ResScalar; + typedef typename Rhs::Scalar RhsScalar; + typedef Map<Matrix<ResScalar,Dynamic,1>, Aligned> MappedDest; + + eigen_assert(dest.rows()==a_lhs.rows() && dest.cols()==a_rhs.cols()); + + typename internal::add_const_on_value_type<ActualLhsType>::type lhs = LhsBlasTraits::extract(a_lhs); + typename internal::add_const_on_value_type<ActualRhsType>::type rhs = RhsBlasTraits::extract(a_rhs); + + Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(a_lhs) + * RhsBlasTraits::extractScalarFactor(a_rhs); + + enum { + EvalToDest = (Dest::InnerStrideAtCompileTime==1), + UseRhs = (ActualRhsTypeCleaned::InnerStrideAtCompileTime==1) + }; + + internal::gemv_static_vector_if<ResScalar,Dest::SizeAtCompileTime,Dest::MaxSizeAtCompileTime,!EvalToDest> static_dest; + internal::gemv_static_vector_if<RhsScalar,ActualRhsTypeCleaned::SizeAtCompileTime,ActualRhsTypeCleaned::MaxSizeAtCompileTime,!UseRhs> static_rhs; + + ei_declare_aligned_stack_constructed_variable(ResScalar,actualDestPtr,dest.size(), + EvalToDest ? dest.data() : static_dest.data()); + + ei_declare_aligned_stack_constructed_variable(RhsScalar,actualRhsPtr,rhs.size(), + UseRhs ? const_cast<RhsScalar*>(rhs.data()) : static_rhs.data()); + + if(!EvalToDest) + { + #ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN + int size = dest.size(); + EIGEN_DENSE_STORAGE_CTOR_PLUGIN + #endif + MappedDest(actualDestPtr, dest.size()) = dest; + } + + if(!UseRhs) + { + #ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN + int size = rhs.size(); + EIGEN_DENSE_STORAGE_CTOR_PLUGIN + #endif + Map<typename ActualRhsTypeCleaned::PlainObject>(actualRhsPtr, rhs.size()) = rhs; + } + + + internal::selfadjoint_matrix_vector_product<Scalar, Index, (internal::traits<ActualLhsTypeCleaned>::Flags&RowMajorBit) ? RowMajor : ColMajor, + int(LhsUpLo), bool(LhsBlasTraits::NeedToConjugate), bool(RhsBlasTraits::NeedToConjugate)>::run + ( + lhs.rows(), // size + &lhs.coeffRef(0,0), lhs.outerStride(), // lhs info + actualRhsPtr, 1, // rhs info + actualDestPtr, // result info + actualAlpha // scale factor + ); + + if(!EvalToDest) + dest = MappedDest(actualDestPtr, dest.size()); + } +}; + +template<typename Lhs, typename Rhs, int RhsMode> +struct selfadjoint_product_impl<Lhs,0,true,Rhs,RhsMode,false> +{ + typedef typename Product<Lhs,Rhs>::Scalar Scalar; + enum { RhsUpLo = RhsMode&(Upper|Lower) }; + + template<typename Dest> + static void run(Dest& dest, const Lhs &a_lhs, const Rhs &a_rhs, const Scalar& alpha) + { + // let's simply transpose the product + Transpose<Dest> destT(dest); + selfadjoint_product_impl<Transpose<const Rhs>, int(RhsUpLo)==Upper ? Lower : Upper, false, + Transpose<const Lhs>, 0, true>::run(destT, a_rhs.transpose(), a_lhs.transpose(), alpha); + } +}; + +} // end namespace internal + +#endif // EIGEN_TEST_EVALUATORS + } // end namespace Eigen #endif // EIGEN_SELFADJOINT_MATRIX_VECTOR_H diff --git a/Eigen/src/Core/products/TriangularMatrixMatrix.h b/Eigen/src/Core/products/TriangularMatrixMatrix.h index 8110507b5..e654b45b1 100644 --- a/Eigen/src/Core/products/TriangularMatrixMatrix.h +++ b/Eigen/src/Core/products/TriangularMatrixMatrix.h @@ -372,7 +372,6 @@ EIGEN_DONT_INLINE void product_triangular_matrix_matrix<Scalar,Index,Mode,false, /*************************************************************************** * Wrapper to product_triangular_matrix_matrix ***************************************************************************/ - template<int Mode, bool LhsIsTriangular, typename Lhs, typename Rhs> struct traits<TriangularProduct<Mode,LhsIsTriangular,Lhs,false,Rhs,false> > : traits<ProductBase<TriangularProduct<Mode,LhsIsTriangular,Lhs,false,Rhs,false>, Lhs, Rhs> > @@ -380,6 +379,7 @@ struct traits<TriangularProduct<Mode,LhsIsTriangular,Lhs,false,Rhs,false> > } // end namespace internal +#ifndef EIGEN_TEST_EVALUATORS template<int Mode, bool LhsIsTriangular, typename Lhs, typename Rhs> struct TriangularProduct<Mode,LhsIsTriangular,Lhs,false,Rhs,false> : public ProductBase<TriangularProduct<Mode,LhsIsTriangular,Lhs,false,Rhs,false>, Lhs, Rhs > @@ -421,6 +421,57 @@ struct TriangularProduct<Mode,LhsIsTriangular,Lhs,false,Rhs,false> ); } }; +#else // EIGEN_TEST_EVALUATORS +namespace internal { +template<int Mode, bool LhsIsTriangular, typename Lhs, typename Rhs> +struct triangular_product_impl<Mode,LhsIsTriangular,Lhs,false,Rhs,false> +{ + template<typename Dest> static void run(Dest& dst, const Lhs &a_lhs, const Rhs &a_rhs, const typename Dest::Scalar& alpha) + { + typedef typename Dest::Index Index; + typedef typename Dest::Scalar Scalar; + + typedef internal::blas_traits<Lhs> LhsBlasTraits; + typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType; + typedef typename internal::remove_all<ActualLhsType>::type ActualLhsTypeCleaned; + typedef internal::blas_traits<Rhs> RhsBlasTraits; + typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType; + typedef typename internal::remove_all<ActualRhsType>::type ActualRhsTypeCleaned; + + typename internal::add_const_on_value_type<ActualLhsType>::type lhs = LhsBlasTraits::extract(a_lhs); + typename internal::add_const_on_value_type<ActualRhsType>::type rhs = RhsBlasTraits::extract(a_rhs); + + Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(a_lhs) + * RhsBlasTraits::extractScalarFactor(a_rhs); + + typedef internal::gemm_blocking_space<(Dest::Flags&RowMajorBit) ? RowMajor : ColMajor,Scalar,Scalar, + Lhs::MaxRowsAtCompileTime, Rhs::MaxColsAtCompileTime, Lhs::MaxColsAtCompileTime,4> BlockingType; + + enum { IsLower = (Mode&Lower) == Lower }; + Index stripedRows = ((!LhsIsTriangular) || (IsLower)) ? lhs.rows() : (std::min)(lhs.rows(),lhs.cols()); + Index stripedCols = ((LhsIsTriangular) || (!IsLower)) ? rhs.cols() : (std::min)(rhs.cols(),rhs.rows()); + Index stripedDepth = LhsIsTriangular ? ((!IsLower) ? lhs.cols() : (std::min)(lhs.cols(),lhs.rows())) + : ((IsLower) ? rhs.rows() : (std::min)(rhs.rows(),rhs.cols())); + + BlockingType blocking(stripedRows, stripedCols, stripedDepth); + + internal::product_triangular_matrix_matrix<Scalar, Index, + Mode, LhsIsTriangular, + (internal::traits<ActualLhsTypeCleaned>::Flags&RowMajorBit) ? RowMajor : ColMajor, LhsBlasTraits::NeedToConjugate, + (internal::traits<ActualRhsTypeCleaned>::Flags&RowMajorBit) ? RowMajor : ColMajor, RhsBlasTraits::NeedToConjugate, + (internal::traits<Dest >::Flags&RowMajorBit) ? RowMajor : ColMajor> + ::run( + stripedRows, stripedCols, stripedDepth, // sizes + &lhs.coeffRef(0,0), lhs.outerStride(), // lhs info + &rhs.coeffRef(0,0), rhs.outerStride(), // rhs info + &dst.coeffRef(0,0), dst.outerStride(), // result info + actualAlpha, blocking + ); + } +}; + +} // end namespace internal +#endif // EIGEN_TEST_EVALUATORS } // end namespace Eigen diff --git a/Eigen/src/Core/products/TriangularMatrixVector.h b/Eigen/src/Core/products/TriangularMatrixVector.h index 6117d5a82..eed7f4258 100644 --- a/Eigen/src/Core/products/TriangularMatrixVector.h +++ b/Eigen/src/Core/products/TriangularMatrixVector.h @@ -168,11 +168,12 @@ struct traits<TriangularProduct<Mode,LhsIsTriangular,Lhs,true,Rhs,false> > {}; -template<int StorageOrder> +template<int Mode,int StorageOrder> struct trmv_selector; } // end namespace internal +#ifndef EIGEN_TEST_EVALUATORS template<int Mode, typename Lhs, typename Rhs> struct TriangularProduct<Mode,true,Lhs,false,Rhs,true> : public ProductBase<TriangularProduct<Mode,true,Lhs,false,Rhs,true>, Lhs, Rhs > @@ -185,7 +186,7 @@ struct TriangularProduct<Mode,true,Lhs,false,Rhs,true> { eigen_assert(dst.rows()==m_lhs.rows() && dst.cols()==m_rhs.cols()); - internal::trmv_selector<(int(internal::traits<Lhs>::Flags)&RowMajorBit) ? RowMajor : ColMajor>::run(*this, dst, alpha); + internal::trmv_selector<Mode,(int(internal::traits<Lhs>::Flags)&RowMajorBit) ? RowMajor : ColMajor>::run(m_lhs, m_rhs, dst, alpha); } }; @@ -201,39 +202,71 @@ struct TriangularProduct<Mode,false,Lhs,true,Rhs,false> { eigen_assert(dst.rows()==m_lhs.rows() && dst.cols()==m_rhs.cols()); - typedef TriangularProduct<(Mode & (UnitDiag|ZeroDiag)) | ((Mode & Lower) ? Upper : Lower),true,Transpose<const Rhs>,false,Transpose<const Lhs>,true> TriangularProductTranspose; Transpose<Dest> dstT(dst); - internal::trmv_selector<(int(internal::traits<Rhs>::Flags)&RowMajorBit) ? ColMajor : RowMajor>::run( - TriangularProductTranspose(m_rhs.transpose(),m_lhs.transpose()), dstT, alpha); + internal::trmv_selector<(Mode & (UnitDiag|ZeroDiag)) | ((Mode & Lower) ? Upper : Lower), + (int(internal::traits<Rhs>::Flags)&RowMajorBit) ? ColMajor : RowMajor> + ::run(m_rhs.transpose(),m_lhs.transpose(), dstT, alpha); } }; +#else // EIGEN_TEST_EVALUATORS +namespace internal { + +template<int Mode, typename Lhs, typename Rhs> +struct triangular_product_impl<Mode,true,Lhs,false,Rhs,true> +{ + template<typename Dest> static void run(Dest& dst, const Lhs &lhs, const Rhs &rhs, const typename Dest::Scalar& alpha) + { + eigen_assert(dst.rows()==lhs.rows() && dst.cols()==rhs.cols()); + + internal::trmv_selector<Mode,(int(internal::traits<Lhs>::Flags)&RowMajorBit) ? RowMajor : ColMajor>::run(lhs, rhs, dst, alpha); + } +}; + +template<int Mode, typename Lhs, typename Rhs> +struct triangular_product_impl<Mode,false,Lhs,true,Rhs,false> +{ + template<typename Dest> static void run(Dest& dst, const Lhs &lhs, const Rhs &rhs, const typename Dest::Scalar& alpha) + { + eigen_assert(dst.rows()==lhs.rows() && dst.cols()==rhs.cols()); + + Transpose<Dest> dstT(dst); + internal::trmv_selector<(Mode & (UnitDiag|ZeroDiag)) | ((Mode & Lower) ? Upper : Lower), + (int(internal::traits<Rhs>::Flags)&RowMajorBit) ? ColMajor : RowMajor> + ::run(rhs.transpose(),lhs.transpose(), dstT, alpha); + } +}; + +} // end namespace internal +#endif // EIGEN_TEST_EVALUATORS + namespace internal { // TODO: find a way to factorize this piece of code with gemv_selector since the logic is exactly the same. -template<> struct trmv_selector<ColMajor> +template<int Mode> struct trmv_selector<Mode,ColMajor> { - template<int Mode, typename Lhs, typename Rhs, typename Dest> - static void run(const TriangularProduct<Mode,true,Lhs,false,Rhs,true>& prod, Dest& dest, const typename TriangularProduct<Mode,true,Lhs,false,Rhs,true>::Scalar& alpha) + template<typename Lhs, typename Rhs, typename Dest> + static void run(const Lhs &lhs, const Rhs &rhs, Dest& dest, const typename Dest::Scalar& alpha) { - typedef TriangularProduct<Mode,true,Lhs,false,Rhs,true> ProductType; - typedef typename ProductType::Index Index; - typedef typename ProductType::LhsScalar LhsScalar; - typedef typename ProductType::RhsScalar RhsScalar; - typedef typename ProductType::Scalar ResScalar; - typedef typename ProductType::RealScalar RealScalar; - typedef typename ProductType::ActualLhsType ActualLhsType; - typedef typename ProductType::ActualRhsType ActualRhsType; - typedef typename ProductType::LhsBlasTraits LhsBlasTraits; - typedef typename ProductType::RhsBlasTraits RhsBlasTraits; + typedef typename Dest::Index Index; + typedef typename Lhs::Scalar LhsScalar; + typedef typename Rhs::Scalar RhsScalar; + typedef typename Dest::Scalar ResScalar; + typedef typename Dest::RealScalar RealScalar; + + typedef internal::blas_traits<Lhs> LhsBlasTraits; + typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType; + typedef internal::blas_traits<Rhs> RhsBlasTraits; + typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType; + typedef Map<Matrix<ResScalar,Dynamic,1>, Aligned> MappedDest; - typename internal::add_const_on_value_type<ActualLhsType>::type actualLhs = LhsBlasTraits::extract(prod.lhs()); - typename internal::add_const_on_value_type<ActualRhsType>::type actualRhs = RhsBlasTraits::extract(prod.rhs()); + typename internal::add_const_on_value_type<ActualLhsType>::type actualLhs = LhsBlasTraits::extract(lhs); + typename internal::add_const_on_value_type<ActualRhsType>::type actualRhs = RhsBlasTraits::extract(rhs); - ResScalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs()) - * RhsBlasTraits::extractScalarFactor(prod.rhs()); + ResScalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(lhs) + * RhsBlasTraits::extractScalarFactor(rhs); enum { // FIXME find a way to allow an inner stride on the result if packet_traits<Scalar>::size==1 @@ -288,33 +321,33 @@ template<> struct trmv_selector<ColMajor> } }; -template<> struct trmv_selector<RowMajor> +template<int Mode> struct trmv_selector<Mode,RowMajor> { - template<int Mode, typename Lhs, typename Rhs, typename Dest> - static void run(const TriangularProduct<Mode,true,Lhs,false,Rhs,true>& prod, Dest& dest, const typename TriangularProduct<Mode,true,Lhs,false,Rhs,true>::Scalar& alpha) + template<typename Lhs, typename Rhs, typename Dest> + static void run(const Lhs &lhs, const Rhs &rhs, Dest& dest, const typename Dest::Scalar& alpha) { - typedef TriangularProduct<Mode,true,Lhs,false,Rhs,true> ProductType; - typedef typename ProductType::LhsScalar LhsScalar; - typedef typename ProductType::RhsScalar RhsScalar; - typedef typename ProductType::Scalar ResScalar; - typedef typename ProductType::Index Index; - typedef typename ProductType::ActualLhsType ActualLhsType; - typedef typename ProductType::ActualRhsType ActualRhsType; - typedef typename ProductType::_ActualRhsType _ActualRhsType; - typedef typename ProductType::LhsBlasTraits LhsBlasTraits; - typedef typename ProductType::RhsBlasTraits RhsBlasTraits; - - typename add_const<ActualLhsType>::type actualLhs = LhsBlasTraits::extract(prod.lhs()); - typename add_const<ActualRhsType>::type actualRhs = RhsBlasTraits::extract(prod.rhs()); - - ResScalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs()) - * RhsBlasTraits::extractScalarFactor(prod.rhs()); + typedef typename Dest::Index Index; + typedef typename Lhs::Scalar LhsScalar; + typedef typename Rhs::Scalar RhsScalar; + typedef typename Dest::Scalar ResScalar; + + typedef internal::blas_traits<Lhs> LhsBlasTraits; + typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType; + typedef internal::blas_traits<Rhs> RhsBlasTraits; + typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType; + typedef typename internal::remove_all<ActualRhsType>::type ActualRhsTypeCleaned; + + typename add_const<ActualLhsType>::type actualLhs = LhsBlasTraits::extract(lhs); + typename add_const<ActualRhsType>::type actualRhs = RhsBlasTraits::extract(rhs); + + ResScalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(lhs) + * RhsBlasTraits::extractScalarFactor(rhs); enum { - DirectlyUseRhs = _ActualRhsType::InnerStrideAtCompileTime==1 + DirectlyUseRhs = ActualRhsTypeCleaned::InnerStrideAtCompileTime==1 }; - gemv_static_vector_if<RhsScalar,_ActualRhsType::SizeAtCompileTime,_ActualRhsType::MaxSizeAtCompileTime,!DirectlyUseRhs> static_rhs; + gemv_static_vector_if<RhsScalar,ActualRhsTypeCleaned::SizeAtCompileTime,ActualRhsTypeCleaned::MaxSizeAtCompileTime,!DirectlyUseRhs> static_rhs; ei_declare_aligned_stack_constructed_variable(RhsScalar,actualRhsPtr,actualRhs.size(), DirectlyUseRhs ? const_cast<RhsScalar*>(actualRhs.data()) : static_rhs.data()); @@ -325,7 +358,7 @@ template<> struct trmv_selector<RowMajor> int size = actualRhs.size(); EIGEN_DENSE_STORAGE_CTOR_PLUGIN #endif - Map<typename _ActualRhsType::PlainObject>(actualRhsPtr, actualRhs.size()) = actualRhs; + Map<typename ActualRhsTypeCleaned::PlainObject>(actualRhsPtr, actualRhs.size()) = actualRhs; } internal::triangular_matrix_vector_product |