aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen/src/SparseCore/SparseDenseProduct.h
diff options
context:
space:
mode:
authorGravatar Gael Guennebaud <g.gael@free.fr>2014-07-19 14:55:56 +0200
committerGravatar Gael Guennebaud <g.gael@free.fr>2014-07-19 14:55:56 +0200
commit3eba5e1101d8652483e1cf232a06dccf49a8a530 (patch)
tree7e735ece3f686637cd2d95c075c89ec2712ab2ea /Eigen/src/SparseCore/SparseDenseProduct.h
parent36e6c9064fc68d5c47473f6d251da10e96ad42b3 (diff)
Implement evaluator for sparse outer products
Diffstat (limited to 'Eigen/src/SparseCore/SparseDenseProduct.h')
-rw-r--r--Eigen/src/SparseCore/SparseDenseProduct.h155
1 files changed, 154 insertions, 1 deletions
diff --git a/Eigen/src/SparseCore/SparseDenseProduct.h b/Eigen/src/SparseCore/SparseDenseProduct.h
index 116edd62e..883e24acb 100644
--- a/Eigen/src/SparseCore/SparseDenseProduct.h
+++ b/Eigen/src/SparseCore/SparseDenseProduct.h
@@ -13,7 +13,10 @@
namespace Eigen {
namespace internal {
-
+
+template <> struct product_promote_storage_type<Sparse,Dense, OuterProduct> { typedef Sparse ret; };
+template <> struct product_promote_storage_type<Dense,Sparse, OuterProduct> { typedef Sparse ret; };
+
template<typename SparseLhsType, typename DenseRhsType, typename DenseResType,
typename AlphaType,
int LhsStorageOrder = ((SparseLhsType::Flags&RowMajorBit)==RowMajorBit) ? RowMajor : ColMajor,
@@ -445,6 +448,156 @@ protected:
PlainObject m_result;
};
+
+// template<typename Lhs, typename Rhs, bool Transpose, typename LhsIterator>
+// class sparse_dense_outer_product_iterator : public LhsIterator
+// {
+// typedef typename SparseDenseOuterProduct::Index Index;
+// public:
+// template<typename XprEval>
+// EIGEN_STRONG_INLINE InnerIterator(const XprEval& prod, Index outer)
+// : LhsIterator(prod.lhs(), 0),
+// m_outer(outer), m_empty(false), m_factor(get(prod.rhs(), outer, typename internal::traits<Rhs>::StorageKind() ))
+// {}
+//
+// inline Index outer() const { return m_outer; }
+// inline Index row() const { return Transpose ? m_outer : Base::index(); }
+// inline Index col() const { return Transpose ? Base::index() : m_outer; }
+//
+// inline Scalar value() const { return Base::value() * m_factor; }
+// inline operator bool() const { return Base::operator bool() && !m_empty; }
+//
+// protected:
+// Scalar get(const _RhsNested &rhs, Index outer, Dense = Dense()) const
+// {
+// return rhs.coeff(outer);
+// }
+//
+// Scalar get(const _RhsNested &rhs, Index outer, Sparse = Sparse())
+// {
+// typename Traits::_RhsNested::InnerIterator it(rhs, outer);
+// if (it && it.index()==0 && it.value()!=Scalar(0))
+// return it.value();
+// m_empty = true;
+// return Scalar(0);
+// }
+//
+// Index m_outer;
+// bool m_empty;
+// Scalar m_factor;
+// };
+
+template<typename LhsT, typename RhsT, bool Transpose>
+struct sparse_dense_outer_product_evaluator
+{
+protected:
+ typedef typename conditional<Transpose,RhsT,LhsT>::type Lhs1;
+ typedef typename conditional<Transpose,LhsT,RhsT>::type Rhs;
+ typedef Product<LhsT,RhsT> ProdXprType;
+
+ // if the actual left-hand side is a dense vector,
+ // then build a sparse-view so that we can seamlessly iterator over it.
+ typedef typename conditional<is_same<typename internal::traits<Lhs1>::StorageKind,Sparse>::value,
+ Lhs1, SparseView<Lhs1> >::type Lhs;
+ typedef typename conditional<is_same<typename internal::traits<Lhs1>::StorageKind,Sparse>::value,
+ Lhs1 const&, SparseView<Lhs1> >::type LhsArg;
+
+ typedef typename evaluator<Lhs>::type LhsEval;
+ typedef typename evaluator<Rhs>::type RhsEval;
+ typedef typename evaluator<Lhs>::InnerIterator LhsIterator;
+ typedef typename ProdXprType::Scalar Scalar;
+ typedef typename ProdXprType::Index Index;
+
+public:
+ enum {
+ Flags = Transpose ? RowMajorBit : 0,
+ CoeffReadCost = Dynamic
+ };
+
+ class InnerIterator : public LhsIterator
+ {
+ public:
+ InnerIterator(const sparse_dense_outer_product_evaluator &xprEval, Index outer)
+ : LhsIterator(xprEval.m_lhsXprImpl, 0),
+ m_outer(outer),
+ m_empty(false),
+ m_factor(get(xprEval.m_rhsXprImpl, outer, typename internal::traits<Rhs>::StorageKind() ))
+ {}
+
+ EIGEN_STRONG_INLINE Index outer() const { return m_outer; }
+ EIGEN_STRONG_INLINE Index row() const { return Transpose ? m_outer : LhsIterator::index(); }
+ EIGEN_STRONG_INLINE Index col() const { return Transpose ? LhsIterator::index() : m_outer; }
+
+ EIGEN_STRONG_INLINE Scalar value() const { return LhsIterator::value() * m_factor; }
+ EIGEN_STRONG_INLINE operator bool() const { return LhsIterator::operator bool() && (!m_empty); }
+
+
+ protected:
+ Scalar get(const RhsEval &rhs, Index outer, Dense = Dense()) const
+ {
+ return rhs.coeff(outer);
+ }
+
+ Scalar get(const RhsEval &rhs, Index outer, Sparse = Sparse())
+ {
+ typename RhsEval::InnerIterator it(rhs, outer);
+ if (it && it.index()==0 && it.value()!=Scalar(0))
+ return it.value();
+ m_empty = true;
+ return Scalar(0);
+ }
+
+ Index m_outer;
+ bool m_empty;
+ Scalar m_factor;
+ };
+
+ sparse_dense_outer_product_evaluator(const Lhs &lhs, const Rhs &rhs)
+ : m_lhs(lhs), m_lhsXprImpl(m_lhs), m_rhsXprImpl(rhs)
+ {}
+
+ // transpose case
+ sparse_dense_outer_product_evaluator(const Rhs &rhs, const Lhs1 &lhs)
+ : m_lhs(lhs), m_lhsXprImpl(m_lhs), m_rhsXprImpl(rhs)
+ {}
+
+protected:
+ const LhsArg m_lhs;
+ typename evaluator<Lhs>::nestedType m_lhsXprImpl;
+ typename evaluator<Rhs>::nestedType m_rhsXprImpl;
+};
+
+// sparse * dense outer product
+template<typename Lhs, typename Rhs>
+struct product_evaluator<Product<Lhs, Rhs, DefaultProduct>, OuterProduct, SparseShape, DenseShape, typename Lhs::Scalar, typename Rhs::Scalar>
+ : sparse_dense_outer_product_evaluator<Lhs,Rhs, Lhs::IsRowMajor>
+{
+ typedef sparse_dense_outer_product_evaluator<Lhs,Rhs, Lhs::IsRowMajor> Base;
+
+ typedef Product<Lhs, Rhs> XprType;
+ typedef typename XprType::PlainObject PlainObject;
+
+ product_evaluator(const XprType& xpr)
+ : Base(xpr.lhs(), xpr.rhs())
+ {}
+
+};
+
+template<typename Lhs, typename Rhs>
+struct product_evaluator<Product<Lhs, Rhs, DefaultProduct>, OuterProduct, DenseShape, SparseShape, typename Lhs::Scalar, typename Rhs::Scalar>
+ : sparse_dense_outer_product_evaluator<Lhs,Rhs, Rhs::IsRowMajor>
+{
+ typedef sparse_dense_outer_product_evaluator<Lhs,Rhs, Rhs::IsRowMajor> Base;
+
+ typedef Product<Lhs, Rhs> XprType;
+ typedef typename XprType::PlainObject PlainObject;
+
+ product_evaluator(const XprType& xpr)
+ : Base(xpr.lhs(), xpr.rhs())
+ {}
+
+};
+
} // end namespace internal
#endif // EIGEN_TEST_EVALUATORS