aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Gael Guennebaud <g.gael@free.fr>2014-02-19 11:33:29 +0100
committerGravatar Gael Guennebaud <g.gael@free.fr>2014-02-19 11:33:29 +0100
commitccc41128fb936563651a3c5b25bfc80f303710bf (patch)
tree95f9c13ef7bfa277c30364bb898b64e27f0f510b
parentb3a07eecc5ad19f418f15d70d1517f8b7e07dc1b (diff)
Add a Solve expression for uniform treatment of solve() methods.
-rw-r--r--Eigen/Core3
-rw-r--r--Eigen/src/Core/CoreEvaluators.h4
-rw-r--r--Eigen/src/Core/Solve.h145
-rw-r--r--Eigen/src/Core/TriangularMatrix.h17
-rw-r--r--Eigen/src/Core/util/ForwardDeclarations.h3
5 files changed, 169 insertions, 3 deletions
diff --git a/Eigen/Core b/Eigen/Core
index c1d168ec4..e00cd7a16 100644
--- a/Eigen/Core
+++ b/Eigen/Core
@@ -361,6 +361,9 @@ using std::ptrdiff_t;
#include "src/Core/Flagged.h"
#include "src/Core/ProductBase.h"
#include "src/Core/GeneralProduct.h"
+#ifdef EIGEN_ENABLE_EVALUATORS
+#include "src/Core/Solve.h"
+#endif
#include "src/Core/TriangularMatrix.h"
#include "src/Core/SelfAdjointView.h"
#include "src/Core/products/GeneralBlockPanelKernel.h"
diff --git a/Eigen/src/Core/CoreEvaluators.h b/Eigen/src/Core/CoreEvaluators.h
index d02bac0a8..726a6854a 100644
--- a/Eigen/src/Core/CoreEvaluators.h
+++ b/Eigen/src/Core/CoreEvaluators.h
@@ -495,14 +495,14 @@ struct evaluator<CwiseBinaryOp<BinaryOp, Lhs, Rhs> >
PacketScalar packet(Index row, Index col) const
{
return m_functor.packetOp(m_lhsImpl.template packet<LoadMode>(row, col),
- m_rhsImpl.template packet<LoadMode>(row, col));
+ m_rhsImpl.template packet<LoadMode>(row, col));
}
template<int LoadMode>
PacketScalar packet(Index index) const
{
return m_functor.packetOp(m_lhsImpl.template packet<LoadMode>(index),
- m_rhsImpl.template packet<LoadMode>(index));
+ m_rhsImpl.template packet<LoadMode>(index));
}
protected:
diff --git a/Eigen/src/Core/Solve.h b/Eigen/src/Core/Solve.h
new file mode 100644
index 000000000..2da2aff56
--- /dev/null
+++ b/Eigen/src/Core/Solve.h
@@ -0,0 +1,145 @@
+// This file is part of Eigen, a lightweight C++ template library
+// for linear algebra.
+//
+// Copyright (C) 2014 Gael Guennebaud <gael.guennebaud@inria.fr>
+//
+// This Source Code Form is subject to the terms of the Mozilla
+// Public License v. 2.0. If a copy of the MPL was not distributed
+// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+#ifndef EIGEN_INVERSE_H
+#define EIGEN_INVERSE_H
+
+namespace Eigen {
+
+template<typename Decomposition, typename RhsType, typename StorageKind> class SolveImpl;
+
+/** \class Solve
+ * \ingroup Core_Module
+ *
+ * \brief Pseudo expression representing a solving operation
+ *
+ * \tparam Decomposition the type of the matrix or decomposion object
+ * \tparam Rhstype the type of the right-hand side
+ *
+ * This class represents an expression of A.solve(B)
+ * and most of the time this is the only way it is used.
+ *
+ */
+namespace internal {
+
+// this solve_traits class permits to determine the evaluation type with respect to storage kind (Dense vs Sparse)
+template<typename Decomposition, typename RhsType,typename StorageKind> struct solve_traits;
+
+template<typename Decomposition, typename RhsType>
+struct solve_traits<Decomposition,RhsType,Dense>
+{
+ typedef typename Decomposition::MatrixType MatrixType;
+ typedef Matrix<typename RhsType::Scalar,
+ MatrixType::ColsAtCompileTime,
+ RhsType::ColsAtCompileTime,
+ RhsType::PlainObject::Options,
+ MatrixType::MaxColsAtCompileTime,
+ RhsType::MaxColsAtCompileTime> PlainObject;
+};
+
+template<typename Decomposition, typename RhsType>
+struct traits<Solve<Decomposition, RhsType> >
+ : traits<typename solve_traits<Decomposition,RhsType,typename internal::traits<RhsType>::StorageKind>::PlainObject>
+{
+ typedef typename solve_traits<Decomposition,RhsType,typename internal::traits<RhsType>::StorageKind>::PlainObject PlainObject;
+ typedef traits<PlainObject> BaseTraits;
+ enum {
+ Flags = BaseTraits::Flags & RowMajorBit,
+ CoeffReadCost = Dynamic
+ };
+};
+
+}
+
+
+template<typename Decomposition, typename RhsType>
+class Solve : public SolveImpl<Decomposition,RhsType,typename internal::traits<RhsType>::StorageKind>
+{
+public:
+ typedef typename RhsType::Index Index;
+ typedef typename internal::traits<Solve>::PlainObject PlainObject;
+
+ Solve(const Decomposition &dec, const RhsType &rhs)
+ : m_dec(dec), m_rhs(rhs)
+ {}
+
+ EIGEN_DEVICE_FUNC Index rows() const { return m_dec.rows(); }
+ EIGEN_DEVICE_FUNC Index cols() const { return m_rhs.cols(); }
+
+ EIGEN_DEVICE_FUNC const Decomposition& dec() const { return m_dec; }
+ EIGEN_DEVICE_FUNC const RhsType& rhs() const { return m_rhs; }
+
+protected:
+ const Decomposition &m_dec;
+ const RhsType &m_rhs;
+};
+
+
+// Specilaization of the Solve expression for dense results
+template<typename Decomposition, typename RhsType>
+class SolveImpl<Decomposition,RhsType,Dense>
+ : public MatrixBase<Solve<Decomposition,RhsType> >
+{
+ typedef Solve<Decomposition,RhsType> Derived;
+
+public:
+
+ typedef MatrixBase<Solve<Decomposition,RhsType> > Base;
+ EIGEN_DENSE_PUBLIC_INTERFACE(Derived)
+
+private:
+
+ Scalar coeff(Index row, Index col) const;
+ Scalar coeff(Index i) const;
+};
+
+
+namespace internal {
+
+// Evaluator of Solve -> eval into a temporary
+template<typename Decomposition, typename RhsType>
+struct evaluator<Solve<Decomposition,RhsType> >
+ : public evaluator<typename Solve<Decomposition,RhsType>::PlainObject>::type
+{
+ typedef Solve<Decomposition,RhsType> SolveType;
+ typedef typename SolveType::PlainObject PlainObject;
+ typedef typename evaluator<PlainObject>::type Base;
+
+ typedef evaluator type;
+ typedef evaluator nestedType;
+
+ evaluator(const SolveType& solve)
+ : m_result(solve.rows(), solve.cols())
+ {
+ ::new (static_cast<Base*>(this)) Base(m_result);
+ solve.dec()._solve_impl(solve.rhs(), m_result);
+ }
+
+protected:
+ PlainObject m_result;
+};
+
+// Specialization for "dst = dec.solve(rhs)"
+// NOTE we need to specialize it for Dense2Dense to avoid ambiguous specialization error and a Sparse2Sparse specialization must exist somewhere
+template<typename DstXprType, typename DecType, typename RhsType, typename Scalar>
+struct Assignment<DstXprType, Solve<DecType,RhsType>, internal::assign_op<Scalar>, Dense2Dense, Scalar>
+{
+ typedef Solve<DecType,RhsType> SrcXprType;
+ static void run(DstXprType &dst, const SrcXprType &src, const internal::assign_op<Scalar> &)
+ {
+ // FIXME shall we resize dst here?
+ src.dec()._solve_impl(src.rhs(), dst);
+ }
+};
+
+} // end namepsace internal
+
+} // end namespace Eigen
+
+#endif // EIGEN_SOLVE_H
diff --git a/Eigen/src/Core/TriangularMatrix.h b/Eigen/src/Core/TriangularMatrix.h
index ffa2faeaa..76ba76d5f 100644
--- a/Eigen/src/Core/TriangularMatrix.h
+++ b/Eigen/src/Core/TriangularMatrix.h
@@ -437,11 +437,19 @@ template<typename _MatrixType, unsigned int _Mode> class TriangularView
EIGEN_DEVICE_FUNC
void solveInPlace(const MatrixBase<OtherDerived>& other) const;
+#ifdef EIGEN_TEST_EVALUATORS
+ template<typename Other>
+ EIGEN_DEVICE_FUNC
+ inline const Solve<TriangularView, Other>
+ solve(const MatrixBase<Other>& other) const
+ { return Solve<TriangularView, Other>(*this, other.derived()); }
+#else // EIGEN_TEST_EVALUATORS
template<typename Other>
EIGEN_DEVICE_FUNC
inline const internal::triangular_solve_retval<OnTheLeft,TriangularView, Other>
solve(const MatrixBase<Other>& other) const
{ return solve<OnTheLeft>(other); }
+#endif // EIGEN_TEST_EVALUATORS
template<typename OtherDerived>
EIGEN_DEVICE_FUNC
@@ -547,6 +555,15 @@ template<typename _MatrixType, unsigned int _Mode> class TriangularView
#endif // EIGEN_TEST_EVALUATORS
#ifdef EIGEN_TEST_EVALUATORS
+
+ template<typename RhsType, typename DstType>
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE void _solve_impl(const RhsType &rhs, DstType &dst) const {
+ if(!(internal::is_same<RhsType,DstType>::value && internal::extract_data(dst) == internal::extract_data(rhs)))
+ dst = rhs;
+ this->template solveInPlace(dst);
+ }
+
template<typename ProductType>
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE TriangularView& _assignProduct(const ProductType& prod, const Scalar& alpha);
diff --git a/Eigen/src/Core/util/ForwardDeclarations.h b/Eigen/src/Core/util/ForwardDeclarations.h
index db0e2b033..8cf13abd9 100644
--- a/Eigen/src/Core/util/ForwardDeclarations.h
+++ b/Eigen/src/Core/util/ForwardDeclarations.h
@@ -94,7 +94,8 @@ template<typename UnaryOp, typename MatrixType> class CwiseUnaryOp;
template<typename ViewOp, typename MatrixType> class CwiseUnaryView;
template<typename BinaryOp, typename Lhs, typename Rhs> class CwiseBinaryOp;
template<typename BinOp, typename Lhs, typename Rhs> class SelfCwiseBinaryOp; // TODO deprecated
-template<typename Derived, typename Lhs, typename Rhs> class ProductBase;
+template<typename Derived, typename Lhs, typename Rhs> class ProductBase; // TODO deprecated
+template<typename Decomposition, typename Rhstype> class Solve;
namespace internal {
template<typename Lhs, typename Rhs> struct product_tag;