aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Gael Guennebaud <g.gael@free.fr>2011-02-01 11:38:46 +0100
committerGravatar Gael Guennebaud <g.gael@free.fr>2011-02-01 11:38:46 +0100
commitc60818fca8ed58a272fab9f3f62024e04eac1a1c (patch)
tree55cf1d692557b7fbd682b33378cb0cae4444bb8a
parent0fdd01fe247ce40add4a6e45e817246fdf99ba5d (diff)
fix trmv regarding strided vectors and static allocation of temporaries
-rw-r--r--Eigen/src/Core/ProductBase.h2
-rw-r--r--Eigen/src/Core/products/TriangularMatrixVector.h186
-rw-r--r--test/nomalloc.cpp32
3 files changed, 194 insertions, 26 deletions
diff --git a/Eigen/src/Core/ProductBase.h b/Eigen/src/Core/ProductBase.h
index 1f2b373cd..287ea554f 100644
--- a/Eigen/src/Core/ProductBase.h
+++ b/Eigen/src/Core/ProductBase.h
@@ -84,12 +84,14 @@ class ProductBase : public MatrixBase<Derived>
typedef internal::blas_traits<_LhsNested> LhsBlasTraits;
typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
typedef typename internal::remove_all<ActualLhsType>::type _ActualLhsType;
+ typedef typename internal::traits<Lhs>::Scalar LhsScalar;
typedef typename Rhs::Nested RhsNested;
typedef typename internal::remove_all<RhsNested>::type _RhsNested;
typedef internal::blas_traits<_RhsNested> RhsBlasTraits;
typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
typedef typename internal::remove_all<ActualRhsType>::type _ActualRhsType;
+ typedef typename internal::traits<Rhs>::Scalar RhsScalar;
// Diagonal of a product: no need to evaluate the arguments because they are going to be evaluated only once
typedef CoeffBasedProduct<LhsNested, RhsNested, 0> FullyLazyCoeffBaseProductType;
diff --git a/Eigen/src/Core/products/TriangularMatrixVector.h b/Eigen/src/Core/products/TriangularMatrixVector.h
index c1f64dcea..23aa52ade 100644
--- a/Eigen/src/Core/products/TriangularMatrixVector.h
+++ b/Eigen/src/Core/products/TriangularMatrixVector.h
@@ -152,6 +152,10 @@ struct traits<TriangularProduct<Mode,LhsIsTriangular,Lhs,true,Rhs,false> >
: traits<ProductBase<TriangularProduct<Mode,LhsIsTriangular,Lhs,true,Rhs,false>, Lhs, Rhs> >
{};
+
+template<int StorageOrder>
+struct trmv_selector;
+
} // end namespace internal
template<int Mode, typename Lhs, typename Rhs>
@@ -165,20 +169,8 @@ struct TriangularProduct<Mode,true,Lhs,false,Rhs,true>
template<typename Dest> void scaleAndAddTo(Dest& dst, Scalar alpha) const
{
eigen_assert(dst.rows()==m_lhs.rows() && dst.cols()==m_rhs.cols());
-
- const ActualLhsType lhs = LhsBlasTraits::extract(m_lhs);
- const ActualRhsType rhs = RhsBlasTraits::extract(m_rhs);
-
- Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(m_lhs)
- * RhsBlasTraits::extractScalarFactor(m_rhs);
-
- internal::product_triangular_matrix_vector
- <Index,Mode,
- typename _ActualLhsType::Scalar, LhsBlasTraits::NeedToConjugate,
- typename _ActualRhsType::Scalar, RhsBlasTraits::NeedToConjugate,
- (int(internal::traits<Lhs>::Flags)&RowMajorBit) ? RowMajor : ColMajor>
- ::run(lhs.rows(),lhs.cols(),lhs.data(),lhs.outerStride(),rhs.data(),rhs.innerStride(),
- dst.data(),dst.innerStride(),actualAlpha);
+
+ internal::trmv_selector<(int(internal::traits<Lhs>::Flags)&RowMajorBit) ? RowMajor : ColMajor>::run(*this, dst, alpha);
}
};
@@ -192,23 +184,167 @@ struct TriangularProduct<Mode,false,Lhs,true,Rhs,false>
template<typename Dest> void scaleAndAddTo(Dest& dst, Scalar alpha) const
{
-
eigen_assert(dst.rows()==m_lhs.rows() && dst.cols()==m_rhs.cols());
+
+ typedef TriangularProduct<(Mode & UnitDiag) | ((Mode & Lower) ? Upper : Lower),true,Transpose<const Rhs>,false,Transpose<const Lhs>,true> TriangularProductTranspose;
+ Transpose<Dest> dstT(dst);
+ internal::trmv_selector<(int(internal::traits<Rhs>::Flags)&RowMajorBit) ? ColMajor : RowMajor>::run(
+ TriangularProductTranspose(m_rhs.transpose(),m_lhs.transpose()), dstT, alpha);
+ }
+};
+
+namespace internal {
+
+// TODO: find a way to factorize this piece of code with gemv_selector since the logic is exactly the same.
+
+template<> struct trmv_selector<ColMajor>
+{
+ template<int Mode, typename Lhs, typename Rhs, typename Dest>
+ static void run(const TriangularProduct<Mode,true,Lhs,false,Rhs,true>& prod, Dest& dest, typename TriangularProduct<Mode,true,Lhs,false,Rhs,true>::Scalar alpha)
+ {
+ typedef TriangularProduct<Mode,true,Lhs,false,Rhs,true> ProductType;
+ typedef typename ProductType::Index Index;
+ typedef typename ProductType::LhsScalar LhsScalar;
+ typedef typename ProductType::RhsScalar RhsScalar;
+ typedef typename ProductType::Scalar ResScalar;
+ typedef typename ProductType::RealScalar RealScalar;
+ typedef typename ProductType::ActualLhsType ActualLhsType;
+ typedef typename ProductType::ActualRhsType ActualRhsType;
+ typedef typename ProductType::LhsBlasTraits LhsBlasTraits;
+ typedef typename ProductType::RhsBlasTraits RhsBlasTraits;
+ typedef Map<Matrix<ResScalar,Dynamic,1>, Aligned> MappedDest;
+
+ const ActualLhsType actualLhs = LhsBlasTraits::extract(prod.lhs());
+ const ActualRhsType actualRhs = RhsBlasTraits::extract(prod.rhs());
+
+ ResScalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs())
+ * RhsBlasTraits::extractScalarFactor(prod.rhs());
- const ActualLhsType lhs = LhsBlasTraits::extract(m_lhs);
- const ActualRhsType rhs = RhsBlasTraits::extract(m_rhs);
+ enum {
+ // FIXME find a way to allow an inner stride on the result if packet_traits<Scalar>::size==1
+ // on, the other hand it is good for the cache to pack the vector anyways...
+ EvalToDestAtCompileTime = Dest::InnerStrideAtCompileTime==1,
+ ComplexByReal = (NumTraits<LhsScalar>::IsComplex) && (!NumTraits<RhsScalar>::IsComplex),
+ MightCannotUseDest = (Dest::InnerStrideAtCompileTime!=1) || ComplexByReal
+ };
- Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(m_lhs)
- * RhsBlasTraits::extractScalarFactor(m_rhs);
+ gemv_static_vector_if<ResScalar,Dest::SizeAtCompileTime,Dest::MaxSizeAtCompileTime,MightCannotUseDest> static_dest;
+ bool alphaIsCompatible = (!ComplexByReal) || (imag(actualAlpha)==RealScalar(0));
+ bool evalToDest = EvalToDestAtCompileTime && alphaIsCompatible;
+
+ RhsScalar compatibleAlpha = get_factor<ResScalar,RhsScalar>::run(actualAlpha);
+
+ ResScalar* actualDestPtr;
+ bool freeDestPtr = false;
+ if (evalToDest)
+ {
+ actualDestPtr = dest.data();
+ }
+ else
+ {
+ #ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN
+ int size = dest.size();
+ EIGEN_DENSE_STORAGE_CTOR_PLUGIN
+ #endif
+ if((actualDestPtr = static_dest.data())==0)
+ {
+ freeDestPtr = true;
+ actualDestPtr = ei_aligned_stack_new(ResScalar,dest.size());
+ }
+ if(!alphaIsCompatible)
+ {
+ MappedDest(actualDestPtr, dest.size()).setZero();
+ compatibleAlpha = RhsScalar(1);
+ }
+ else
+ MappedDest(actualDestPtr, dest.size()) = dest;
+ }
+
internal::product_triangular_matrix_vector
- <Index,(Mode & UnitDiag) | ((Mode & Lower) ? Upper : Lower),
- typename _ActualRhsType::Scalar, RhsBlasTraits::NeedToConjugate,
- typename _ActualLhsType::Scalar, LhsBlasTraits::NeedToConjugate,
- (int(internal::traits<Rhs>::Flags)&RowMajorBit) ? ColMajor : RowMajor>
- ::run(rhs.rows(),rhs.cols(),rhs.data(),rhs.outerStride(),lhs.data(),lhs.innerStride(),
- dst.data(),dst.innerStride(),actualAlpha);
+ <Index,Mode,
+ LhsScalar, LhsBlasTraits::NeedToConjugate,
+ RhsScalar, RhsBlasTraits::NeedToConjugate,
+ ColMajor>
+ ::run(actualLhs.rows(),actualLhs.cols(),
+ actualLhs.data(),actualLhs.outerStride(),
+ actualRhs.data(),actualRhs.innerStride(),
+ actualDestPtr,1,compatibleAlpha);
+
+ if (!evalToDest)
+ {
+ if(!alphaIsCompatible)
+ dest += actualAlpha * MappedDest(actualDestPtr, dest.size());
+ else
+ dest = MappedDest(actualDestPtr, dest.size());
+ if(freeDestPtr) ei_aligned_stack_delete(ResScalar, actualDestPtr, dest.size());
+ }
}
};
+template<> struct trmv_selector<RowMajor>
+{
+ template<int Mode, typename Lhs, typename Rhs, typename Dest>
+ static void run(const TriangularProduct<Mode,true,Lhs,false,Rhs,true>& prod, Dest& dest, typename TriangularProduct<Mode,true,Lhs,false,Rhs,true>::Scalar alpha)
+ {
+ typedef TriangularProduct<Mode,true,Lhs,false,Rhs,true> ProductType;
+ typedef typename ProductType::LhsScalar LhsScalar;
+ typedef typename ProductType::RhsScalar RhsScalar;
+ typedef typename ProductType::Scalar ResScalar;
+ typedef typename ProductType::Index Index;
+ typedef typename ProductType::ActualLhsType ActualLhsType;
+ typedef typename ProductType::ActualRhsType ActualRhsType;
+ typedef typename ProductType::_ActualRhsType _ActualRhsType;
+ typedef typename ProductType::LhsBlasTraits LhsBlasTraits;
+ typedef typename ProductType::RhsBlasTraits RhsBlasTraits;
+
+ typename add_const<ActualLhsType>::type actualLhs = LhsBlasTraits::extract(prod.lhs());
+ typename add_const<ActualRhsType>::type actualRhs = RhsBlasTraits::extract(prod.rhs());
+
+ ResScalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs())
+ * RhsBlasTraits::extractScalarFactor(prod.rhs());
+
+ enum {
+ DirectlyUseRhs = _ActualRhsType::InnerStrideAtCompileTime==1
+ };
+
+ gemv_static_vector_if<RhsScalar,_ActualRhsType::SizeAtCompileTime,_ActualRhsType::MaxSizeAtCompileTime,!DirectlyUseRhs> static_rhs;
+
+ RhsScalar* actualRhsPtr;
+ bool freeRhsPtr = false;
+ if (DirectlyUseRhs)
+ {
+ actualRhsPtr = const_cast<RhsScalar*>(actualRhs.data());
+ }
+ else
+ {
+ #ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN
+ int size = actualRhs.size();
+ EIGEN_DENSE_STORAGE_CTOR_PLUGIN
+ #endif
+ if((actualRhsPtr = static_rhs.data())==0)
+ {
+ freeRhsPtr = true;
+ actualRhsPtr = ei_aligned_stack_new(RhsScalar, actualRhs.size());
+ }
+ Map<typename _ActualRhsType::PlainObject>(actualRhsPtr, actualRhs.size()) = actualRhs;
+ }
+
+ internal::product_triangular_matrix_vector
+ <Index,Mode,
+ LhsScalar, LhsBlasTraits::NeedToConjugate,
+ RhsScalar, RhsBlasTraits::NeedToConjugate,
+ RowMajor>
+ ::run(actualLhs.rows(),actualLhs.cols(),
+ actualLhs.data(),actualLhs.outerStride(),
+ actualRhsPtr,1,
+ dest.data(),dest.innerStride(),
+ actualAlpha);
+
+ if((!DirectlyUseRhs) && freeRhsPtr) ei_aligned_stack_delete(RhsScalar, actualRhsPtr, prod.rhs().size());
+ }
+};
+
+} // end namespace internal
+
#endif // EIGEN_TRIANGULARMATRIXVECTOR_H
diff --git a/test/nomalloc.cpp b/test/nomalloc.cpp
index 94c1b0533..7ef71bfcd 100644
--- a/test/nomalloc.cpp
+++ b/test/nomalloc.cpp
@@ -71,7 +71,7 @@ template<typename MatrixType> void nomalloc(const MatrixType& m)
VERIFY_IS_APPROX((m1+m2)(r,c), (m1(r,c))+(m2(r,c)));
VERIFY_IS_APPROX(m1.cwiseProduct(m1.block(0,0,rows,cols)), (m1.array()*m1.array()).matrix());
VERIFY_IS_APPROX((m1*m1.transpose())*m2, m1*(m1.transpose()*m2));
-
+
m2.col(0).noalias() = m1 * m1.col(0);
m2.col(0).noalias() -= m1.adjoint() * m1.col(0);
m2.col(0).noalias() -= m1 * m1.row(0).adjoint();
@@ -81,6 +81,36 @@ template<typename MatrixType> void nomalloc(const MatrixType& m)
m2.row(0).noalias() -= m1.row(0) * m1.adjoint();
m2.row(0).noalias() -= m1.col(0).adjoint() * m1;
m2.row(0).noalias() -= m1.col(0).adjoint() * m1.adjoint();
+ VERIFY_IS_APPROX(m2,m2);
+
+ m2.col(0).noalias() = m1.template triangularView<Upper>() * m1.col(0);
+ m2.col(0).noalias() -= m1.adjoint().template triangularView<Upper>() * m1.col(0);
+ m2.col(0).noalias() -= m1.template triangularView<Upper>() * m1.row(0).adjoint();
+ m2.col(0).noalias() -= m1.adjoint().template triangularView<Upper>() * m1.row(0).adjoint();
+
+ m2.row(0).noalias() = m1.row(0) * m1.template triangularView<Upper>();
+ m2.row(0).noalias() -= m1.row(0) * m1.adjoint().template triangularView<Upper>();
+ m2.row(0).noalias() -= m1.col(0).adjoint() * m1.template triangularView<Upper>();
+ m2.row(0).noalias() -= m1.col(0).adjoint() * m1.adjoint().template triangularView<Upper>();
+ VERIFY_IS_APPROX(m2,m2);
+
+ m2.col(0).noalias() = m1.template selfadjointView<Upper>() * m1.col(0);
+ m2.col(0).noalias() -= m1.adjoint().template selfadjointView<Upper>() * m1.col(0);
+ m2.col(0).noalias() -= m1.template selfadjointView<Upper>() * m1.row(0).adjoint();
+ m2.col(0).noalias() -= m1.adjoint().template selfadjointView<Upper>() * m1.row(0).adjoint();
+
+ m2.row(0).noalias() = m1.row(0) * m1.template selfadjointView<Upper>();
+ m2.row(0).noalias() -= m1.row(0) * m1.adjoint().template selfadjointView<Upper>();
+ m2.row(0).noalias() -= m1.col(0).adjoint() * m1.template selfadjointView<Upper>();
+ m2.row(0).noalias() -= m1.col(0).adjoint() * m1.adjoint().template selfadjointView<Upper>();
+ VERIFY_IS_APPROX(m2,m2);
+
+ // The following fancy matrix-matrix products are not safe yet regarding static allocation
+// m1 += m1.template triangularView<Upper>() * m2.col(;
+// m1.template selfadjointView<Lower>().rankUpdate(m2);
+// m1 += m1.template triangularView<Upper>() * m2;
+// m1 += m1.template selfadjointView<Lower>() * m2;
+// VERIFY_IS_APPROX(m1,m1);
}
template<typename Scalar>