diff options
author | Gael Guennebaud <g.gael@free.fr> | 2015-12-01 14:38:47 +0100 |
---|---|---|
committer | Gael Guennebaud <g.gael@free.fr> | 2015-12-01 14:38:47 +0100 |
commit | 0bb12fa61437e55ce563d076938593bebff7f0fc (patch) | |
tree | 0b78b95457df64d4addd62b77899d0d645e59e53 /Eigen/src/Core | |
parent | 1663d15da7daf6cea77b6d0072849e77428db7a4 (diff) |
Add LU::transpose().solve() and LU::adjoint().solve() API.
Diffstat (limited to 'Eigen/src/Core')
-rw-r--r-- | Eigen/src/Core/CoreEvaluators.h | 1 | ||||
-rw-r--r-- | Eigen/src/Core/Inverse.h | 24 | ||||
-rw-r--r-- | Eigen/src/Core/Solve.h | 27 | ||||
-rw-r--r-- | Eigen/src/Core/SolverBase.h | 130 | ||||
-rw-r--r-- | Eigen/src/Core/Transpose.h | 2 | ||||
-rw-r--r-- | Eigen/src/Core/util/Constants.h | 4 | ||||
-rw-r--r-- | Eigen/src/Core/util/ForwardDeclarations.h | 1 |
7 files changed, 169 insertions, 20 deletions
diff --git a/Eigen/src/Core/CoreEvaluators.h b/Eigen/src/Core/CoreEvaluators.h index a8b359085..42ad452f7 100644 --- a/Eigen/src/Core/CoreEvaluators.h +++ b/Eigen/src/Core/CoreEvaluators.h @@ -29,6 +29,7 @@ struct storage_kind_to_evaluator_kind { template<typename StorageKind> struct storage_kind_to_shape; template<> struct storage_kind_to_shape<Dense> { typedef DenseShape Shape; }; +template<> struct storage_kind_to_shape<SolverStorage> { typedef SolverShape Shape; }; template<> struct storage_kind_to_shape<PermutationStorage> { typedef PermutationShape Shape; }; template<> struct storage_kind_to_shape<TranspositionsStorage> { typedef TranspositionsShape Shape; }; diff --git a/Eigen/src/Core/Inverse.h b/Eigen/src/Core/Inverse.h index 8ba1a12d9..f3ec84990 100644 --- a/Eigen/src/Core/Inverse.h +++ b/Eigen/src/Core/Inverse.h @@ -48,6 +48,7 @@ public: typedef typename internal::ref_selector<XprType>::type XprTypeNested; typedef typename internal::remove_all<XprTypeNested>::type XprTypeNestedCleaned; typedef typename internal::ref_selector<Inverse>::type Nested; + typedef typename internal::remove_all<XprType>::type NestedExpression; explicit Inverse(const XprType &xpr) : m_xpr(xpr) @@ -62,25 +63,16 @@ protected: XprTypeNested m_xpr; }; -/** \internal - * Specialization of the Inverse expression for dense expressions. - * Direct access to the coefficients are discared. - * FIXME this intermediate class is probably not needed anymore. - */ -template<typename XprType> -class InverseImpl<XprType,Dense> - : public MatrixBase<Inverse<XprType> > +// Generic API dispatcher +template<typename XprType, typename StorageKind> +class InverseImpl + : public internal::generic_xpr_base<Inverse<XprType> >::type { - typedef Inverse<XprType> Derived; - public: - - typedef MatrixBase<Derived> Base; - EIGEN_DENSE_PUBLIC_INTERFACE(Derived) - typedef typename internal::remove_all<XprType>::type NestedExpression; - + typedef typename internal::generic_xpr_base<Inverse<XprType> >::type Base; + typedef typename XprType::Scalar Scalar; private: - + Scalar coeff(Index row, Index col) const; Scalar coeff(Index i) const; }; diff --git a/Eigen/src/Core/Solve.h b/Eigen/src/Core/Solve.h index 2d163fe2a..ba2ee53b8 100644 --- a/Eigen/src/Core/Solve.h +++ b/Eigen/src/Core/Solve.h @@ -34,12 +34,11 @@ template<typename Decomposition, typename RhsType,typename StorageKind> struct s template<typename Decomposition, typename RhsType> struct solve_traits<Decomposition,RhsType,Dense> { - typedef typename Decomposition::MatrixType MatrixType; typedef Matrix<typename RhsType::Scalar, - MatrixType::ColsAtCompileTime, + Decomposition::ColsAtCompileTime, RhsType::ColsAtCompileTime, RhsType::PlainObject::Options, - MatrixType::MaxColsAtCompileTime, + Decomposition::MaxColsAtCompileTime, RhsType::MaxColsAtCompileTime> PlainObject; }; @@ -145,6 +144,28 @@ struct Assignment<DstXprType, Solve<DecType,RhsType>, internal::assign_op<Scalar } }; +// Specialization for "dst = dec.transpose().solve(rhs)" +template<typename DstXprType, typename DecType, typename RhsType, typename Scalar> +struct Assignment<DstXprType, Solve<Transpose<const DecType>,RhsType>, internal::assign_op<Scalar>, Dense2Dense, Scalar> +{ + typedef Solve<Transpose<const DecType>,RhsType> SrcXprType; + static void run(DstXprType &dst, const SrcXprType &src, const internal::assign_op<Scalar> &) + { + src.dec().nestedExpression().template _solve_impl_transposed<false>(src.rhs(), dst); + } +}; + +// Specialization for "dst = dec.adjoint().solve(rhs)" +template<typename DstXprType, typename DecType, typename RhsType, typename Scalar> +struct Assignment<DstXprType, Solve<CwiseUnaryOp<internal::scalar_conjugate_op<typename DecType::Scalar>, const Transpose<const DecType> >,RhsType>, internal::assign_op<Scalar>, Dense2Dense, Scalar> +{ + typedef Solve<CwiseUnaryOp<internal::scalar_conjugate_op<typename DecType::Scalar>, const Transpose<const DecType> >,RhsType> SrcXprType; + static void run(DstXprType &dst, const SrcXprType &src, const internal::assign_op<Scalar> &) + { + src.dec().nestedExpression().nestedExpression().template _solve_impl_transposed<true>(src.rhs(), dst); + } +}; + } // end namepsace internal } // end namespace Eigen diff --git a/Eigen/src/Core/SolverBase.h b/Eigen/src/Core/SolverBase.h new file mode 100644 index 000000000..8a4adc229 --- /dev/null +++ b/Eigen/src/Core/SolverBase.h @@ -0,0 +1,130 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2015 Gael Guennebaud <gael.guennebaud@inria.fr> +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef EIGEN_SOLVERBASE_H +#define EIGEN_SOLVERBASE_H + +namespace Eigen { + +namespace internal { + + + +} // end namespace internal + +/** \class SolverBase + * \brief A base class for matrix decomposition and solvers + * + * \tparam Derived the actual type of the decomposition/solver. + * + * Any matrix decomposition inheriting this base class provide the following API: + * + * \code + * MatrixType A, b, x; + * DecompositionType dec(A); + * x = dec.solve(b); // solve A * x = b + * x = dec.transpose().solve(b); // solve A^T * x = b + * x = dec.adjoint().solve(b); // solve A' * x = b + * \endcode + * + * \warning Currently, any other usage of transpose() and adjoint() are not supported and will produce compilation errors. + * + * \sa class PartialPivLU, class FullPivLU + */ +template<typename Derived> +class SolverBase : public EigenBase<Derived> +{ + public: + + typedef EigenBase<Derived> Base; + typedef typename internal::traits<Derived>::Scalar Scalar; + typedef Scalar CoeffReturnType; + + enum { + RowsAtCompileTime = internal::traits<Derived>::RowsAtCompileTime, + ColsAtCompileTime = internal::traits<Derived>::ColsAtCompileTime, + SizeAtCompileTime = (internal::size_at_compile_time<internal::traits<Derived>::RowsAtCompileTime, + internal::traits<Derived>::ColsAtCompileTime>::ret), + MaxRowsAtCompileTime = internal::traits<Derived>::MaxRowsAtCompileTime, + MaxColsAtCompileTime = internal::traits<Derived>::MaxColsAtCompileTime, + MaxSizeAtCompileTime = (internal::size_at_compile_time<internal::traits<Derived>::MaxRowsAtCompileTime, + internal::traits<Derived>::MaxColsAtCompileTime>::ret), + IsVectorAtCompileTime = internal::traits<Derived>::MaxRowsAtCompileTime == 1 + || internal::traits<Derived>::MaxColsAtCompileTime == 1 + }; + + /** Default constructor */ + SolverBase() + {} + + ~SolverBase() + {} + + using Base::derived; + + /** \returns an expression of the solution x of \f$ A x = b \f$ using the current decomposition of A. + */ + template<typename Rhs> + inline const Solve<Derived, Rhs> + solve(const MatrixBase<Rhs>& b) const + { + eigen_assert(derived().rows()==b.rows() && "solve(): invalid number of rows of the right hand side matrix b"); + return Solve<Derived, Rhs>(derived(), b.derived()); + } + + /** \internal the return type of transpose() */ + typedef typename internal::add_const<Transpose<const Derived> >::type ConstTransposeReturnType; + /** \returns an expression of the transposed of the factored matrix. + * + * A typical usage is to solve for the transposed problem A^T x = b: + * \code x = dec.transpose().solve(b); \endcode + * + * \sa adjoint(), solve() + */ + inline ConstTransposeReturnType transpose() const + { + return ConstTransposeReturnType(derived()); + } + + /** \internal the return type of adjoint() */ + typedef typename internal::conditional<NumTraits<Scalar>::IsComplex, + CwiseUnaryOp<internal::scalar_conjugate_op<Scalar>, ConstTransposeReturnType>, + ConstTransposeReturnType + >::type AdjointReturnType; + /** \returns an expression of the adjoint of the factored matrix + * + * A typical usage is to solve for the adjoint problem A' x = b: + * \code x = dec.adjoint().solve(b); \endcode + * + * For real scalar types, this function is equivalent to transpose(). + * + * \sa transpose(), solve() + */ + inline AdjointReturnType adjoint() const + { + return AdjointReturnType(derived().transpose()); + } + + protected: +}; + +namespace internal { + +template<typename Derived> +struct generic_xpr_base<Derived, MatrixXpr, SolverStorage> +{ + typedef SolverBase<Derived> type; + +}; + +} // end namespace internal + +} // end namespace Eigen + +#endif // EIGEN_SOLVERBASE_H diff --git a/Eigen/src/Core/Transpose.h b/Eigen/src/Core/Transpose.h index 2152405d5..5b66eb5e1 100644 --- a/Eigen/src/Core/Transpose.h +++ b/Eigen/src/Core/Transpose.h @@ -39,7 +39,7 @@ struct traits<Transpose<MatrixType> > : public traits<MatrixType> MaxRowsAtCompileTime = MatrixType::MaxColsAtCompileTime, MaxColsAtCompileTime = MatrixType::MaxRowsAtCompileTime, FlagsLvalueBit = is_lvalue<MatrixType>::value ? LvalueBit : 0, - Flags0 = MatrixTypeNestedPlain::Flags & ~(LvalueBit | NestByRefBit), + Flags0 = traits<MatrixTypeNestedPlain>::Flags & ~(LvalueBit | NestByRefBit), Flags1 = Flags0 | FlagsLvalueBit, Flags = Flags1 ^ RowMajorBit, InnerStrideAtCompileTime = inner_stride_at_compile_time<MatrixType>::ret, diff --git a/Eigen/src/Core/util/Constants.h b/Eigen/src/Core/util/Constants.h index 28852c8c3..a364f48d1 100644 --- a/Eigen/src/Core/util/Constants.h +++ b/Eigen/src/Core/util/Constants.h @@ -492,6 +492,9 @@ struct Dense {}; /** The type used to identify a general sparse storage. */ struct Sparse {}; +/** The type used to identify a general solver (foctored) storage. */ +struct SolverStorage {}; + /** The type used to identify a permutation storage. */ struct PermutationStorage {}; @@ -506,6 +509,7 @@ struct ArrayXpr {}; // An evaluator must define its shape. By default, it can be one of the following: struct DenseShape { static std::string debugName() { return "DenseShape"; } }; +struct SolverShape { static std::string debugName() { return "SolverShape"; } }; struct HomogeneousShape { static std::string debugName() { return "HomogeneousShape"; } }; struct DiagonalShape { static std::string debugName() { return "DiagonalShape"; } }; struct BandShape { static std::string debugName() { return "BandShape"; } }; diff --git a/Eigen/src/Core/util/ForwardDeclarations.h b/Eigen/src/Core/util/ForwardDeclarations.h index 1aa81abf8..483af876f 100644 --- a/Eigen/src/Core/util/ForwardDeclarations.h +++ b/Eigen/src/Core/util/ForwardDeclarations.h @@ -132,6 +132,7 @@ template<typename MatrixType> struct CommaInitializer; template<typename Derived> class ReturnByValue; template<typename ExpressionType> class ArrayWrapper; template<typename ExpressionType> class MatrixWrapper; +template<typename Derived> class SolverBase; template<typename XprType> class InnerIterator; namespace internal { |