diff options
-rw-r--r-- | Eigen/src/Core/DiagonalMatrix.h | 15 | ||||
-rw-r--r-- | Eigen/src/Core/DiagonalProduct.h | 23 | ||||
-rw-r--r-- | Eigen/src/Core/Product.h | 4 | ||||
-rw-r--r-- | test/product_large.cpp | 9 |
4 files changed, 46 insertions, 5 deletions
diff --git a/Eigen/src/Core/DiagonalMatrix.h b/Eigen/src/Core/DiagonalMatrix.h index de1752a9f..e09797eaf 100644 --- a/Eigen/src/Core/DiagonalMatrix.h +++ b/Eigen/src/Core/DiagonalMatrix.h @@ -62,10 +62,19 @@ class DiagonalMatrix : ei_no_assignment_operator, EIGEN_GENERIC_PUBLIC_INTERFACE(DiagonalMatrix) + // needed to evaluate a DiagonalMatrix<Xpr> to a DiagonalMatrix<NestByValue<Vector> > + template<typename OtherCoeffsVectorType> + inline DiagonalMatrix(const DiagonalMatrix<OtherCoeffsVectorType>& other) : m_coeffs(other.diagonal()) + { + EIGEN_STATIC_ASSERT_VECTOR_ONLY(CoeffsVectorType); + EIGEN_STATIC_ASSERT_VECTOR_ONLY(OtherCoeffsVectorType); + ei_assert(m_coeffs.size() > 0); + } + inline DiagonalMatrix(const CoeffsVectorType& coeffs) : m_coeffs(coeffs) { - ei_assert(CoeffsVectorType::IsVectorAtCompileTime - && coeffs.size() > 0); + EIGEN_STATIC_ASSERT_VECTOR_ONLY(CoeffsVectorType); + ei_assert(coeffs.size() > 0); } inline int rows() const { return m_coeffs.size(); } @@ -76,6 +85,8 @@ class DiagonalMatrix : ei_no_assignment_operator, return row == col ? m_coeffs.coeff(row) : static_cast<Scalar>(0); } + inline const CoeffsVectorType& diagonal() const { return m_coeffs; } + protected: const typename CoeffsVectorType::Nested m_coeffs; }; diff --git a/Eigen/src/Core/DiagonalProduct.h b/Eigen/src/Core/DiagonalProduct.h index f30f8d369..ca0b56872 100644 --- a/Eigen/src/Core/DiagonalProduct.h +++ b/Eigen/src/Core/DiagonalProduct.h @@ -26,12 +26,31 @@ #ifndef EIGEN_DIAGONALPRODUCT_H #define EIGEN_DIAGONALPRODUCT_H +/** \internal Specialization of ei_nested for DiagonalMatrix. + * Unlike ei_nested, if the argument is a DiagonalMatrix and if it must be evaluated, + * then it evaluated to a DiagonalMatrix having its own argument evaluated. + */ +template<typename T, int N> struct ei_nested_diagonal : ei_nested<T,N> {}; +template<typename T, int N> struct ei_nested_diagonal<DiagonalMatrix<T>,N > + : ei_nested<DiagonalMatrix<T>, N, DiagonalMatrix<NestByValue<typename ei_eval<T>::type> > > +{}; + +// specialization of ProductReturnType +template<typename Lhs, typename Rhs> +struct ProductReturnType<Lhs,Rhs,DiagonalProduct> +{ + typedef typename ei_nested_diagonal<Lhs,Rhs::ColsAtCompileTime>::type LhsNested; + typedef typename ei_nested_diagonal<Rhs,Lhs::RowsAtCompileTime>::type RhsNested; + + typedef Product<LhsNested, RhsNested, DiagonalProduct> Type; +}; + template<typename LhsNested, typename RhsNested> struct ei_traits<Product<LhsNested, RhsNested, DiagonalProduct> > { // clean the nested types: - typedef typename ei_unconst<typename ei_unref<LhsNested>::type>::type _LhsNested; - typedef typename ei_unconst<typename ei_unref<RhsNested>::type>::type _RhsNested; + typedef typename ei_cleantype<LhsNested>::type _LhsNested; + typedef typename ei_cleantype<RhsNested>::type _RhsNested; typedef typename _LhsNested::Scalar Scalar; enum { diff --git a/Eigen/src/Core/Product.h b/Eigen/src/Core/Product.h index c5b06c450..b464304f8 100644 --- a/Eigen/src/Core/Product.h +++ b/Eigen/src/Core/Product.h @@ -62,6 +62,7 @@ struct ProductReturnType }; // cache friendly specialization +// note that there is a DiagonalProduct specialization in DiagonalProduct.h template<typename Lhs, typename Rhs> struct ProductReturnType<Lhs,Rhs,CacheFriendlyProduct> { @@ -77,7 +78,8 @@ struct ProductReturnType<Lhs,Rhs,CacheFriendlyProduct> /* Helper class to determine the type of the product, can be either: * - NormalProduct * - CacheFriendlyProduct - * - NormalProduct + * - DiagonalProduct + * - SparseProduct */ template<typename Lhs, typename Rhs> struct ei_product_mode { diff --git a/test/product_large.cpp b/test/product_large.cpp index a1e187889..1c33578be 100644 --- a/test/product_large.cpp +++ b/test/product_large.cpp @@ -33,4 +33,13 @@ void test_product_large() CALL_SUBTEST( product(MatrixXcf(ei_random<int>(1,50), ei_random<int>(1,50))) ); CALL_SUBTEST( product(Matrix<float,Dynamic,Dynamic,RowMajor>(ei_random<int>(1,320), ei_random<int>(1,320))) ); } + + { + // test a specific issue in DiagonalProduct + int N = 1000000; + VectorXf v = VectorXf::Ones(N); + MatrixXf m = MatrixXf::Ones(N,3); + m = (v+v).asDiagonal() * m; + VERIFY_IS_APPROX(m, MatrixXf::Constant(N,3,2)); + } } |