diff options
author | Gael Guennebaud <g.gael@free.fr> | 2017-01-16 22:21:23 +0100 |
---|---|---|
committer | Gael Guennebaud <g.gael@free.fr> | 2017-01-16 22:21:23 +0100 |
commit | 4989922be2708378b2438db5a843640ec468ce4c (patch) | |
tree | 19f81b7439f227edde89cd38fccffd94fbc1d3af /Eigen/src | |
parent | 12e22a2844d060cfbeab7a48512046ee59709e53 (diff) |
Add support for symbolic expressions as arguments of operator()
Diffstat (limited to 'Eigen/src')
-rw-r--r-- | Eigen/src/Core/ArithmeticSequence.h | 66 | ||||
-rw-r--r-- | Eigen/src/Core/IndexedView.h | 8 | ||||
-rw-r--r-- | Eigen/src/Core/util/IndexedViewHelper.h | 80 | ||||
-rw-r--r-- | Eigen/src/plugins/IndexedViewMethods.h | 41 |
4 files changed, 116 insertions, 79 deletions
diff --git a/Eigen/src/Core/ArithmeticSequence.h b/Eigen/src/Core/ArithmeticSequence.h index 79e6bb74e..056ace1f2 100644 --- a/Eigen/src/Core/ArithmeticSequence.h +++ b/Eigen/src/Core/ArithmeticSequence.h @@ -12,56 +12,6 @@ namespace Eigen { -/** \namespace Eigen::placeholders - * \ingroup Core_Module - * - * Namespace containing symbolic placeholder and identifiers - */ -namespace placeholders { - -namespace internal { -struct symbolic_last_tag {}; -} - -/** \var last - * \ingroup Core_Module - * - * Can be used as a parameter to Eigen::seq and Eigen::seqN functions to symbolically reference the last element/row/columns - * of the underlying vector or matrix once passed to DenseBase::operator()(const RowIndices&, const ColIndices&). - * - * This symbolic placeholder support standard arithmetic operation. - * - * A typical usage example would be: - * \code - * using namespace Eigen; - * using Eigen::placeholders::last; - * VectorXd v(n); - * v(seq(2,last-2)).setOnes(); - * \endcode - * - * \sa end - */ -static const Symbolic::SymbolExpr<internal::symbolic_last_tag> last; - -/** \var end - * \ingroup Core_Module - * - * Can be used as a parameter to Eigen::seq and Eigen::seqN functions to symbolically reference the last+1 element/row/columns - * of the underlying vector or matrix once passed to DenseBase::operator()(const RowIndices&, const ColIndices&). - * - * This symbolic placeholder support standard arithmetic operation. - * It is essentially an alias to last+1 - * - * \sa last - */ -#ifdef EIGEN_PARSED_BY_DOXYGEN -static const auto end = last+1; -#else -static const Symbolic::AddExpr<Symbolic::SymbolExpr<internal::symbolic_last_tag>,Symbolic::ValueExpr> end(last+1); -#endif - -} // end namespace placeholders - //-------------------------------------------------------------------------------- // seq(first,last,incr) and seqN(first,size,incr) //-------------------------------------------------------------------------------- @@ -293,18 +243,6 @@ seq(const Symbolic::BaseExpr<FirstTypeDerived> &f, const Symbolic::BaseExpr<Last namespace internal { -// Replace symbolic last/end "keywords" by their true runtime value -inline Index eval_expr_given_size(Index x, Index /* size */) { return x; } - -template<int N> -fix_t<N> eval_expr_given_size(fix_t<N> x, Index /*size*/) { return x; } - -template<typename Derived> -Index eval_expr_given_size(const Symbolic::BaseExpr<Derived> &x, Index size) -{ - return x.derived().eval(placeholders::last=size-1); -} - // Convert a symbolic span into a usable one (i.e., remove last/end "keywords") template<typename T> struct make_size_type { @@ -318,7 +256,7 @@ struct IndexedViewCompatibleType<ArithmeticSequence<FirstType,SizeType,IncrType> template<typename FirstType,typename SizeType,typename IncrType> ArithmeticSequence<Index,typename make_size_type<SizeType>::type,IncrType> -makeIndexedViewCompatible(const ArithmeticSequence<FirstType,SizeType,IncrType>& ids, Index size) { +makeIndexedViewCompatible(const ArithmeticSequence<FirstType,SizeType,IncrType>& ids, Index size,SpecializedType) { return ArithmeticSequence<Index,typename make_size_type<SizeType>::type,IncrType>( eval_expr_given_size(ids.firstObject(),size),eval_expr_given_size(ids.sizeObject(),size),ids.incrObject()); } @@ -436,7 +374,7 @@ struct IndexedViewCompatibleType<legacy::ArithmeticSequenceProxyWithBounds<First template<typename FirstType,typename LastType,typename IncrType> legacy::ArithmeticSequenceProxyWithBounds<Index,Index,IncrType> -makeIndexedViewCompatible(const legacy::ArithmeticSequenceProxyWithBounds<FirstType,LastType,IncrType>& ids, Index size) { +makeIndexedViewCompatible(const legacy::ArithmeticSequenceProxyWithBounds<FirstType,LastType,IncrType>& ids, Index size,SpecializedType) { return legacy::ArithmeticSequenceProxyWithBounds<Index,Index,IncrType>( eval_expr_given_size(ids.firstObject(),size),eval_expr_given_size(ids.lastObject(),size),ids.incrObject()); } diff --git a/Eigen/src/Core/IndexedView.h b/Eigen/src/Core/IndexedView.h index 38ee69638..63878428e 100644 --- a/Eigen/src/Core/IndexedView.h +++ b/Eigen/src/Core/IndexedView.h @@ -45,6 +45,10 @@ struct traits<IndexedView<XprType, RowIndices, ColIndices> > InnerStrideAtCompileTime = InnerIncr<0 || InnerIncr==DynamicIndex || XprInnerStride==Dynamic ? Dynamic : XprInnerStride * InnerIncr, OuterStrideAtCompileTime = OuterIncr<0 || OuterIncr==DynamicIndex || XprOuterstride==Dynamic ? Dynamic : XprOuterstride * OuterIncr, + ReturnAsScalar = is_same<RowIndices,SingleRange>::value && is_same<ColIndices,SingleRange>::value, + ReturnAsBlock = (!ReturnAsScalar) && IsBlockAlike, + ReturnAsIndexedView = (!ReturnAsScalar) && (!ReturnAsBlock), + // FIXME we deal with compile-time strides if and only if we have DirectAccessBit flag, // but this is too strict regarding negative strides... DirectAccessMask = (InnerIncr!=UndefinedIncr && OuterIncr!=UndefinedIncr && InnerIncr>=0 && OuterIncr>=0) ? DirectAccessBit : 0, @@ -91,8 +95,8 @@ class IndexedViewImpl; * - decltype(ArrayXi::LinSpaced(...)) * - Any view/expressions of the previous types * - Eigen::ArithmeticSequence - * - Eigen::AllRange (helper for Eigen::all) - * - Eigen::IntAsArray (helper for single index) + * - Eigen::internal::AllRange (helper for Eigen::all) + * - Eigen::internal::SingleRange (helper for single index) * - etc. * * In typical usages of %Eigen, this class should never be used directly. It is the return type of diff --git a/Eigen/src/Core/util/IndexedViewHelper.h b/Eigen/src/Core/util/IndexedViewHelper.h index 4f6dd065e..09637a157 100644 --- a/Eigen/src/Core/util/IndexedViewHelper.h +++ b/Eigen/src/Core/util/IndexedViewHelper.h @@ -13,8 +13,70 @@ namespace Eigen { +/** \namespace Eigen::placeholders + * \ingroup Core_Module + * + * Namespace containing symbolic placeholder and identifiers + */ +namespace placeholders { + +namespace internal { +struct symbolic_last_tag {}; +} + +/** \var last + * \ingroup Core_Module + * + * Can be used as a parameter to Eigen::seq and Eigen::seqN functions to symbolically reference the last element/row/columns + * of the underlying vector or matrix once passed to DenseBase::operator()(const RowIndices&, const ColIndices&). + * + * This symbolic placeholder support standard arithmetic operation. + * + * A typical usage example would be: + * \code + * using namespace Eigen; + * using Eigen::placeholders::last; + * VectorXd v(n); + * v(seq(2,last-2)).setOnes(); + * \endcode + * + * \sa end + */ +static const Symbolic::SymbolExpr<internal::symbolic_last_tag> last; + +/** \var end + * \ingroup Core_Module + * + * Can be used as a parameter to Eigen::seq and Eigen::seqN functions to symbolically reference the last+1 element/row/columns + * of the underlying vector or matrix once passed to DenseBase::operator()(const RowIndices&, const ColIndices&). + * + * This symbolic placeholder support standard arithmetic operation. + * It is essentially an alias to last+1 + * + * \sa last + */ +#ifdef EIGEN_PARSED_BY_DOXYGEN +static const auto end = last+1; +#else +static const Symbolic::AddExpr<Symbolic::SymbolExpr<internal::symbolic_last_tag>,Symbolic::ValueExpr> end(last+1); +#endif + +} // end namespace placeholders + namespace internal { + // Replace symbolic last/end "keywords" by their true runtime value +inline Index eval_expr_given_size(Index x, Index /* size */) { return x; } + +template<int N> +fix_t<N> eval_expr_given_size(fix_t<N> x, Index /*size*/) { return x; } + +template<typename Derived> +Index eval_expr_given_size(const Symbolic::BaseExpr<Derived> &x, Index size) +{ + return x.derived().eval(placeholders::last=size-1); +} + // Extract increment/step at compile time template<typename T, typename EnableIf = void> struct get_compile_time_incr { enum { value = UndefinedIncr }; @@ -31,8 +93,8 @@ struct IndexedViewCompatibleType { typedef T type; }; -template<typename T> -const T& makeIndexedViewCompatible(const T& x, Index /*size*/) { return x; } +template<typename T,typename Q> +const T& makeIndexedViewCompatible(const T& x, Index /*size*/, Q) { return x; } //-------------------------------------------------------------------------------- // Handling of a single Index @@ -62,6 +124,18 @@ struct IndexedViewCompatibleType<T,XprSize,typename internal::enable_if<internal typedef SingleRange type; }; +template<typename T, int XprSize> +struct IndexedViewCompatibleType<T, XprSize, typename enable_if<Symbolic::is_symbolic<T>::value>::type> { + typedef SingleRange type; +}; + + +template<typename T> +typename enable_if<Symbolic::is_symbolic<T>::value,SingleRange>::type +makeIndexedViewCompatible(const T& id, Index size, SpecializedType) { + return eval_expr_given_size(id,size); +} + //-------------------------------------------------------------------------------- // Handling of all //-------------------------------------------------------------------------------- @@ -85,7 +159,7 @@ struct IndexedViewCompatibleType<all_t,XprSize> { }; template<typename XprSizeType> -inline AllRange<get_compile_time<XprSizeType>::value> makeIndexedViewCompatible(all_t , XprSizeType size) { +inline AllRange<get_compile_time<XprSizeType>::value> makeIndexedViewCompatible(all_t , XprSizeType size, SpecializedType) { return AllRange<get_compile_time<XprSizeType>::value>(size); } diff --git a/Eigen/src/plugins/IndexedViewMethods.h b/Eigen/src/plugins/IndexedViewMethods.h index 0584a5926..90ade05ed 100644 --- a/Eigen/src/plugins/IndexedViewMethods.h +++ b/Eigen/src/plugins/IndexedViewMethods.h @@ -38,19 +38,24 @@ typedef typename internal::IndexedViewCompatibleType<Index,1>::type IvcIndex; template<typename Indices> typename IvcRowType<Indices>::type ivcRow(const Indices& indices) const { - return internal::makeIndexedViewCompatible(indices, internal::variable_if_dynamic<Index,RowsAtCompileTime>(derived().rows())); + return internal::makeIndexedViewCompatible(indices, internal::variable_if_dynamic<Index,RowsAtCompileTime>(derived().rows()),Specialized); }; template<typename Indices> typename IvcColType<Indices>::type ivcCol(const Indices& indices) const { - return internal::makeIndexedViewCompatible(indices, internal::variable_if_dynamic<Index,ColsAtCompileTime>(derived().cols())); + return internal::makeIndexedViewCompatible(indices, internal::variable_if_dynamic<Index,ColsAtCompileTime>(derived().cols()),Specialized); }; template<typename Indices> typename IvcColType<Indices>::type ivcSize(const Indices& indices) const { - return internal::makeIndexedViewCompatible(indices, internal::variable_if_dynamic<Index,SizeAtCompileTime>(derived().size())); + return internal::makeIndexedViewCompatible(indices, internal::variable_if_dynamic<Index,SizeAtCompileTime>(derived().size()),Specialized); +}; + +template<typename RowIndices, typename ColIndices> +struct valid_indexed_view_overload { + enum { value = !(internal::is_integral<RowIndices>::value && internal::is_integral<ColIndices>::value) }; }; public: @@ -67,9 +72,8 @@ struct EIGEN_INDEXED_VIEW_METHOD_TYPE { // This is the generic version template<typename RowIndices, typename ColIndices> -typename internal::enable_if< - ! (internal::traits<typename EIGEN_INDEXED_VIEW_METHOD_TYPE<RowIndices,ColIndices>::type>::IsBlockAlike - || (internal::is_integral<RowIndices>::value && internal::is_integral<ColIndices>::value)), +typename internal::enable_if<valid_indexed_view_overload<RowIndices,ColIndices>::value + && internal::traits<typename EIGEN_INDEXED_VIEW_METHOD_TYPE<RowIndices,ColIndices>::type>::ReturnAsIndexedView, typename EIGEN_INDEXED_VIEW_METHOD_TYPE<RowIndices,ColIndices>::type >::type operator()(const RowIndices& rowIndices, const ColIndices& colIndices) EIGEN_INDEXED_VIEW_METHOD_CONST { @@ -80,9 +84,8 @@ operator()(const RowIndices& rowIndices, const ColIndices& colIndices) EIGEN_IND // The folowing overload returns a Block<> object template<typename RowIndices, typename ColIndices> -typename internal::enable_if< - internal::traits<typename EIGEN_INDEXED_VIEW_METHOD_TYPE<RowIndices,ColIndices>::type>::IsBlockAlike - && !(internal::is_integral<RowIndices>::value && internal::is_integral<ColIndices>::value), +typename internal::enable_if<valid_indexed_view_overload<RowIndices,ColIndices>::value + && internal::traits<typename EIGEN_INDEXED_VIEW_METHOD_TYPE<RowIndices,ColIndices>::type>::ReturnAsBlock, typename internal::traits<typename EIGEN_INDEXED_VIEW_METHOD_TYPE<RowIndices,ColIndices>::type>::BlockType>::type operator()(const RowIndices& rowIndices, const ColIndices& colIndices) EIGEN_INDEXED_VIEW_METHOD_CONST { @@ -96,6 +99,17 @@ operator()(const RowIndices& rowIndices, const ColIndices& colIndices) EIGEN_IND internal::size(actualColIndices)); } +// The following overload returns a Scalar + +template<typename RowIndices, typename ColIndices> +typename internal::enable_if<valid_indexed_view_overload<RowIndices,ColIndices>::value + && internal::traits<typename EIGEN_INDEXED_VIEW_METHOD_TYPE<RowIndices,ColIndices>::type>::ReturnAsScalar, + CoeffReturnType >::type +operator()(const RowIndices& rowIndices, const ColIndices& colIndices) EIGEN_INDEXED_VIEW_METHOD_CONST +{ + return Base::operator()(internal::eval_expr_given_size(rowIndices,rows()),internal::eval_expr_given_size(colIndices,cols())); +} + // The folowing three overloads are needed to handle raw Index[N] arrays. template<typename RowIndicesT, std::size_t RowIndicesN, typename ColIndices> @@ -148,7 +162,7 @@ operator()(const Indices& indices) EIGEN_INDEXED_VIEW_METHOD_CONST template<typename Indices> typename internal::enable_if< - (internal::get_compile_time_incr<typename IvcType<Indices>::type>::value==1) && (!internal::is_integral<Indices>::value), + (internal::get_compile_time_incr<typename IvcType<Indices>::type>::value==1) && (!internal::is_integral<Indices>::value) && (!Symbolic::is_symbolic<Indices>::value), VectorBlock<EIGEN_INDEXED_VIEW_METHOD_CONST Derived,internal::array_size<Indices>::value> >::type operator()(const Indices& indices) EIGEN_INDEXED_VIEW_METHOD_CONST { @@ -158,6 +172,13 @@ operator()(const Indices& indices) EIGEN_INDEXED_VIEW_METHOD_CONST (derived(), internal::first(actualIndices), internal::size(actualIndices)); } +template<typename IndexType> +typename internal::enable_if<Symbolic::is_symbolic<IndexType>::value, CoeffReturnType >::type +operator()(const IndexType& id) EIGEN_INDEXED_VIEW_METHOD_CONST +{ + return Base::operator()(internal::eval_expr_given_size(id,size())); +} + template<typename IndicesT, std::size_t IndicesN> typename internal::enable_if<IsRowMajor, IndexedView<EIGEN_INDEXED_VIEW_METHOD_CONST Derived,IvcIndex,const IndicesT (&)[IndicesN]> >::type |