aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen
diff options
context:
space:
mode:
authorGravatar Gael Guennebaud <g.gael@free.fr>2014-06-27 15:54:44 +0200
committerGravatar Gael Guennebaud <g.gael@free.fr>2014-06-27 15:54:44 +0200
commit73e686c6a44adfd54fe75a0d7581fd14bfa58f54 (patch)
tree9e4da476ae70595190bdacb8a0df4e079d6620af /Eigen
parentae039dde135a6af852d7028abd772316613a5249 (diff)
Implement evaluators for sparse times diagonal products.
Diffstat (limited to 'Eigen')
-rw-r--r--Eigen/SparseCore2
-rw-r--r--Eigen/src/SparseCore/SparseDiagonalProduct.h126
-rw-r--r--Eigen/src/SparseCore/SparseMatrixBase.h14
3 files changed, 139 insertions, 3 deletions
diff --git a/Eigen/SparseCore b/Eigen/SparseCore
index 340ff7f52..b338950ca 100644
--- a/Eigen/SparseCore
+++ b/Eigen/SparseCore
@@ -48,6 +48,7 @@ struct Sparse {};
#include "src/SparseCore/SparseDot.h"
#include "src/SparseCore/SparseRedux.h"
#include "src/SparseCore/SparseView.h"
+#include "src/SparseCore/SparseDiagonalProduct.h"
#ifndef EIGEN_TEST_EVALUATORS
#include "src/SparseCore/SparsePermutation.h"
#include "src/SparseCore/SparseFuzzy.h"
@@ -55,7 +56,6 @@ struct Sparse {};
#include "src/SparseCore/SparseSparseProductWithPruning.h"
#include "src/SparseCore/SparseProduct.h"
#include "src/SparseCore/SparseDenseProduct.h"
-#include "src/SparseCore/SparseDiagonalProduct.h"
#include "src/SparseCore/SparseTriangularView.h"
#include "src/SparseCore/SparseSelfAdjointView.h"
#include "src/SparseCore/TriangularSolver.h"
diff --git a/Eigen/src/SparseCore/SparseDiagonalProduct.h b/Eigen/src/SparseCore/SparseDiagonalProduct.h
index 1bb590e64..cf0f77342 100644
--- a/Eigen/src/SparseCore/SparseDiagonalProduct.h
+++ b/Eigen/src/SparseCore/SparseDiagonalProduct.h
@@ -24,8 +24,10 @@ namespace Eigen {
// for that particular case
// The two other cases are symmetric.
+#ifndef EIGEN_TEST_EVALUATORS
+
namespace internal {
-
+
template<typename Lhs, typename Rhs>
struct traits<SparseDiagonalProduct<Lhs, Rhs> >
{
@@ -100,9 +102,14 @@ class SparseDiagonalProduct
LhsNested m_lhs;
RhsNested m_rhs;
};
+#endif
namespace internal {
+#ifndef EIGEN_TEST_EVALUATORS
+
+
+
template<typename Lhs, typename Rhs, typename SparseDiagonalProductType>
class sparse_diagonal_product_inner_iterator_selector
<Lhs,Rhs,SparseDiagonalProductType,SDP_IsDiagonal,SDP_IsSparseRowMajor>
@@ -179,10 +186,124 @@ class sparse_diagonal_product_inner_iterator_selector
inline Index row() const { return m_outer; }
};
+#else // EIGEN_TEST_EVALUATORS
+enum {
+ SDP_AsScalarProduct,
+ SDP_AsCwiseProduct
+};
+
+template<typename SparseXprType, typename DiagonalCoeffType, int SDP_Tag>
+struct sparse_diagonal_product_evaluator;
+
+template<typename Lhs, typename Rhs, int Options, int ProductTag>
+struct product_evaluator<Product<Lhs, Rhs, Options>, ProductTag, DiagonalShape, SparseShape, typename Lhs::Scalar, typename Rhs::Scalar>
+ : public sparse_diagonal_product_evaluator<Rhs, typename Lhs::DiagonalVectorType, Rhs::Flags&RowMajorBit?SDP_AsScalarProduct:SDP_AsCwiseProduct>
+{
+ typedef Product<Lhs, Rhs, Options> XprType;
+ typedef evaluator<XprType> type;
+ typedef evaluator<XprType> nestedType;
+ enum { CoeffReadCost = Dynamic, Flags = Rhs::Flags&RowMajorBit }; // FIXME CoeffReadCost & Flags
+
+ typedef sparse_diagonal_product_evaluator<Rhs, typename Lhs::DiagonalVectorType, Rhs::Flags&RowMajorBit?SDP_AsScalarProduct:SDP_AsCwiseProduct> Base;
+ product_evaluator(const XprType& xpr) : Base(xpr.rhs(), xpr.lhs().diagonal()) {}
+};
+
+template<typename Lhs, typename Rhs, int Options, int ProductTag>
+struct product_evaluator<Product<Lhs, Rhs, Options>, ProductTag, SparseShape, DiagonalShape, typename Lhs::Scalar, typename Rhs::Scalar>
+ : public sparse_diagonal_product_evaluator<Lhs, Transpose<const typename Rhs::DiagonalVectorType>, Lhs::Flags&RowMajorBit?SDP_AsCwiseProduct:SDP_AsScalarProduct>
+{
+ typedef Product<Lhs, Rhs, Options> XprType;
+ typedef evaluator<XprType> type;
+ typedef evaluator<XprType> nestedType;
+ enum { CoeffReadCost = Dynamic, Flags = Lhs::Flags&RowMajorBit }; // FIXME CoeffReadCost & Flags
+
+ typedef sparse_diagonal_product_evaluator<Lhs, Transpose<const typename Rhs::DiagonalVectorType>, Lhs::Flags&RowMajorBit?SDP_AsCwiseProduct:SDP_AsScalarProduct> Base;
+ product_evaluator(const XprType& xpr) : Base(xpr.lhs(), xpr.rhs().diagonal()) {}
+};
+
+template<typename SparseXprType, typename DiagonalCoeffType>
+struct sparse_diagonal_product_evaluator<SparseXprType, DiagonalCoeffType, SDP_AsScalarProduct>
+{
+protected:
+ typedef typename evaluator<SparseXprType>::InnerIterator SparseXprInnerIterator;
+ typedef typename SparseXprType::Scalar Scalar;
+ typedef typename SparseXprType::Index Index;
+
+public:
+ class InnerIterator : public SparseXprInnerIterator
+ {
+ public:
+ InnerIterator(const sparse_diagonal_product_evaluator &xprEval, Index outer)
+ : SparseXprInnerIterator(xprEval.m_sparseXprImpl, outer),
+ m_coeff(xprEval.m_diagCoeffImpl.coeff(outer))
+ {}
+
+ EIGEN_STRONG_INLINE Scalar value() const { return m_coeff * SparseXprInnerIterator::value(); }
+ protected:
+ typename DiagonalCoeffType::Scalar m_coeff;
+ };
+
+ sparse_diagonal_product_evaluator(const SparseXprType &sparseXpr, const DiagonalCoeffType &diagCoeff)
+ : m_sparseXprImpl(sparseXpr), m_diagCoeffImpl(diagCoeff)
+ {}
+
+protected:
+ typename evaluator<SparseXprType>::nestedType m_sparseXprImpl;
+ typename evaluator<DiagonalCoeffType>::nestedType m_diagCoeffImpl;
+};
+
+
+template<typename SparseXprType, typename DiagCoeffType>
+struct sparse_diagonal_product_evaluator<SparseXprType, DiagCoeffType, SDP_AsCwiseProduct>
+{
+ typedef typename SparseXprType::Scalar Scalar;
+ typedef typename SparseXprType::Index Index;
+
+ typedef CwiseBinaryOp<scalar_product_op<Scalar>,
+ const typename SparseXprType::ConstInnerVectorReturnType,
+ const DiagCoeffType> CwiseProductType;
+
+ typedef typename evaluator<CwiseProductType>::type CwiseProductEval;
+ typedef typename evaluator<CwiseProductType>::InnerIterator CwiseProductIterator;
+
+ class InnerIterator : public CwiseProductIterator
+ {
+ public:
+ InnerIterator(const sparse_diagonal_product_evaluator &xprEval, Index outer)
+ : CwiseProductIterator(CwiseProductEval(xprEval.m_sparseXprNested.innerVector(outer).cwiseProduct(xprEval.m_diagCoeffNested)),0),
+ m_cwiseEval(xprEval.m_sparseXprNested.innerVector(outer).cwiseProduct(xprEval.m_diagCoeffNested)),
+ m_outer(outer)
+ {
+ ::new (static_cast<CwiseProductIterator*>(this)) CwiseProductIterator(m_cwiseEval,0);
+ }
+
+ inline Index outer() const { return m_outer; }
+ inline Index col() const { return SparseXprType::IsRowMajor ? CwiseProductIterator::index() : m_outer; }
+ inline Index row() const { return SparseXprType::IsRowMajor ? m_outer : CwiseProductIterator::index(); }
+
+ protected:
+ Index m_outer;
+ CwiseProductEval m_cwiseEval;
+ };
+
+ sparse_diagonal_product_evaluator(const SparseXprType &sparseXpr, const DiagCoeffType &diagCoeff)
+ : m_sparseXprNested(sparseXpr), m_diagCoeffNested(diagCoeff)
+ {}
+
+protected:
+ typename nested_eval<SparseXprType,1>::type m_sparseXprNested;
+ typename nested_eval<DiagCoeffType,SparseXprType::IsRowMajor ? SparseXprType::RowsAtCompileTime
+ : SparseXprType::ColsAtCompileTime>::type m_diagCoeffNested;
+};
+
+#endif // EIGEN_TEST_EVALUATORS
+
+
} // end namespace internal
-// SparseMatrixBase functions
+#ifndef EIGEN_TEST_EVALUATORS
+// SparseMatrixBase functions
template<typename Derived>
template<typename OtherDerived>
const SparseDiagonalProduct<Derived,OtherDerived>
@@ -190,6 +311,7 @@ SparseMatrixBase<Derived>::operator*(const DiagonalBase<OtherDerived> &other) co
{
return SparseDiagonalProduct<Derived,OtherDerived>(this->derived(), other.derived());
}
+#endif // EIGEN_TEST_EVALUATORS
} // end namespace Eigen
diff --git a/Eigen/src/SparseCore/SparseMatrixBase.h b/Eigen/src/SparseCore/SparseMatrixBase.h
index 4c46008ae..c71244d3e 100644
--- a/Eigen/src/SparseCore/SparseMatrixBase.h
+++ b/Eigen/src/SparseCore/SparseMatrixBase.h
@@ -269,6 +269,7 @@ template<typename Derived> class SparseMatrixBase : public EigenBase<Derived>
const typename SparseSparseProductReturnType<Derived,OtherDerived>::Type
operator*(const SparseMatrixBase<OtherDerived> &other) const;
+#ifndef EIGEN_TEST_EVALUATORS
// sparse * diagonal
template<typename OtherDerived>
const SparseDiagonalProduct<Derived,OtherDerived>
@@ -279,6 +280,19 @@ template<typename Derived> class SparseMatrixBase : public EigenBase<Derived>
const SparseDiagonalProduct<OtherDerived,Derived>
operator*(const DiagonalBase<OtherDerived> &lhs, const SparseMatrixBase& rhs)
{ return SparseDiagonalProduct<OtherDerived,Derived>(lhs.derived(), rhs.derived()); }
+#else // EIGEN_TEST_EVALUATORS
+ // sparse * diagonal
+ template<typename OtherDerived>
+ const Product<Derived,OtherDerived>
+ operator*(const DiagonalBase<OtherDerived> &other) const
+ { return Product<Derived,OtherDerived>(derived(), other.derived()); }
+
+ // diagonal * sparse
+ template<typename OtherDerived> friend
+ const Product<OtherDerived,Derived>
+ operator*(const DiagonalBase<OtherDerived> &lhs, const SparseMatrixBase& rhs)
+ { return Product<OtherDerived,Derived>(lhs.derived(), rhs.derived()); }
+#endif // EIGEN_TEST_EVALUATORS
/** dense * sparse (return a dense object unless it is an outer product) */
template<typename OtherDerived> friend