aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--Eigen/src/Core/DiagonalMatrix.h15
-rw-r--r--Eigen/src/Core/DiagonalProduct.h23
-rw-r--r--Eigen/src/Core/Product.h4
-rw-r--r--test/product_large.cpp9
4 files changed, 46 insertions, 5 deletions
diff --git a/Eigen/src/Core/DiagonalMatrix.h b/Eigen/src/Core/DiagonalMatrix.h
index de1752a9f..e09797eaf 100644
--- a/Eigen/src/Core/DiagonalMatrix.h
+++ b/Eigen/src/Core/DiagonalMatrix.h
@@ -62,10 +62,19 @@ class DiagonalMatrix : ei_no_assignment_operator,
EIGEN_GENERIC_PUBLIC_INTERFACE(DiagonalMatrix)
+ // needed to evaluate a DiagonalMatrix<Xpr> to a DiagonalMatrix<NestByValue<Vector> >
+ template<typename OtherCoeffsVectorType>
+ inline DiagonalMatrix(const DiagonalMatrix<OtherCoeffsVectorType>& other) : m_coeffs(other.diagonal())
+ {
+ EIGEN_STATIC_ASSERT_VECTOR_ONLY(CoeffsVectorType);
+ EIGEN_STATIC_ASSERT_VECTOR_ONLY(OtherCoeffsVectorType);
+ ei_assert(m_coeffs.size() > 0);
+ }
+
inline DiagonalMatrix(const CoeffsVectorType& coeffs) : m_coeffs(coeffs)
{
- ei_assert(CoeffsVectorType::IsVectorAtCompileTime
- && coeffs.size() > 0);
+ EIGEN_STATIC_ASSERT_VECTOR_ONLY(CoeffsVectorType);
+ ei_assert(coeffs.size() > 0);
}
inline int rows() const { return m_coeffs.size(); }
@@ -76,6 +85,8 @@ class DiagonalMatrix : ei_no_assignment_operator,
return row == col ? m_coeffs.coeff(row) : static_cast<Scalar>(0);
}
+ inline const CoeffsVectorType& diagonal() const { return m_coeffs; }
+
protected:
const typename CoeffsVectorType::Nested m_coeffs;
};
diff --git a/Eigen/src/Core/DiagonalProduct.h b/Eigen/src/Core/DiagonalProduct.h
index f30f8d369..ca0b56872 100644
--- a/Eigen/src/Core/DiagonalProduct.h
+++ b/Eigen/src/Core/DiagonalProduct.h
@@ -26,12 +26,31 @@
#ifndef EIGEN_DIAGONALPRODUCT_H
#define EIGEN_DIAGONALPRODUCT_H
+/** \internal Specialization of ei_nested for DiagonalMatrix.
+ * Unlike ei_nested, if the argument is a DiagonalMatrix and if it must be evaluated,
+ * then it evaluated to a DiagonalMatrix having its own argument evaluated.
+ */
+template<typename T, int N> struct ei_nested_diagonal : ei_nested<T,N> {};
+template<typename T, int N> struct ei_nested_diagonal<DiagonalMatrix<T>,N >
+ : ei_nested<DiagonalMatrix<T>, N, DiagonalMatrix<NestByValue<typename ei_eval<T>::type> > >
+{};
+
+// specialization of ProductReturnType
+template<typename Lhs, typename Rhs>
+struct ProductReturnType<Lhs,Rhs,DiagonalProduct>
+{
+ typedef typename ei_nested_diagonal<Lhs,Rhs::ColsAtCompileTime>::type LhsNested;
+ typedef typename ei_nested_diagonal<Rhs,Lhs::RowsAtCompileTime>::type RhsNested;
+
+ typedef Product<LhsNested, RhsNested, DiagonalProduct> Type;
+};
+
template<typename LhsNested, typename RhsNested>
struct ei_traits<Product<LhsNested, RhsNested, DiagonalProduct> >
{
// clean the nested types:
- typedef typename ei_unconst<typename ei_unref<LhsNested>::type>::type _LhsNested;
- typedef typename ei_unconst<typename ei_unref<RhsNested>::type>::type _RhsNested;
+ typedef typename ei_cleantype<LhsNested>::type _LhsNested;
+ typedef typename ei_cleantype<RhsNested>::type _RhsNested;
typedef typename _LhsNested::Scalar Scalar;
enum {
diff --git a/Eigen/src/Core/Product.h b/Eigen/src/Core/Product.h
index c5b06c450..b464304f8 100644
--- a/Eigen/src/Core/Product.h
+++ b/Eigen/src/Core/Product.h
@@ -62,6 +62,7 @@ struct ProductReturnType
};
// cache friendly specialization
+// note that there is a DiagonalProduct specialization in DiagonalProduct.h
template<typename Lhs, typename Rhs>
struct ProductReturnType<Lhs,Rhs,CacheFriendlyProduct>
{
@@ -77,7 +78,8 @@ struct ProductReturnType<Lhs,Rhs,CacheFriendlyProduct>
/* Helper class to determine the type of the product, can be either:
* - NormalProduct
* - CacheFriendlyProduct
- * - NormalProduct
+ * - DiagonalProduct
+ * - SparseProduct
*/
template<typename Lhs, typename Rhs> struct ei_product_mode
{
diff --git a/test/product_large.cpp b/test/product_large.cpp
index a1e187889..1c33578be 100644
--- a/test/product_large.cpp
+++ b/test/product_large.cpp
@@ -33,4 +33,13 @@ void test_product_large()
CALL_SUBTEST( product(MatrixXcf(ei_random<int>(1,50), ei_random<int>(1,50))) );
CALL_SUBTEST( product(Matrix<float,Dynamic,Dynamic,RowMajor>(ei_random<int>(1,320), ei_random<int>(1,320))) );
}
+
+ {
+ // test a specific issue in DiagonalProduct
+ int N = 1000000;
+ VectorXf v = VectorXf::Ones(N);
+ MatrixXf m = MatrixXf::Ones(N,3);
+ m = (v+v).asDiagonal() * m;
+ VERIFY_IS_APPROX(m, MatrixXf::Constant(N,3,2));
+ }
}