diff options
author | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2016-04-11 17:20:17 -0700 |
---|---|---|
committer | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2016-04-11 17:20:17 -0700 |
commit | d6e596174d09446236b3f398d8ec39148c638ed9 (patch) | |
tree | ccb4116b05dc11d7931bac0129fd1394abe1e0b0 /Eigen/src/SparseCore/SparseCwiseBinaryOp.h | |
parent | 3ca1ae2bb761d7738bcdad885639f422a6b7c914 (diff) | |
parent | 833efb39bfe4957934982112fe435ab30a0c3b4f (diff) |
Pull latest updates from upstream
Diffstat (limited to 'Eigen/src/SparseCore/SparseCwiseBinaryOp.h')
-rw-r--r-- | Eigen/src/SparseCore/SparseCwiseBinaryOp.h | 221 |
1 files changed, 210 insertions, 11 deletions
diff --git a/Eigen/src/SparseCore/SparseCwiseBinaryOp.h b/Eigen/src/SparseCore/SparseCwiseBinaryOp.h index d9420ac63..c57d9ac59 100644 --- a/Eigen/src/SparseCore/SparseCwiseBinaryOp.h +++ b/Eigen/src/SparseCore/SparseCwiseBinaryOp.h @@ -49,17 +49,10 @@ class CwiseBinaryOpImpl<BinaryOp, Lhs, Rhs, Sparse> namespace internal { -template<typename BinaryOp, typename Lhs, typename Rhs, typename Derived, - typename _LhsStorageMode = typename traits<Lhs>::StorageKind, - typename _RhsStorageMode = typename traits<Rhs>::StorageKind> -class sparse_cwise_binary_op_inner_iterator_selector; - -} // end namespace internal - -namespace internal { - // Generic "sparse OP sparse" +template<typename XprType> struct binary_sparse_evaluator; + template<typename BinaryOp, typename Lhs, typename Rhs> struct binary_evaluator<CwiseBinaryOp<BinaryOp, Lhs, Rhs>, IteratorBased, IteratorBased> : evaluator_base<CwiseBinaryOp<BinaryOp, Lhs, Rhs> > @@ -153,6 +146,182 @@ protected: evaluator<Rhs> m_rhsImpl; }; +// dense op sparse +template<typename BinaryOp, typename Lhs, typename Rhs> +struct binary_evaluator<CwiseBinaryOp<BinaryOp, Lhs, Rhs>, IndexBased, IteratorBased> + : evaluator_base<CwiseBinaryOp<BinaryOp, Lhs, Rhs> > +{ +protected: + typedef typename evaluator<Rhs>::InnerIterator RhsIterator; + typedef CwiseBinaryOp<BinaryOp, Lhs, Rhs> XprType; + typedef typename traits<XprType>::Scalar Scalar; + typedef typename XprType::StorageIndex StorageIndex; +public: + + class ReverseInnerIterator; + class InnerIterator + { + enum { IsRowMajor = (int(Rhs::Flags)&RowMajorBit)==RowMajorBit }; + public: + + EIGEN_STRONG_INLINE InnerIterator(const binary_evaluator& aEval, Index outer) + : m_lhsEval(aEval.m_lhsImpl), m_rhsIter(aEval.m_rhsImpl,outer), m_functor(aEval.m_functor), m_id(-1), m_innerSize(aEval.m_expr.rhs().innerSize()) + { + this->operator++(); + } + + EIGEN_STRONG_INLINE InnerIterator& operator++() + { + ++m_id; + if(m_id<m_innerSize) + { + Scalar lhsVal = m_lhsEval.coeff(IsRowMajor?m_rhsIter.outer():m_id, + IsRowMajor?m_id:m_rhsIter.outer()); + if(m_rhsIter && m_rhsIter.index()==m_id) + { + m_value = m_functor(lhsVal, m_rhsIter.value()); + ++m_rhsIter; + } + else + m_value = m_functor(lhsVal, Scalar(0)); + } + + return *this; + } + + EIGEN_STRONG_INLINE Scalar value() const { return m_value; } + + EIGEN_STRONG_INLINE StorageIndex index() const { return m_id; } + EIGEN_STRONG_INLINE Index row() const { return IsRowMajor ? m_rhsIter.outer() : m_id; } + EIGEN_STRONG_INLINE Index col() const { return IsRowMajor ? m_id : m_rhsIter.outer(); } + + EIGEN_STRONG_INLINE operator bool() const { return m_id<m_innerSize; } + + protected: + const evaluator<Lhs> &m_lhsEval; + RhsIterator m_rhsIter; + const BinaryOp& m_functor; + Scalar m_value; + StorageIndex m_id; + StorageIndex m_innerSize; + }; + + + enum { + CoeffReadCost = evaluator<Lhs>::CoeffReadCost + evaluator<Rhs>::CoeffReadCost + functor_traits<BinaryOp>::Cost, + // Expose storage order of the sparse expression + Flags = (XprType::Flags & ~RowMajorBit) | (int(Rhs::Flags)&RowMajorBit) + }; + + explicit binary_evaluator(const XprType& xpr) + : m_functor(xpr.functor()), + m_lhsImpl(xpr.lhs()), + m_rhsImpl(xpr.rhs()), + m_expr(xpr) + { + EIGEN_INTERNAL_CHECK_COST_VALUE(functor_traits<BinaryOp>::Cost); + EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost); + } + + inline Index nonZerosEstimate() const { + return m_expr.size(); + } + +protected: + const BinaryOp m_functor; + evaluator<Lhs> m_lhsImpl; + evaluator<Rhs> m_rhsImpl; + const XprType &m_expr; +}; + +// sparse op dense +template<typename BinaryOp, typename Lhs, typename Rhs> +struct binary_evaluator<CwiseBinaryOp<BinaryOp, Lhs, Rhs>, IteratorBased, IndexBased> + : evaluator_base<CwiseBinaryOp<BinaryOp, Lhs, Rhs> > +{ +protected: + typedef typename evaluator<Lhs>::InnerIterator LhsIterator; + typedef CwiseBinaryOp<BinaryOp, Lhs, Rhs> XprType; + typedef typename traits<XprType>::Scalar Scalar; + typedef typename XprType::StorageIndex StorageIndex; +public: + + class ReverseInnerIterator; + class InnerIterator + { + enum { IsRowMajor = (int(Lhs::Flags)&RowMajorBit)==RowMajorBit }; + public: + + EIGEN_STRONG_INLINE InnerIterator(const binary_evaluator& aEval, Index outer) + : m_lhsIter(aEval.m_lhsImpl,outer), m_rhsEval(aEval.m_rhsImpl), m_functor(aEval.m_functor), m_id(-1), m_innerSize(aEval.m_expr.lhs().innerSize()) + { + this->operator++(); + } + + EIGEN_STRONG_INLINE InnerIterator& operator++() + { + ++m_id; + if(m_id<m_innerSize) + { + Scalar rhsVal = m_rhsEval.coeff(IsRowMajor?m_lhsIter.outer():m_id, + IsRowMajor?m_id:m_lhsIter.outer()); + if(m_lhsIter && m_lhsIter.index()==m_id) + { + m_value = m_functor(m_lhsIter.value(), rhsVal); + ++m_lhsIter; + } + else + m_value = m_functor(Scalar(0),rhsVal); + } + + return *this; + } + + EIGEN_STRONG_INLINE Scalar value() const { return m_value; } + + EIGEN_STRONG_INLINE StorageIndex index() const { return m_id; } + EIGEN_STRONG_INLINE Index row() const { return IsRowMajor ? m_lhsIter.outer() : m_id; } + EIGEN_STRONG_INLINE Index col() const { return IsRowMajor ? m_id : m_lhsIter.outer(); } + + EIGEN_STRONG_INLINE operator bool() const { return m_id<m_innerSize; } + + protected: + LhsIterator m_lhsIter; + const evaluator<Rhs> &m_rhsEval; + const BinaryOp& m_functor; + Scalar m_value; + StorageIndex m_id; + StorageIndex m_innerSize; + }; + + + enum { + CoeffReadCost = evaluator<Lhs>::CoeffReadCost + evaluator<Rhs>::CoeffReadCost + functor_traits<BinaryOp>::Cost, + // Expose storage order of the sparse expression + Flags = (XprType::Flags & ~RowMajorBit) | (int(Lhs::Flags)&RowMajorBit) + }; + + explicit binary_evaluator(const XprType& xpr) + : m_functor(xpr.functor()), + m_lhsImpl(xpr.lhs()), + m_rhsImpl(xpr.rhs()), + m_expr(xpr) + { + EIGEN_INTERNAL_CHECK_COST_VALUE(functor_traits<BinaryOp>::Cost); + EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost); + } + + inline Index nonZerosEstimate() const { + return m_expr.size(); + } + +protected: + const BinaryOp m_functor; + evaluator<Lhs> m_lhsImpl; + evaluator<Rhs> m_rhsImpl; + const XprType &m_expr; +}; + // "sparse .* sparse" template<typename T, typename Lhs, typename Rhs> struct binary_evaluator<CwiseBinaryOp<scalar_product_op<T>, Lhs, Rhs>, IteratorBased, IteratorBased> @@ -287,7 +456,8 @@ public: enum { CoeffReadCost = evaluator<Lhs>::CoeffReadCost + evaluator<Rhs>::CoeffReadCost + functor_traits<BinaryOp>::Cost, - Flags = XprType::Flags + // Expose storage order of the sparse expression + Flags = (XprType::Flags & ~RowMajorBit) | (int(Rhs::Flags)&RowMajorBit) }; explicit binary_evaluator(const XprType& xpr) @@ -360,7 +530,8 @@ public: enum { CoeffReadCost = evaluator<Lhs>::CoeffReadCost + evaluator<Rhs>::CoeffReadCost + functor_traits<BinaryOp>::Cost, - Flags = XprType::Flags + // Expose storage order of the sparse expression + Flags = (XprType::Flags & ~RowMajorBit) | (int(Lhs::Flags)&RowMajorBit) }; explicit binary_evaluator(const XprType& xpr) @@ -428,6 +599,34 @@ SparseMatrixBase<Derived>::cwiseProduct(const MatrixBase<OtherDerived> &other) c return typename CwiseProductDenseReturnType<OtherDerived>::Type(derived(), other.derived()); } +template<typename DenseDerived, typename SparseDerived> +EIGEN_STRONG_INLINE const CwiseBinaryOp<internal::scalar_sum_op<typename DenseDerived::Scalar>, const DenseDerived, const SparseDerived> +operator+(const MatrixBase<DenseDerived> &a, const SparseMatrixBase<SparseDerived> &b) +{ + return CwiseBinaryOp<internal::scalar_sum_op<typename DenseDerived::Scalar>, const DenseDerived, const SparseDerived>(a.derived(), b.derived()); +} + +template<typename SparseDerived, typename DenseDerived> +EIGEN_STRONG_INLINE const CwiseBinaryOp<internal::scalar_sum_op<typename DenseDerived::Scalar>, const SparseDerived, const DenseDerived> +operator+(const SparseMatrixBase<SparseDerived> &a, const MatrixBase<DenseDerived> &b) +{ + return CwiseBinaryOp<internal::scalar_sum_op<typename DenseDerived::Scalar>, const SparseDerived, const DenseDerived>(a.derived(), b.derived()); +} + +template<typename DenseDerived, typename SparseDerived> +EIGEN_STRONG_INLINE const CwiseBinaryOp<internal::scalar_difference_op<typename DenseDerived::Scalar>, const DenseDerived, const SparseDerived> +operator-(const MatrixBase<DenseDerived> &a, const SparseMatrixBase<SparseDerived> &b) +{ + return CwiseBinaryOp<internal::scalar_difference_op<typename DenseDerived::Scalar>, const DenseDerived, const SparseDerived>(a.derived(), b.derived()); +} + +template<typename SparseDerived, typename DenseDerived> +EIGEN_STRONG_INLINE const CwiseBinaryOp<internal::scalar_difference_op<typename DenseDerived::Scalar>, const SparseDerived, const DenseDerived> +operator-(const SparseMatrixBase<SparseDerived> &a, const MatrixBase<DenseDerived> &b) +{ + return CwiseBinaryOp<internal::scalar_difference_op<typename DenseDerived::Scalar>, const SparseDerived, const DenseDerived>(a.derived(), b.derived()); +} + } // end namespace Eigen #endif // EIGEN_SPARSE_CWISE_BINARY_OP_H |