aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen/src/SparseCore/SparseDiagonalProduct.h
diff options
context:
space:
mode:
Diffstat (limited to 'Eigen/src/SparseCore/SparseDiagonalProduct.h')
-rw-r--r--Eigen/src/SparseCore/SparseDiagonalProduct.h133
1 files changed, 130 insertions, 3 deletions
diff --git a/Eigen/src/SparseCore/SparseDiagonalProduct.h b/Eigen/src/SparseCore/SparseDiagonalProduct.h
index c056b4914..9f465a828 100644
--- a/Eigen/src/SparseCore/SparseDiagonalProduct.h
+++ b/Eigen/src/SparseCore/SparseDiagonalProduct.h
@@ -1,7 +1,7 @@
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
-// Copyright (C) 2009 Gael Guennebaud <gael.guennebaud@inria.fr>
+// Copyright (C) 2009-2014 Gael Guennebaud <gael.guennebaud@inria.fr>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
@@ -24,8 +24,10 @@ namespace Eigen {
// for that particular case
// The two other cases are symmetric.
+#ifndef EIGEN_TEST_EVALUATORS
+
namespace internal {
-
+
template<typename Lhs, typename Rhs>
struct traits<SparseDiagonalProduct<Lhs, Rhs> >
{
@@ -102,9 +104,14 @@ class SparseDiagonalProduct
LhsNested m_lhs;
RhsNested m_rhs;
};
+#endif
namespace internal {
+#ifndef EIGEN_TEST_EVALUATORS
+
+
+
template<typename Lhs, typename Rhs, typename SparseDiagonalProductType>
class sparse_diagonal_product_inner_iterator_selector
<Lhs,Rhs,SparseDiagonalProductType,SDP_IsDiagonal,SDP_IsSparseRowMajor>
@@ -181,10 +188,129 @@ class sparse_diagonal_product_inner_iterator_selector
inline Index row() const { return m_outer; }
};
+#else // EIGEN_TEST_EVALUATORS
+enum {
+ SDP_AsScalarProduct,
+ SDP_AsCwiseProduct
+};
+
+template<typename SparseXprType, typename DiagonalCoeffType, int SDP_Tag>
+struct sparse_diagonal_product_evaluator;
+
+template<typename Lhs, typename Rhs, int ProductTag>
+struct product_evaluator<Product<Lhs, Rhs, DefaultProduct>, ProductTag, DiagonalShape, SparseShape, typename traits<Lhs>::Scalar, typename traits<Rhs>::Scalar>
+ : public sparse_diagonal_product_evaluator<Rhs, typename Lhs::DiagonalVectorType, Rhs::Flags&RowMajorBit?SDP_AsScalarProduct:SDP_AsCwiseProduct>
+{
+ typedef Product<Lhs, Rhs, DefaultProduct> XprType;
+ typedef evaluator<XprType> type;
+ typedef evaluator<XprType> nestedType;
+ enum { CoeffReadCost = Dynamic, Flags = Rhs::Flags&RowMajorBit }; // FIXME CoeffReadCost & Flags
+
+ typedef sparse_diagonal_product_evaluator<Rhs, typename Lhs::DiagonalVectorType, Rhs::Flags&RowMajorBit?SDP_AsScalarProduct:SDP_AsCwiseProduct> Base;
+ product_evaluator(const XprType& xpr) : Base(xpr.rhs(), xpr.lhs().diagonal()) {}
+};
+
+template<typename Lhs, typename Rhs, int ProductTag>
+struct product_evaluator<Product<Lhs, Rhs, DefaultProduct>, ProductTag, SparseShape, DiagonalShape, typename traits<Lhs>::Scalar, typename traits<Rhs>::Scalar>
+ : public sparse_diagonal_product_evaluator<Lhs, Transpose<const typename Rhs::DiagonalVectorType>, Lhs::Flags&RowMajorBit?SDP_AsCwiseProduct:SDP_AsScalarProduct>
+{
+ typedef Product<Lhs, Rhs, DefaultProduct> XprType;
+ typedef evaluator<XprType> type;
+ typedef evaluator<XprType> nestedType;
+ enum { CoeffReadCost = Dynamic, Flags = Lhs::Flags&RowMajorBit }; // FIXME CoeffReadCost & Flags
+
+ typedef sparse_diagonal_product_evaluator<Lhs, Transpose<const typename Rhs::DiagonalVectorType>, Lhs::Flags&RowMajorBit?SDP_AsCwiseProduct:SDP_AsScalarProduct> Base;
+ product_evaluator(const XprType& xpr) : Base(xpr.lhs(), xpr.rhs().diagonal()) {}
+};
+
+template<typename SparseXprType, typename DiagonalCoeffType>
+struct sparse_diagonal_product_evaluator<SparseXprType, DiagonalCoeffType, SDP_AsScalarProduct>
+{
+protected:
+ typedef typename evaluator<SparseXprType>::InnerIterator SparseXprInnerIterator;
+ typedef typename SparseXprType::Scalar Scalar;
+ typedef typename SparseXprType::Index Index;
+
+public:
+ class InnerIterator : public SparseXprInnerIterator
+ {
+ public:
+ InnerIterator(const sparse_diagonal_product_evaluator &xprEval, Index outer)
+ : SparseXprInnerIterator(xprEval.m_sparseXprImpl, outer),
+ m_coeff(xprEval.m_diagCoeffImpl.coeff(outer))
+ {}
+
+ EIGEN_STRONG_INLINE Scalar value() const { return m_coeff * SparseXprInnerIterator::value(); }
+ protected:
+ typename DiagonalCoeffType::Scalar m_coeff;
+ };
+
+ sparse_diagonal_product_evaluator(const SparseXprType &sparseXpr, const DiagonalCoeffType &diagCoeff)
+ : m_sparseXprImpl(sparseXpr), m_diagCoeffImpl(diagCoeff)
+ {}
+
+protected:
+ typename evaluator<SparseXprType>::nestedType m_sparseXprImpl;
+ typename evaluator<DiagonalCoeffType>::nestedType m_diagCoeffImpl;
+};
+
+
+template<typename SparseXprType, typename DiagCoeffType>
+struct sparse_diagonal_product_evaluator<SparseXprType, DiagCoeffType, SDP_AsCwiseProduct>
+{
+ typedef typename SparseXprType::Scalar Scalar;
+ typedef typename SparseXprType::Index Index;
+
+ typedef CwiseBinaryOp<scalar_product_op<Scalar>,
+ const typename SparseXprType::ConstInnerVectorReturnType,
+ const DiagCoeffType> CwiseProductType;
+
+ typedef typename evaluator<CwiseProductType>::type CwiseProductEval;
+ typedef typename evaluator<CwiseProductType>::InnerIterator CwiseProductIterator;
+
+ class InnerIterator
+ {
+ public:
+ InnerIterator(const sparse_diagonal_product_evaluator &xprEval, Index outer)
+ : m_cwiseEval(xprEval.m_sparseXprNested.innerVector(outer).cwiseProduct(xprEval.m_diagCoeffNested)),
+ m_cwiseIter(m_cwiseEval, 0),
+ m_outer(outer)
+ {}
+
+ inline Scalar value() const { return m_cwiseIter.value(); }
+ inline Index index() const { return m_cwiseIter.index(); }
+ inline Index outer() const { return m_outer; }
+ inline Index col() const { return SparseXprType::IsRowMajor ? m_cwiseIter.index() : m_outer; }
+ inline Index row() const { return SparseXprType::IsRowMajor ? m_outer : m_cwiseIter.index(); }
+
+ EIGEN_STRONG_INLINE InnerIterator& operator++()
+ { ++m_cwiseIter; return *this; }
+ inline operator bool() const { return m_cwiseIter; }
+
+ protected:
+ CwiseProductEval m_cwiseEval;
+ CwiseProductIterator m_cwiseIter;
+ Index m_outer;
+ };
+
+ sparse_diagonal_product_evaluator(const SparseXprType &sparseXpr, const DiagCoeffType &diagCoeff)
+ : m_sparseXprNested(sparseXpr), m_diagCoeffNested(diagCoeff)
+ {}
+
+protected:
+ typename nested_eval<SparseXprType,1>::type m_sparseXprNested;
+ typename nested_eval<DiagCoeffType,SparseXprType::IsRowMajor ? SparseXprType::RowsAtCompileTime
+ : SparseXprType::ColsAtCompileTime>::type m_diagCoeffNested;
+};
+
+#endif // EIGEN_TEST_EVALUATORS
+
+
} // end namespace internal
-// SparseMatrixBase functions
+#ifndef EIGEN_TEST_EVALUATORS
+// SparseMatrixBase functions
template<typename Derived>
template<typename OtherDerived>
const SparseDiagonalProduct<Derived,OtherDerived>
@@ -192,6 +318,7 @@ SparseMatrixBase<Derived>::operator*(const DiagonalBase<OtherDerived> &other) co
{
return SparseDiagonalProduct<Derived,OtherDerived>(this->derived(), other.derived());
}
+#endif // EIGEN_TEST_EVALUATORS
} // end namespace Eigen