diff options
-rw-r--r-- | Eigen/src/Core/ProductEvaluators.h | 48 |
1 files changed, 27 insertions, 21 deletions
diff --git a/Eigen/src/Core/ProductEvaluators.h b/Eigen/src/Core/ProductEvaluators.h index c12ebb5a0..3cf7c0e16 100644 --- a/Eigen/src/Core/ProductEvaluators.h +++ b/Eigen/src/Core/ProductEvaluators.h @@ -842,17 +842,19 @@ struct product_evaluator<Product<Lhs, Rhs, ProductKind>, ProductTag, DenseShape, * Internal helper class implementing the product between a permutation matrix and a matrix. * This class is specialized for DenseShape below and for SparseShape in SparseCore/SparsePermutation.h */ -template<typename MatrixType, int Side, bool Transposed, typename MatrixShape> +template<typename ExpressionType, int Side, bool Transposed, typename ExpressionShape> struct permutation_matrix_product; -template<typename MatrixType, int Side, bool Transposed> -struct permutation_matrix_product<MatrixType, Side, Transposed, DenseShape> +template<typename ExpressionType, int Side, bool Transposed> +struct permutation_matrix_product<ExpressionType, Side, Transposed, DenseShape> { + typedef typename nested_eval<ExpressionType, 1>::type MatrixType; typedef typename remove_all<MatrixType>::type MatrixTypeCleaned; template<typename Dest, typename PermutationType> - static inline void run(Dest& dst, const PermutationType& perm, const MatrixType& mat) + static inline void run(Dest& dst, const PermutationType& perm, const ExpressionType& xpr) { + MatrixType mat(xpr); const Index n = Side==OnTheLeft ? mat.rows() : mat.cols(); // FIXME we need an is_same for expression that is not sensitive to constness. For instance // is_same_xpr<Block<const Matrix>, Block<Matrix> >::value should be true. @@ -893,7 +895,7 @@ struct permutation_matrix_product<MatrixType, Side, Transposed, DenseShape> = - Block<const MatrixTypeCleaned,Side==OnTheLeft ? 1 : MatrixType::RowsAtCompileTime,Side==OnTheRight ? 1 : MatrixType::ColsAtCompileTime> + Block<const MatrixTypeCleaned,Side==OnTheLeft ? 1 : MatrixTypeCleaned::RowsAtCompileTime,Side==OnTheRight ? 1 : MatrixTypeCleaned::ColsAtCompileTime> (mat, ((Side==OnTheRight) ^ Transposed) ? perm.indices().coeff(i) : i); } } @@ -951,26 +953,30 @@ struct generic_product_impl<Lhs, Transpose<Rhs>, MatrixShape, PermutationShape, * \class transposition_matrix_product * Internal helper class implementing the product between a permutation matrix and a matrix. */ -template<typename MatrixType, int Side, bool Transposed, typename MatrixShape> +template<typename ExpressionType, int Side, bool Transposed, typename ExpressionShape> struct transposition_matrix_product { - template<typename Dest, typename TranspositionType> - static inline void run(Dest& dst, const TranspositionType& tr, const MatrixType& mat) - { - typedef typename TranspositionType::StorageIndex StorageIndex; - const Index size = tr.size(); - StorageIndex j = 0; + typedef typename nested_eval<ExpressionType, 1>::type MatrixType; + typedef typename remove_all<MatrixType>::type MatrixTypeCleaned; + + template<typename Dest, typename TranspositionType> + static inline void run(Dest& dst, const TranspositionType& tr, const ExpressionType& xpr) + { + MatrixType mat(xpr); + typedef typename TranspositionType::StorageIndex StorageIndex; + const Index size = tr.size(); + StorageIndex j = 0; - if(!(is_same<MatrixType,Dest>::value && extract_data(dst) == extract_data(mat))) - dst = mat; + if(!(is_same<MatrixTypeCleaned,Dest>::value && extract_data(dst) == extract_data(mat))) + dst = mat; - for(Index k=(Transposed?size-1:0) ; Transposed?k>=0:k<size ; Transposed?--k:++k) - if(Index(j=tr.coeff(k))!=k) - { - if(Side==OnTheLeft) dst.row(k).swap(dst.row(j)); - else if(Side==OnTheRight) dst.col(k).swap(dst.col(j)); - } - } + for(Index k=(Transposed?size-1:0) ; Transposed?k>=0:k<size ; Transposed?--k:++k) + if(Index(j=tr.coeff(k))!=k) + { + if(Side==OnTheLeft) dst.row(k).swap(dst.row(j)); + else if(Side==OnTheRight) dst.col(k).swap(dst.col(j)); + } + } }; template<typename Lhs, typename Rhs, int ProductTag, typename MatrixShape> |