aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Gael Guennebaud <g.gael@free.fr>2015-11-04 17:42:07 +0100
committerGravatar Gael Guennebaud <g.gael@free.fr>2015-11-04 17:42:07 +0100
commit902750826b52de97f2cd48d91fcf4f70d874e93c (patch)
tree71fdeae27b0cd66ab9f56e9ed49fd9f9a8f969fe
parentf6b1deebab06dda60ffea454ead6b581a138b513 (diff)
Add support for dense.cwiseProduct(sparse)
This also fixes a regression regarding (dense*sparse).diagonal()
-rw-r--r--Eigen/src/Core/MatrixBase.h9
-rw-r--r--Eigen/src/Core/util/ForwardDeclarations.h4
-rw-r--r--Eigen/src/SparseCore/SparseCwiseBinaryOp.h4
-rw-r--r--Eigen/src/SparseCore/SparseMatrixBase.h22
-rw-r--r--Eigen/src/SparseCore/SparseUtil.h1
-rw-r--r--test/sparse_basic.cpp2
6 files changed, 26 insertions, 16 deletions
diff --git a/Eigen/src/Core/MatrixBase.h b/Eigen/src/Core/MatrixBase.h
index b5afff005..9d612c852 100644
--- a/Eigen/src/Core/MatrixBase.h
+++ b/Eigen/src/Core/MatrixBase.h
@@ -438,6 +438,15 @@ template<typename Derived> class MatrixBase
template<typename OtherScalar>
void applyOnTheRight(Index p, Index q, const JacobiRotation<OtherScalar>& j);
+///////// SparseCore module /////////
+
+ template<typename OtherDerived>
+ EIGEN_STRONG_INLINE const typename SparseMatrixBase<OtherDerived>::template CwiseProductDenseReturnType<Derived>::Type
+ cwiseProduct(const SparseMatrixBase<OtherDerived> &other) const
+ {
+ return other.cwiseProduct(derived());
+ }
+
///////// MatrixFunctions module /////////
typedef typename internal::stem_function<Scalar>::type StemFunction;
diff --git a/Eigen/src/Core/util/ForwardDeclarations.h b/Eigen/src/Core/util/ForwardDeclarations.h
index c2d5f4316..34697765d 100644
--- a/Eigen/src/Core/util/ForwardDeclarations.h
+++ b/Eigen/src/Core/util/ForwardDeclarations.h
@@ -265,7 +265,6 @@ template<typename Scalar> class Rotation2D;
template<typename Scalar> class AngleAxis;
template<typename Scalar,int Dim> class Translation;
template<typename Scalar,int Dim> class AlignedBox;
-
template<typename Scalar, int Options = AutoAlign> class Quaternion;
template<typename Scalar,int Dim,int Mode,int _Options=AutoAlign> class Transform;
template <typename _Scalar, int _AmbientDim, int Options=AutoAlign> class ParametrizedLine;
@@ -273,6 +272,9 @@ template <typename _Scalar, int _AmbientDim, int Options=AutoAlign> class Hyperp
template<typename Scalar> class UniformScaling;
template<typename MatrixType,int Direction> class Homogeneous;
+// Sparse module:
+template<typename Derived> class SparseMatrixBase;
+
// MatrixFunctions module
template<typename Derived> struct MatrixExponentialReturnValue;
template<typename Derived> class MatrixFunctionReturnValue;
diff --git a/Eigen/src/SparseCore/SparseCwiseBinaryOp.h b/Eigen/src/SparseCore/SparseCwiseBinaryOp.h
index abbbf397b..90f702ee3 100644
--- a/Eigen/src/SparseCore/SparseCwiseBinaryOp.h
+++ b/Eigen/src/SparseCore/SparseCwiseBinaryOp.h
@@ -423,10 +423,10 @@ Derived& SparseMatrixBase<Derived>::operator-=(const DiagonalBase<OtherDerived>&
template<typename Derived>
template<typename OtherDerived>
-EIGEN_STRONG_INLINE const EIGEN_SPARSE_CWISE_PRODUCT_RETURN_TYPE
+EIGEN_STRONG_INLINE const typename SparseMatrixBase<Derived>::template CwiseProductDenseReturnType<OtherDerived>::Type
SparseMatrixBase<Derived>::cwiseProduct(const MatrixBase<OtherDerived> &other) const
{
- return EIGEN_SPARSE_CWISE_PRODUCT_RETURN_TYPE(derived(), other.derived());
+ return typename CwiseProductDenseReturnType<OtherDerived>::Type(derived(), other.derived());
}
} // end namespace Eigen
diff --git a/Eigen/src/SparseCore/SparseMatrixBase.h b/Eigen/src/SparseCore/SparseMatrixBase.h
index ff417302f..648ae1f8a 100644
--- a/Eigen/src/SparseCore/SparseMatrixBase.h
+++ b/Eigen/src/SparseCore/SparseMatrixBase.h
@@ -262,20 +262,18 @@ template<typename Derived> class SparseMatrixBase
Derived& operator*=(const Scalar& other);
Derived& operator/=(const Scalar& other);
- #define EIGEN_SPARSE_CWISE_PRODUCT_RETURN_TYPE \
- CwiseBinaryOp< \
- internal::scalar_product_op< \
- typename internal::scalar_product_traits< \
- typename internal::traits<Derived>::Scalar, \
- typename internal::traits<OtherDerived>::Scalar \
- >::ReturnType \
- >, \
- const Derived, \
- const OtherDerived \
- >
+ template<typename OtherDerived> struct CwiseProductDenseReturnType {
+ typedef CwiseBinaryOp<internal::scalar_product_op<typename internal::scalar_product_traits<
+ typename internal::traits<Derived>::Scalar,
+ typename internal::traits<OtherDerived>::Scalar
+ >::ReturnType>,
+ const Derived,
+ const OtherDerived
+ > Type;
+ };
template<typename OtherDerived>
- EIGEN_STRONG_INLINE const EIGEN_SPARSE_CWISE_PRODUCT_RETURN_TYPE
+ EIGEN_STRONG_INLINE const typename CwiseProductDenseReturnType<OtherDerived>::Type
cwiseProduct(const MatrixBase<OtherDerived> &other) const;
// sparse * diagonal
diff --git a/Eigen/src/SparseCore/SparseUtil.h b/Eigen/src/SparseCore/SparseUtil.h
index 62f0f6864..74df0d496 100644
--- a/Eigen/src/SparseCore/SparseUtil.h
+++ b/Eigen/src/SparseCore/SparseUtil.h
@@ -49,7 +49,6 @@ const int InnerRandomAccessPattern = 0x2 | CoherentAccessPattern;
const int OuterRandomAccessPattern = 0x4 | CoherentAccessPattern;
const int RandomAccessPattern = 0x8 | OuterRandomAccessPattern | InnerRandomAccessPattern;
-template<typename Derived> class SparseMatrixBase;
template<typename _Scalar, int _Flags = 0, typename _StorageIndex = int> class SparseMatrix;
template<typename _Scalar, int _Flags = 0, typename _StorageIndex = int> class DynamicSparseMatrix;
template<typename _Scalar, int _Flags = 0, typename _StorageIndex = int> class SparseVector;
diff --git a/test/sparse_basic.cpp b/test/sparse_basic.cpp
index 0c7d2c652..2d0f5819f 100644
--- a/test/sparse_basic.cpp
+++ b/test/sparse_basic.cpp
@@ -188,6 +188,8 @@ template<typename SparseMatrixType> void sparse_basic(const SparseMatrixType& re
refM4.setRandom();
// sparse cwise* dense
VERIFY_IS_APPROX(m3.cwiseProduct(refM4), refM3.cwiseProduct(refM4));
+ // dense cwise* sparse
+ VERIFY_IS_APPROX(refM4.cwiseProduct(m3), refM4.cwiseProduct(refM3));
// VERIFY_IS_APPROX(m3.cwise()/refM4, refM3.cwise()/refM4);
// test aliasing