aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Gael Guennebaud <g.gael@free.fr>2014-02-17 16:10:55 +0100
committerGravatar Gael Guennebaud <g.gael@free.fr>2014-02-17 16:10:55 +0100
commitbffa15142c4271313a70801e6bb7d01365a00bc9 (patch)
tree205f1a6031a294d939501cfbcc058ad3eb6c8171
parent94acccc126d430bf34587527d84ff9b389219c2f (diff)
Add evaluator support for diagonal products
-rw-r--r--Eigen/src/Core/DiagonalMatrix.h36
-rw-r--r--Eigen/src/Core/ProductEvaluators.h146
-rw-r--r--test/evaluators.cpp17
3 files changed, 192 insertions, 7 deletions
diff --git a/Eigen/src/Core/DiagonalMatrix.h b/Eigen/src/Core/DiagonalMatrix.h
index f7ac22f8b..8df636928 100644
--- a/Eigen/src/Core/DiagonalMatrix.h
+++ b/Eigen/src/Core/DiagonalMatrix.h
@@ -66,6 +66,7 @@ class DiagonalBase : public EigenBase<Derived>
EIGEN_DEVICE_FUNC
inline Index cols() const { return diagonal().size(); }
+#ifndef EIGEN_TEST_EVALUATORS
/** \returns the diagonal matrix product of \c *this by the matrix \a matrix.
*/
template<typename MatrixDerived>
@@ -75,6 +76,15 @@ class DiagonalBase : public EigenBase<Derived>
{
return DiagonalProduct<MatrixDerived, Derived, OnTheLeft>(matrix.derived(), derived());
}
+#else
+ template<typename MatrixDerived>
+ EIGEN_DEVICE_FUNC
+ const Product<Derived,MatrixDerived,LazyProduct>
+ operator*(const MatrixBase<MatrixDerived> &matrix) const
+ {
+ return Product<Derived, MatrixDerived, LazyProduct>(derived(),matrix.derived());
+ }
+#endif // EIGEN_TEST_EVALUATORS
EIGEN_DEVICE_FUNC
inline const DiagonalWrapper<const CwiseUnaryOp<internal::scalar_inverse_op<Scalar>, const DiagonalVectorType> >
@@ -270,7 +280,8 @@ struct traits<DiagonalWrapper<_DiagonalVectorType> >
ColsAtCompileTime = DiagonalVectorType::SizeAtCompileTime,
MaxRowsAtCompileTime = DiagonalVectorType::SizeAtCompileTime,
MaxColsAtCompileTime = DiagonalVectorType::SizeAtCompileTime,
- Flags = traits<DiagonalVectorType>::Flags & LvalueBit
+ Flags = traits<DiagonalVectorType>::Flags & LvalueBit,
+ CoeffReadCost = traits<_DiagonalVectorType>::CoeffReadCost
};
};
}
@@ -341,6 +352,29 @@ bool MatrixBase<Derived>::isDiagonal(const RealScalar& prec) const
return true;
}
+#ifdef EIGEN_ENABLE_EVALUATORS
+namespace internal {
+
+// TODO currently a diagonal expression has the form DiagonalMatrix<> or DiagonalWrapper
+// in the future diagonal-ness should be defined by the expression traits
+template<typename _Scalar, int SizeAtCompileTime, int MaxSizeAtCompileTime>
+struct evaluator_traits<DiagonalMatrix<_Scalar,SizeAtCompileTime,MaxSizeAtCompileTime> >
+{
+ typedef typename storage_kind_to_evaluator_kind<Dense>::Kind Kind;
+ typedef DiagonalShape Shape;
+ static const int AssumeAliasing = 0;
+};
+template<typename Derived>
+struct evaluator_traits<DiagonalWrapper<Derived> >
+{
+ typedef typename storage_kind_to_evaluator_kind<typename Derived::StorageKind>::Kind Kind;
+ typedef DiagonalShape Shape;
+ static const int AssumeAliasing = 0;
+};
+
+} // namespace internal
+#endif // EIGEN_ENABLE_EVALUATORS
+
} // end namespace Eigen
#endif // EIGEN_DIAGONALMATRIX_H
diff --git a/Eigen/src/Core/ProductEvaluators.h b/Eigen/src/Core/ProductEvaluators.h
index 2279ec33b..d991ff8b5 100644
--- a/Eigen/src/Core/ProductEvaluators.h
+++ b/Eigen/src/Core/ProductEvaluators.h
@@ -309,9 +309,9 @@ struct generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,CoeffBasedProductMode>
};
// This specialization enforces the use of a coefficient-based evaluation strategy
-// template<typename Lhs, typename Rhs>
-// struct generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,LazyCoeffBasedProductMode>
-// : generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,CoeffBasedProductMode> {};
+template<typename Lhs, typename Rhs>
+struct generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,LazyCoeffBasedProductMode>
+ : generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,CoeffBasedProductMode> {};
// Case 2: Evaluate coeff by coeff
//
@@ -764,6 +764,146 @@ protected:
PlainObject m_result;
};
+/***************************************************************************
+* Diagonal products
+***************************************************************************/
+
+template<typename MatrixType, typename DiagonalType, typename Derived>
+struct diagonal_product_evaluator_base
+ : evaluator_base<Derived>
+{
+ typedef typename MatrixType::Index Index;
+ typedef typename MatrixType::Scalar Scalar;
+ typedef typename MatrixType::PacketScalar PacketScalar;
+public:
+ diagonal_product_evaluator_base(const MatrixType &mat, const DiagonalType &diag)
+ : m_diagImpl(diag), m_matImpl(mat)
+ {
+ }
+
+ EIGEN_STRONG_INLINE const Scalar coeff(Index idx) const
+ {
+ return m_diagImpl.coeff(idx) * m_matImpl.coeff(idx);
+ }
+
+protected:
+ template<int LoadMode>
+ EIGEN_STRONG_INLINE PacketScalar packet_impl(Index row, Index col, Index id, internal::true_type) const
+ {
+ return internal::pmul(m_matImpl.template packet<LoadMode>(row, col),
+ internal::pset1<PacketScalar>(m_diagImpl.coeff(id)));
+ }
+
+ template<int LoadMode>
+ EIGEN_STRONG_INLINE PacketScalar packet_impl(Index row, Index col, Index id, internal::false_type) const
+ {
+ enum {
+ InnerSize = (MatrixType::Flags & RowMajorBit) ? MatrixType::ColsAtCompileTime : MatrixType::RowsAtCompileTime,
+ DiagonalPacketLoadMode = (LoadMode == Aligned && (((InnerSize%16) == 0) || (int(DiagonalType::Flags)&AlignedBit)==AlignedBit) ? Aligned : Unaligned)
+ };
+ return internal::pmul(m_matImpl.template packet<LoadMode>(row, col),
+ m_diagImpl.template packet<DiagonalPacketLoadMode>(id));
+ }
+
+ typename evaluator<DiagonalType>::nestedType m_diagImpl;
+ typename evaluator<MatrixType>::nestedType m_matImpl;
+};
+
+// diagonal * dense
+template<typename Lhs, typename Rhs, int ProductKind, int ProductTag>
+struct product_evaluator<Product<Lhs, Rhs, ProductKind>, ProductTag, DiagonalShape, DenseShape, typename Lhs::Scalar, typename Rhs::Scalar>
+ : diagonal_product_evaluator_base<Rhs, typename Lhs::DiagonalVectorType, Product<Lhs, Rhs, LazyProduct> >
+{
+ typedef diagonal_product_evaluator_base<Rhs, typename Lhs::DiagonalVectorType, Product<Lhs, Rhs, LazyProduct> > Base;
+ using Base::m_diagImpl;
+ using Base::m_matImpl;
+ using Base::coeff;
+ using Base::packet_impl;
+ typedef typename Base::Scalar Scalar;
+ typedef typename Base::Index Index;
+ typedef typename Base::PacketScalar PacketScalar;
+
+ typedef Product<Lhs, Rhs, ProductKind> XprType;
+ typedef typename XprType::PlainObject PlainObject;
+
+ product_evaluator(const XprType& xpr)
+ : Base(xpr.rhs(), xpr.lhs().diagonal())
+ {
+ }
+
+ EIGEN_STRONG_INLINE const Scalar coeff(Index row, Index col) const
+ {
+ return m_diagImpl.coeff(row) * m_matImpl.coeff(row, col);
+ }
+
+ template<int LoadMode>
+ EIGEN_STRONG_INLINE PacketScalar packet(Index row, Index col) const
+ {
+ enum {
+ StorageOrder = Rhs::Flags & RowMajorBit ? RowMajor : ColMajor
+ };
+ return this->template packet_impl<LoadMode>(row,col, row,
+ typename internal::conditional<int(StorageOrder)==RowMajor, internal::true_type, internal::false_type>::type());
+ }
+
+ template<int LoadMode>
+ EIGEN_STRONG_INLINE PacketScalar packet(Index idx) const
+ {
+ enum {
+ StorageOrder = int(Rhs::Flags) & RowMajorBit ? RowMajor : ColMajor
+ };
+ return packet<LoadMode>(int(StorageOrder)==ColMajor?idx:0,int(StorageOrder)==ColMajor?0:idx);
+ }
+
+};
+
+// dense * diagonal
+template<typename Lhs, typename Rhs, int ProductKind, int ProductTag>
+struct product_evaluator<Product<Lhs, Rhs, ProductKind>, ProductTag, DenseShape, DiagonalShape, typename Lhs::Scalar, typename Rhs::Scalar>
+ : diagonal_product_evaluator_base<Lhs, typename Rhs::DiagonalVectorType, Product<Lhs, Rhs, LazyProduct> >
+{
+ typedef diagonal_product_evaluator_base<Lhs, typename Rhs::DiagonalVectorType, Product<Lhs, Rhs, LazyProduct> > Base;
+ using Base::m_diagImpl;
+ using Base::m_matImpl;
+ using Base::coeff;
+ using Base::packet_impl;
+ typedef typename Base::Scalar Scalar;
+ typedef typename Base::Index Index;
+ typedef typename Base::PacketScalar PacketScalar;
+
+ typedef Product<Lhs, Rhs, ProductKind> XprType;
+ typedef typename XprType::PlainObject PlainObject;
+
+ product_evaluator(const XprType& xpr)
+ : Base(xpr.lhs(), xpr.rhs().diagonal())
+ {
+ }
+
+ EIGEN_STRONG_INLINE const Scalar coeff(Index row, Index col) const
+ {
+ return m_matImpl.coeff(row, col) * m_diagImpl.coeff(col);
+ }
+
+ template<int LoadMode>
+ EIGEN_STRONG_INLINE PacketScalar packet(Index row, Index col) const
+ {
+ enum {
+ StorageOrder = Rhs::Flags & RowMajorBit ? RowMajor : ColMajor
+ };
+ return this->template packet_impl<LoadMode>(row,col, col,
+ typename internal::conditional<int(StorageOrder)==ColMajor, internal::true_type, internal::false_type>::type());
+ }
+
+ template<int LoadMode>
+ EIGEN_STRONG_INLINE PacketScalar packet(Index idx) const
+ {
+ enum {
+ StorageOrder = int(Rhs::Flags) & RowMajorBit ? RowMajor : ColMajor
+ };
+ return packet<LoadMode>(int(StorageOrder)==ColMajor?idx:0,int(StorageOrder)==ColMajor?0:idx);
+ }
+
+};
} // end namespace internal
diff --git a/test/evaluators.cpp b/test/evaluators.cpp
index 69a45661f..305047a6a 100644
--- a/test/evaluators.cpp
+++ b/test/evaluators.cpp
@@ -151,19 +151,19 @@ void test_evaluators()
c = a*a;
copy_using_evaluator(a, prod(a,a));
VERIFY_IS_APPROX(a,c);
-
+
// check compound assignment of products
d = c;
add_assign_using_evaluator(c.noalias(), prod(a,b));
d.noalias() += a*b;
VERIFY_IS_APPROX(c, d);
-
+
d = c;
subtract_assign_using_evaluator(c.noalias(), prod(a,b));
d.noalias() -= a*b;
VERIFY_IS_APPROX(c, d);
}
-
+
{
// test product with all possible sizes
int s = internal::random<int>(1,100);
@@ -458,4 +458,15 @@ void test_evaluators()
VERIFY_IS_APPROX_EVALUATOR2(B, prod(A.selfadjointView<Upper>(),A), MatrixXd(A.selfadjointView<Upper>()*A));
}
+
+ {
+ // test diagonal shapes
+ VectorXd d = VectorXd::Random(6);
+ MatrixXd A = MatrixXd::Random(6,6), B(6,6);
+ A.setRandom();B.setRandom();
+
+ VERIFY_IS_APPROX_EVALUATOR2(B, lazyprod(d.asDiagonal(),A), MatrixXd(d.asDiagonal()*A));
+ VERIFY_IS_APPROX_EVALUATOR2(B, lazyprod(A,d.asDiagonal()), MatrixXd(A*d.asDiagonal()));
+
+ }
}