diff options
Diffstat (limited to 'Eigen/src/SparseCore/SparsePermutation.h')
-rw-r--r-- | Eigen/src/SparseCore/SparsePermutation.h | 117 |
1 files changed, 103 insertions, 14 deletions
diff --git a/Eigen/src/SparseCore/SparsePermutation.h b/Eigen/src/SparseCore/SparsePermutation.h index b85be93f6..21411f232 100644 --- a/Eigen/src/SparseCore/SparsePermutation.h +++ b/Eigen/src/SparseCore/SparsePermutation.h @@ -61,7 +61,7 @@ struct permut_sparsematrix_product_retval for(Index j=0; j<m_matrix.outerSize(); ++j) { Index jp = m_permutation.indices().coeff(j); - sizes[((Side==OnTheLeft) ^ Transposed) ? jp : j] = m_matrix.innerVector(((Side==OnTheRight) ^ Transposed) ? jp : j).size(); + sizes[((Side==OnTheLeft) ^ Transposed) ? jp : j] = m_matrix.innerVector(((Side==OnTheRight) ^ Transposed) ? jp : j).nonZeros(); } tmp.reserve(sizes); for(Index j=0; j<m_matrix.outerSize(); ++j) @@ -103,44 +103,133 @@ struct permut_sparsematrix_product_retval } +namespace internal { + +template <int ProductTag> struct product_promote_storage_type<Sparse, PermutationStorage, ProductTag> { typedef Sparse ret; }; +template <int ProductTag> struct product_promote_storage_type<PermutationStorage, Sparse, ProductTag> { typedef Sparse ret; }; + +// TODO, the following need cleaning, this is just a copy-past of the dense case + +template<typename Lhs, typename Rhs, int ProductTag> +struct generic_product_impl<Lhs, Rhs, PermutationShape, SparseShape, ProductTag> +{ + template<typename Dest> + static void evalTo(Dest& dst, const Lhs& lhs, const Rhs& rhs) + { + permut_sparsematrix_product_retval<Lhs, Rhs, OnTheLeft, false> pmpr(lhs, rhs); + pmpr.evalTo(dst); + } +}; + +template<typename Lhs, typename Rhs, int ProductTag> +struct generic_product_impl<Lhs, Rhs, SparseShape, PermutationShape, ProductTag> +{ + template<typename Dest> + static void evalTo(Dest& dst, const Lhs& lhs, const Rhs& rhs) + { + permut_sparsematrix_product_retval<Rhs, Lhs, OnTheRight, false> pmpr(rhs, lhs); + pmpr.evalTo(dst); + } +}; + +template<typename Lhs, typename Rhs, int ProductTag> +struct generic_product_impl<Transpose<Lhs>, Rhs, PermutationShape, SparseShape, ProductTag> +{ + template<typename Dest> + static void evalTo(Dest& dst, const Transpose<Lhs>& lhs, const Rhs& rhs) + { + permut_sparsematrix_product_retval<Lhs, Rhs, OnTheLeft, true> pmpr(lhs.nestedPermutation(), rhs); + pmpr.evalTo(dst); + } +}; + +template<typename Lhs, typename Rhs, int ProductTag> +struct generic_product_impl<Lhs, Transpose<Rhs>, SparseShape, PermutationShape, ProductTag> +{ + template<typename Dest> + static void evalTo(Dest& dst, const Lhs& lhs, const Transpose<Rhs>& rhs) + { + permut_sparsematrix_product_retval<Rhs, Lhs, OnTheRight, true> pmpr(rhs.nestedPermutation(), lhs); + pmpr.evalTo(dst); + } +}; + +// TODO, the following two overloads are only needed to define the right temporary type through +// typename traits<permut_sparsematrix_product_retval<Rhs,Lhs,OnTheRight,false> >::ReturnType +// while it should be correctly handled by traits<Product<> >::PlainObject +template<typename Lhs, typename Rhs, int ProductTag> +struct product_evaluator<Product<Lhs, Rhs, DefaultProduct>, ProductTag, PermutationShape, SparseShape, typename traits<Lhs>::Scalar, typename traits<Rhs>::Scalar> + : public evaluator<typename traits<permut_sparsematrix_product_retval<Lhs,Rhs,OnTheRight,false> >::ReturnType>::type +{ + typedef Product<Lhs, Rhs, DefaultProduct> XprType; + typedef typename traits<permut_sparsematrix_product_retval<Lhs,Rhs,OnTheRight,false> >::ReturnType PlainObject; + typedef typename evaluator<PlainObject>::type Base; + + explicit product_evaluator(const XprType& xpr) + : m_result(xpr.rows(), xpr.cols()) + { + ::new (static_cast<Base*>(this)) Base(m_result); + generic_product_impl<Lhs, Rhs, PermutationShape, SparseShape, ProductTag>::evalTo(m_result, xpr.lhs(), xpr.rhs()); + } + +protected: + PlainObject m_result; +}; + +template<typename Lhs, typename Rhs, int ProductTag> +struct product_evaluator<Product<Lhs, Rhs, DefaultProduct>, ProductTag, SparseShape, PermutationShape, typename traits<Lhs>::Scalar, typename traits<Rhs>::Scalar> + : public evaluator<typename traits<permut_sparsematrix_product_retval<Rhs,Lhs,OnTheRight,false> >::ReturnType>::type +{ + typedef Product<Lhs, Rhs, DefaultProduct> XprType; + typedef typename traits<permut_sparsematrix_product_retval<Rhs,Lhs,OnTheRight,false> >::ReturnType PlainObject; + typedef typename evaluator<PlainObject>::type Base; + + explicit product_evaluator(const XprType& xpr) + : m_result(xpr.rows(), xpr.cols()) + { + ::new (static_cast<Base*>(this)) Base(m_result); + generic_product_impl<Lhs, Rhs, SparseShape, PermutationShape, ProductTag>::evalTo(m_result, xpr.lhs(), xpr.rhs()); + } + +protected: + PlainObject m_result; +}; + +} // end namespace internal /** \returns the matrix with the permutation applied to the columns */ template<typename SparseDerived, typename PermDerived> -inline const internal::permut_sparsematrix_product_retval<PermutationBase<PermDerived>, SparseDerived, OnTheRight, false> +inline const Product<SparseDerived, PermDerived> operator*(const SparseMatrixBase<SparseDerived>& matrix, const PermutationBase<PermDerived>& perm) -{ - return internal::permut_sparsematrix_product_retval<PermutationBase<PermDerived>, SparseDerived, OnTheRight, false>(perm, matrix.derived()); -} +{ return Product<SparseDerived, PermDerived>(matrix.derived(), perm.derived()); } /** \returns the matrix with the permutation applied to the rows */ template<typename SparseDerived, typename PermDerived> -inline const internal::permut_sparsematrix_product_retval<PermutationBase<PermDerived>, SparseDerived, OnTheLeft, false> +inline const Product<PermDerived, SparseDerived> operator*( const PermutationBase<PermDerived>& perm, const SparseMatrixBase<SparseDerived>& matrix) -{ - return internal::permut_sparsematrix_product_retval<PermutationBase<PermDerived>, SparseDerived, OnTheLeft, false>(perm, matrix.derived()); -} - +{ return Product<PermDerived, SparseDerived>(perm.derived(), matrix.derived()); } +// TODO, the following specializations should not be needed as Transpose<Permutation*> should be a PermutationBase. /** \returns the matrix with the inverse permutation applied to the columns. */ template<typename SparseDerived, typename PermDerived> -inline const internal::permut_sparsematrix_product_retval<PermutationBase<PermDerived>, SparseDerived, OnTheRight, true> +inline const Product<SparseDerived, Transpose<PermutationBase<PermDerived> > > operator*(const SparseMatrixBase<SparseDerived>& matrix, const Transpose<PermutationBase<PermDerived> >& tperm) { - return internal::permut_sparsematrix_product_retval<PermutationBase<PermDerived>, SparseDerived, OnTheRight, true>(tperm.nestedPermutation(), matrix.derived()); + return Product<SparseDerived, Transpose<PermutationBase<PermDerived> > >(matrix.derived(), tperm); } /** \returns the matrix with the inverse permutation applied to the rows. */ template<typename SparseDerived, typename PermDerived> -inline const internal::permut_sparsematrix_product_retval<PermutationBase<PermDerived>, SparseDerived, OnTheLeft, true> +inline const Product<Transpose<PermutationBase<PermDerived> >, SparseDerived> operator*(const Transpose<PermutationBase<PermDerived> >& tperm, const SparseMatrixBase<SparseDerived>& matrix) { - return internal::permut_sparsematrix_product_retval<PermutationBase<PermDerived>, SparseDerived, OnTheLeft, true>(tperm.nestedPermutation(), matrix.derived()); + return Product<Transpose<PermutationBase<PermDerived> >, SparseDerived>(tperm, matrix.derived()); } } // end namespace Eigen |