aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--Eigen/src/SparseCore/SparseAssign.h33
-rw-r--r--Eigen/src/SparseCore/SparseMatrixBase.h7
-rw-r--r--test/sparse_basic.cpp12
3 files changed, 52 insertions, 0 deletions
diff --git a/Eigen/src/SparseCore/SparseAssign.h b/Eigen/src/SparseCore/SparseAssign.h
index 93e0adbff..c939f6c92 100644
--- a/Eigen/src/SparseCore/SparseAssign.h
+++ b/Eigen/src/SparseCore/SparseAssign.h
@@ -181,6 +181,39 @@ struct Assignment<DstXprType, Solve<DecType,RhsType>, internal::assign_op<Scalar
}
};
+struct Diagonal2Sparse {};
+
+template<> struct AssignmentKind<SparseShape,DiagonalShape> { typedef Diagonal2Sparse Kind; };
+
+template< typename DstXprType, typename SrcXprType, typename Functor, typename Scalar>
+struct Assignment<DstXprType, SrcXprType, Functor, Diagonal2Sparse, Scalar>
+{
+ typedef typename DstXprType::StorageIndex StorageIndex;
+ typedef Array<StorageIndex,Dynamic,1> ArrayXI;
+ typedef Array<Scalar,Dynamic,1> ArrayXS;
+ template<int Options>
+ static void run(SparseMatrix<Scalar,Options,StorageIndex> &dst, const SrcXprType &src, const internal::assign_op<typename DstXprType::Scalar> &/*func*/)
+ {
+ Index size = src.diagonal().size();
+ dst.makeCompressed();
+ dst.resizeNonZeros(size);
+ Map<ArrayXI>(dst.innerIndexPtr(), size).setLinSpaced(0,StorageIndex(size)-1);
+ Map<ArrayXI>(dst.outerIndexPtr(), size+1).setLinSpaced(0,StorageIndex(size));
+ Map<ArrayXS>(dst.valuePtr(), size) = src.diagonal();
+ }
+
+ template<typename DstDerived>
+ static void run(SparseMatrixBase<DstDerived> &dst, const SrcXprType &src, const internal::assign_op<typename DstXprType::Scalar> &/*func*/)
+ {
+ dst.diagonal() = src.diagonal();
+ }
+
+ static void run(DstXprType &dst, const SrcXprType &src, const internal::add_assign_op<typename DstXprType::Scalar> &/*func*/)
+ { dst.diagonal() += src.diagonal(); }
+
+ static void run(DstXprType &dst, const SrcXprType &src, const internal::sub_assign_op<typename DstXprType::Scalar> &/*func*/)
+ { dst.diagonal() -= src.diagonal(); }
+};
} // end namespace internal
} // end namespace Eigen
diff --git a/Eigen/src/SparseCore/SparseMatrixBase.h b/Eigen/src/SparseCore/SparseMatrixBase.h
index f1b5d2a97..4e720904e 100644
--- a/Eigen/src/SparseCore/SparseMatrixBase.h
+++ b/Eigen/src/SparseCore/SparseMatrixBase.h
@@ -242,6 +242,11 @@ template<typename Derived> class SparseMatrixBase : public EigenBase<Derived>
Derived& operator+=(const SparseMatrixBase<OtherDerived>& other);
template<typename OtherDerived>
Derived& operator-=(const SparseMatrixBase<OtherDerived>& other);
+
+ template<typename OtherDerived>
+ Derived& operator+=(const DiagonalBase<OtherDerived>& other);
+ template<typename OtherDerived>
+ Derived& operator-=(const DiagonalBase<OtherDerived>& other);
Derived& operator*=(const Scalar& other);
Derived& operator/=(const Scalar& other);
@@ -367,6 +372,8 @@ template<typename Derived> class SparseMatrixBase : public EigenBase<Derived>
static inline StorageIndex convert_index(const Index idx) {
return internal::convert_index<StorageIndex>(idx);
}
+ private:
+ template<typename Dest> void evalTo(Dest &) const;
};
} // end namespace Eigen
diff --git a/test/sparse_basic.cpp b/test/sparse_basic.cpp
index 75f29a2b4..2ebf4d420 100644
--- a/test/sparse_basic.cpp
+++ b/test/sparse_basic.cpp
@@ -365,6 +365,18 @@ template<typename SparseMatrixType> void sparse_basic(const SparseMatrixType& re
VERIFY_IS_APPROX(m2, refMat2);
}
+ // test diagonal to sparse
+ {
+ DenseVector d = DenseVector::Random(rows);
+ DenseMatrix refMat2 = d.asDiagonal();
+ SparseMatrixType m2(rows, rows);
+ m2 = d.asDiagonal();
+ VERIFY_IS_APPROX(m2, refMat2);
+ refMat2 += d.asDiagonal();
+ m2 += d.asDiagonal();
+ VERIFY_IS_APPROX(m2, refMat2);
+ }
+
// test conservative resize
{
std::vector< std::pair<StorageIndex,StorageIndex> > inc;