aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen
diff options
context:
space:
mode:
authorGravatar Gael Guennebaud <g.gael@free.fr>2014-07-01 17:53:18 +0200
committerGravatar Gael Guennebaud <g.gael@free.fr>2014-07-01 17:53:18 +0200
commit7390af91b63e4f250bddd5971eab44bae3a89f32 (patch)
treeb5e16acde6c549f1ff45833dc559ce0a2dabdaaa /Eigen
parent1e6f53e070ffb3d386bea3cda5e37569c0f11b37 (diff)
Implement evaluators for sparse*dense products
Diffstat (limited to 'Eigen')
-rw-r--r--Eigen/SparseCore2
-rw-r--r--Eigen/src/SparseCore/SparseDenseProduct.h334
-rw-r--r--Eigen/src/SparseCore/SparseMatrixBase.h34
3 files changed, 252 insertions, 118 deletions
diff --git a/Eigen/SparseCore b/Eigen/SparseCore
index f74df3038..7cbfb47f2 100644
--- a/Eigen/SparseCore
+++ b/Eigen/SparseCore
@@ -52,10 +52,10 @@ struct Sparse {};
#include "src/SparseCore/ConservativeSparseSparseProduct.h"
#include "src/SparseCore/SparseSparseProductWithPruning.h"
#include "src/SparseCore/SparseProduct.h"
+#include "src/SparseCore/SparseDenseProduct.h"
#ifndef EIGEN_TEST_EVALUATORS
#include "src/SparseCore/SparsePermutation.h"
#include "src/SparseCore/SparseFuzzy.h"
-#include "src/SparseCore/SparseDenseProduct.h"
#include "src/SparseCore/SparseTriangularView.h"
#include "src/SparseCore/SparseSelfAdjointView.h"
#include "src/SparseCore/TriangularSolver.h"
diff --git a/Eigen/src/SparseCore/SparseDenseProduct.h b/Eigen/src/SparseCore/SparseDenseProduct.h
index 610833f3b..2a23365c6 100644
--- a/Eigen/src/SparseCore/SparseDenseProduct.h
+++ b/Eigen/src/SparseCore/SparseDenseProduct.h
@@ -1,7 +1,7 @@
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
-// Copyright (C) 2008-2010 Gael Guennebaud <gael.guennebaud@inria.fr>
+// Copyright (C) 2008-2014 Gael Guennebaud <gael.guennebaud@inria.fr>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
@@ -12,6 +12,152 @@
namespace Eigen {
+namespace internal {
+
+template<typename SparseLhsType, typename DenseRhsType, typename DenseResType,
+ typename AlphaType,
+ int LhsStorageOrder = ((SparseLhsType::Flags&RowMajorBit)==RowMajorBit) ? RowMajor : ColMajor,
+ bool ColPerCol = ((DenseRhsType::Flags&RowMajorBit)==0) || DenseRhsType::ColsAtCompileTime==1>
+struct sparse_time_dense_product_impl;
+
+template<typename SparseLhsType, typename DenseRhsType, typename DenseResType>
+struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, typename DenseResType::Scalar, RowMajor, true>
+{
+ typedef typename internal::remove_all<SparseLhsType>::type Lhs;
+ typedef typename internal::remove_all<DenseRhsType>::type Rhs;
+ typedef typename internal::remove_all<DenseResType>::type Res;
+ typedef typename Lhs::Index Index;
+#ifndef EIGEN_TEST_EVALUATORS
+ typedef typename Lhs::InnerIterator LhsInnerIterator;
+#else
+ typedef typename evaluator<Lhs>::InnerIterator LhsInnerIterator;
+#endif
+ static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const typename Res::Scalar& alpha)
+ {
+#ifndef EIGEN_TEST_EVALUATORS
+ const Lhs &lhsEval(lhs);
+#else
+ typename evaluator<Lhs>::type lhsEval(lhs);
+#endif
+ for(Index c=0; c<rhs.cols(); ++c)
+ {
+ Index n = lhs.outerSize();
+ for(Index j=0; j<n; ++j)
+ {
+ typename Res::Scalar tmp(0);
+ for(LhsInnerIterator it(lhsEval,j); it ;++it)
+ tmp += it.value() * rhs.coeff(it.index(),c);
+ res.coeffRef(j,c) = alpha * tmp;
+ }
+ }
+ }
+};
+
+template<typename T1, typename T2/*, int _Options, typename _StrideType*/>
+struct scalar_product_traits<T1, Ref<T2/*, _Options, _StrideType*/> >
+{
+ enum {
+ Defined = 1
+ };
+ typedef typename CwiseUnaryOp<scalar_multiple2_op<T1, typename T2::Scalar>, T2>::PlainObject ReturnType;
+};
+template<typename SparseLhsType, typename DenseRhsType, typename DenseResType, typename AlphaType>
+struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, AlphaType, ColMajor, true>
+{
+ typedef typename internal::remove_all<SparseLhsType>::type Lhs;
+ typedef typename internal::remove_all<DenseRhsType>::type Rhs;
+ typedef typename internal::remove_all<DenseResType>::type Res;
+ typedef typename Lhs::Index Index;
+#ifndef EIGEN_TEST_EVALUATORS
+ typedef typename Lhs::InnerIterator LhsInnerIterator;
+#else
+ typedef typename evaluator<Lhs>::InnerIterator LhsInnerIterator;
+#endif
+ static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const AlphaType& alpha)
+ {
+#ifndef EIGEN_TEST_EVALUATORS
+ const Lhs &lhsEval(lhs);
+#else
+ typename evaluator<Lhs>::type lhsEval(lhs);
+#endif
+ for(Index c=0; c<rhs.cols(); ++c)
+ {
+ for(Index j=0; j<lhs.outerSize(); ++j)
+ {
+// typename Res::Scalar rhs_j = alpha * rhs.coeff(j,c);
+ typename internal::scalar_product_traits<AlphaType, typename Rhs::Scalar>::ReturnType rhs_j(alpha * rhs.coeff(j,c));
+ for(LhsInnerIterator it(lhsEval,j); it ;++it)
+ res.coeffRef(it.index(),c) += it.value() * rhs_j;
+ }
+ }
+ }
+};
+
+template<typename SparseLhsType, typename DenseRhsType, typename DenseResType>
+struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, typename DenseResType::Scalar, RowMajor, false>
+{
+ typedef typename internal::remove_all<SparseLhsType>::type Lhs;
+ typedef typename internal::remove_all<DenseRhsType>::type Rhs;
+ typedef typename internal::remove_all<DenseResType>::type Res;
+ typedef typename Lhs::Index Index;
+#ifndef EIGEN_TEST_EVALUATORS
+ typedef typename Lhs::InnerIterator LhsInnerIterator;
+#else
+ typedef typename evaluator<Lhs>::InnerIterator LhsInnerIterator;
+#endif
+ static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const typename Res::Scalar& alpha)
+ {
+#ifndef EIGEN_TEST_EVALUATORS
+ const Lhs &lhsEval(lhs);
+#else
+ typename evaluator<Lhs>::type lhsEval(lhs);
+#endif
+ for(Index j=0; j<lhs.outerSize(); ++j)
+ {
+ typename Res::RowXpr res_j(res.row(j));
+ for(LhsInnerIterator it(lhsEval,j); it ;++it)
+ res_j += (alpha*it.value()) * rhs.row(it.index());
+ }
+ }
+};
+
+template<typename SparseLhsType, typename DenseRhsType, typename DenseResType>
+struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, typename DenseResType::Scalar, ColMajor, false>
+{
+ typedef typename internal::remove_all<SparseLhsType>::type Lhs;
+ typedef typename internal::remove_all<DenseRhsType>::type Rhs;
+ typedef typename internal::remove_all<DenseResType>::type Res;
+ typedef typename Lhs::Index Index;
+#ifndef EIGEN_TEST_EVALUATORS
+ typedef typename Lhs::InnerIterator LhsInnerIterator;
+#else
+ typedef typename evaluator<Lhs>::InnerIterator LhsInnerIterator;
+#endif
+ static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const typename Res::Scalar& alpha)
+ {
+#ifndef EIGEN_TEST_EVALUATORS
+ const Lhs &lhsEval(lhs);
+#else
+ typename evaluator<Lhs>::type lhsEval(lhs);
+#endif
+ for(Index j=0; j<lhs.outerSize(); ++j)
+ {
+ typename Rhs::ConstRowXpr rhs_j(rhs.row(j));
+ for(LhsInnerIterator it(lhsEval,j); it ;++it)
+ res.row(it.index()) += (alpha*it.value()) * rhs_j;
+ }
+ }
+};
+
+template<typename SparseLhsType, typename DenseRhsType, typename DenseResType,typename AlphaType>
+inline void sparse_time_dense_product(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const AlphaType& alpha)
+{
+ sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, AlphaType>::run(lhs, rhs, res, alpha);
+}
+
+} // end namespace internal
+
+#ifndef EIGEN_TEST_EVALUATORS
template<typename Lhs, typename Rhs, int InnerSize> struct SparseDenseProductReturnType
{
typedef SparseTimeDenseProduct<Lhs,Rhs> Type;
@@ -138,111 +284,6 @@ struct traits<SparseTimeDenseProduct<Lhs,Rhs> >
typedef MatrixXpr XprKind;
};
-template<typename SparseLhsType, typename DenseRhsType, typename DenseResType,
- typename AlphaType,
- int LhsStorageOrder = ((SparseLhsType::Flags&RowMajorBit)==RowMajorBit) ? RowMajor : ColMajor,
- bool ColPerCol = ((DenseRhsType::Flags&RowMajorBit)==0) || DenseRhsType::ColsAtCompileTime==1>
-struct sparse_time_dense_product_impl;
-
-template<typename SparseLhsType, typename DenseRhsType, typename DenseResType>
-struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, typename DenseResType::Scalar, RowMajor, true>
-{
- typedef typename internal::remove_all<SparseLhsType>::type Lhs;
- typedef typename internal::remove_all<DenseRhsType>::type Rhs;
- typedef typename internal::remove_all<DenseResType>::type Res;
- typedef typename Lhs::Index Index;
- typedef typename Lhs::InnerIterator LhsInnerIterator;
- static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const typename Res::Scalar& alpha)
- {
- for(Index c=0; c<rhs.cols(); ++c)
- {
- Index n = lhs.outerSize();
- for(Index j=0; j<n; ++j)
- {
- typename Res::Scalar tmp(0);
- for(LhsInnerIterator it(lhs,j); it ;++it)
- tmp += it.value() * rhs.coeff(it.index(),c);
- res.coeffRef(j,c) = alpha * tmp;
- }
- }
- }
-};
-
-template<typename T1, typename T2/*, int _Options, typename _StrideType*/>
-struct scalar_product_traits<T1, Ref<T2/*, _Options, _StrideType*/> >
-{
- enum {
- Defined = 1
- };
- typedef typename CwiseUnaryOp<scalar_multiple2_op<T1, typename T2::Scalar>, T2>::PlainObject ReturnType;
-};
-template<typename SparseLhsType, typename DenseRhsType, typename DenseResType, typename AlphaType>
-struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, AlphaType, ColMajor, true>
-{
- typedef typename internal::remove_all<SparseLhsType>::type Lhs;
- typedef typename internal::remove_all<DenseRhsType>::type Rhs;
- typedef typename internal::remove_all<DenseResType>::type Res;
- typedef typename Lhs::InnerIterator LhsInnerIterator;
- typedef typename Lhs::Index Index;
- static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const AlphaType& alpha)
- {
- for(Index c=0; c<rhs.cols(); ++c)
- {
- for(Index j=0; j<lhs.outerSize(); ++j)
- {
-// typename Res::Scalar rhs_j = alpha * rhs.coeff(j,c);
- typename internal::scalar_product_traits<AlphaType, typename Rhs::Scalar>::ReturnType rhs_j(alpha * rhs.coeff(j,c));
- for(LhsInnerIterator it(lhs,j); it ;++it)
- res.coeffRef(it.index(),c) += it.value() * rhs_j;
- }
- }
- }
-};
-
-template<typename SparseLhsType, typename DenseRhsType, typename DenseResType>
-struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, typename DenseResType::Scalar, RowMajor, false>
-{
- typedef typename internal::remove_all<SparseLhsType>::type Lhs;
- typedef typename internal::remove_all<DenseRhsType>::type Rhs;
- typedef typename internal::remove_all<DenseResType>::type Res;
- typedef typename Lhs::InnerIterator LhsInnerIterator;
- typedef typename Lhs::Index Index;
- static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const typename Res::Scalar& alpha)
- {
- for(Index j=0; j<lhs.outerSize(); ++j)
- {
- typename Res::RowXpr res_j(res.row(j));
- for(LhsInnerIterator it(lhs,j); it ;++it)
- res_j += (alpha*it.value()) * rhs.row(it.index());
- }
- }
-};
-
-template<typename SparseLhsType, typename DenseRhsType, typename DenseResType>
-struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, typename DenseResType::Scalar, ColMajor, false>
-{
- typedef typename internal::remove_all<SparseLhsType>::type Lhs;
- typedef typename internal::remove_all<DenseRhsType>::type Rhs;
- typedef typename internal::remove_all<DenseResType>::type Res;
- typedef typename Lhs::InnerIterator LhsInnerIterator;
- typedef typename Lhs::Index Index;
- static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const typename Res::Scalar& alpha)
- {
- for(Index j=0; j<lhs.outerSize(); ++j)
- {
- typename Rhs::ConstRowXpr rhs_j(rhs.row(j));
- for(LhsInnerIterator it(lhs,j); it ;++it)
- res.row(it.index()) += (alpha*it.value()) * rhs_j;
- }
- }
-};
-
-template<typename SparseLhsType, typename DenseRhsType, typename DenseResType,typename AlphaType>
-inline void sparse_time_dense_product(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const AlphaType& alpha)
-{
- sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, AlphaType>::run(lhs, rhs, res, alpha);
-}
-
} // end namespace internal
template<typename Lhs, typename Rhs>
@@ -305,6 +346,87 @@ SparseMatrixBase<Derived>::operator*(const MatrixBase<OtherDerived> &other) cons
{
return typename SparseDenseProductReturnType<Derived,OtherDerived>::Type(derived(), other.derived());
}
+#endif // EIGEN_TEST_EVALUATORS
+
+#ifdef EIGEN_TEST_EVALUATORS
+
+namespace internal {
+
+template<typename Lhs, typename Rhs, int ProductType>
+struct generic_product_impl<Lhs, Rhs, SparseShape, DenseShape, ProductType>
+{
+ template<typename Dest>
+ static void evalTo(Dest& dst, const Lhs& lhs, const Rhs& rhs)
+ {
+ typedef typename nested_eval<Lhs,Dynamic>::type LhsNested;
+ typedef typename nested_eval<Rhs,Dynamic>::type RhsNested;
+ LhsNested lhsNested(lhs);
+ RhsNested rhsNested(rhs);
+
+ dst.setZero();
+ internal::sparse_time_dense_product(lhsNested, rhsNested, dst, typename Dest::Scalar(1));
+ }
+};
+
+template<typename Lhs, typename Rhs, int ProductType>
+struct generic_product_impl<Lhs, Rhs, DenseShape, SparseShape, ProductType>
+{
+ template<typename Dest>
+ static void evalTo(Dest& dst, const Lhs& lhs, const Rhs& rhs)
+ {
+ typedef typename nested_eval<Lhs,Dynamic>::type LhsNested;
+ typedef typename nested_eval<Rhs,Dynamic>::type RhsNested;
+ LhsNested lhsNested(lhs);
+ RhsNested rhsNested(rhs);
+
+ dst.setZero();
+ // transpoe everything
+ Transpose<Dest> dstT(dst);
+ internal::sparse_time_dense_product(rhsNested.transpose(), lhsNested.transpose(), dstT, typename Dest::Scalar(1));
+ }
+};
+
+template<typename Lhs, typename Rhs, int ProductTag>
+struct product_evaluator<Product<Lhs, Rhs, DefaultProduct>, ProductTag, SparseShape, DenseShape, typename Lhs::Scalar, typename Rhs::Scalar>
+ : public evaluator<typename Product<Lhs, Rhs, DefaultProduct>::PlainObject>::type
+{
+ typedef Product<Lhs, Rhs, DefaultProduct> XprType;
+ typedef typename XprType::PlainObject PlainObject;
+ typedef typename evaluator<PlainObject>::type Base;
+
+ product_evaluator(const XprType& xpr)
+ : m_result(xpr.rows(), xpr.cols())
+ {
+ ::new (static_cast<Base*>(this)) Base(m_result);
+ generic_product_impl<Lhs, Rhs, SparseShape, DenseShape, ProductTag>::evalTo(m_result, xpr.lhs(), xpr.rhs());
+ }
+
+protected:
+ PlainObject m_result;
+};
+
+template<typename Lhs, typename Rhs, int ProductTag>
+struct product_evaluator<Product<Lhs, Rhs, DefaultProduct>, ProductTag, DenseShape, SparseShape, typename Lhs::Scalar, typename Rhs::Scalar>
+ : public evaluator<typename Product<Lhs, Rhs, DefaultProduct>::PlainObject>::type
+{
+ typedef Product<Lhs, Rhs, DefaultProduct> XprType;
+ typedef typename XprType::PlainObject PlainObject;
+ typedef typename evaluator<PlainObject>::type Base;
+
+ product_evaluator(const XprType& xpr)
+ : m_result(xpr.rows(), xpr.cols())
+ {
+ ::new (static_cast<Base*>(this)) Base(m_result);
+ generic_product_impl<Lhs, Rhs, DenseShape, SparseShape, ProductTag>::evalTo(m_result, xpr.lhs(), xpr.rhs());
+ }
+
+protected:
+ PlainObject m_result;
+};
+
+} // end namespace internal
+
+#endif // EIGEN_TEST_EVALUATORS
} // end namespace Eigen
diff --git a/Eigen/src/SparseCore/SparseMatrixBase.h b/Eigen/src/SparseCore/SparseMatrixBase.h
index 3a81916fb..3bc5af86d 100644
--- a/Eigen/src/SparseCore/SparseMatrixBase.h
+++ b/Eigen/src/SparseCore/SparseMatrixBase.h
@@ -282,6 +282,17 @@ template<typename Derived> class SparseMatrixBase : public EigenBase<Derived>
const SparseDiagonalProduct<OtherDerived,Derived>
operator*(const DiagonalBase<OtherDerived> &lhs, const SparseMatrixBase& rhs)
{ return SparseDiagonalProduct<OtherDerived,Derived>(lhs.derived(), rhs.derived()); }
+
+ /** dense * sparse (return a dense object unless it is an outer product) */
+ template<typename OtherDerived> friend
+ const typename DenseSparseProductReturnType<OtherDerived,Derived>::Type
+ operator*(const MatrixBase<OtherDerived>& lhs, const Derived& rhs)
+ { return typename DenseSparseProductReturnType<OtherDerived,Derived>::Type(lhs.derived(),rhs); }
+
+ /** sparse * dense (returns a dense object unless it is an outer product) */
+ template<typename OtherDerived>
+ const typename SparseDenseProductReturnType<Derived,OtherDerived>::Type
+ operator*(const MatrixBase<OtherDerived> &other) const;
#else // EIGEN_TEST_EVALUATORS
// sparse * diagonal
template<typename OtherDerived>
@@ -299,18 +310,19 @@ template<typename Derived> class SparseMatrixBase : public EigenBase<Derived>
template<typename OtherDerived>
const Product<Derived,OtherDerived>
operator*(const SparseMatrixBase<OtherDerived> &other) const;
-#endif // EIGEN_TEST_EVALUATORS
-
- /** dense * sparse (return a dense object unless it is an outer product) */
- template<typename OtherDerived> friend
- const typename DenseSparseProductReturnType<OtherDerived,Derived>::Type
- operator*(const MatrixBase<OtherDerived>& lhs, const Derived& rhs)
- { return typename DenseSparseProductReturnType<OtherDerived,Derived>::Type(lhs.derived(),rhs); }
-
- /** sparse * dense (returns a dense object unless it is an outer product) */
+
+ // sparse * dense
template<typename OtherDerived>
- const typename SparseDenseProductReturnType<Derived,OtherDerived>::Type
- operator*(const MatrixBase<OtherDerived> &other) const;
+ const Product<Derived,OtherDerived>
+ operator*(const MatrixBase<OtherDerived> &other) const
+ { return Product<Derived,OtherDerived>(derived(), other.derived()); }
+
+ // dense * sparse
+ template<typename OtherDerived> friend
+ const Product<OtherDerived,Derived>
+ operator*(const MatrixBase<OtherDerived> &lhs, const SparseMatrixBase& rhs)
+ { return Product<OtherDerived,Derived>(lhs.derived(), rhs.derived()); }
+#endif // EIGEN_TEST_EVALUATORS
/** \returns an expression of P H P^-1 where H is the matrix represented by \c *this */
SparseSymmetricPermutationProduct<Derived,Upper|Lower> twistedBy(const PermutationMatrix<Dynamic,Dynamic,Index>& perm) const