diff options
Diffstat (limited to 'Eigen/src/Core/Transpose.h')
-rw-r--r-- | Eigen/src/Core/Transpose.h | 122 |
1 files changed, 52 insertions, 70 deletions
diff --git a/Eigen/src/Core/Transpose.h b/Eigen/src/Core/Transpose.h index aba3f6670..3bab6092c 100644 --- a/Eigen/src/Core/Transpose.h +++ b/Eigen/src/Core/Transpose.h @@ -2,7 +2,7 @@ // for linear algebra. // // Copyright (C) 2006-2008 Benoit Jacob <jacob.benoit.1@gmail.com> -// Copyright (C) 2009-2010 Gael Guennebaud <gael.guennebaud@inria.fr> +// Copyright (C) 2009-2014 Gael Guennebaud <gael.guennebaud@inria.fr> // // This Source Code Form is subject to the terms of the Mozilla // Public License v. 2.0. If a copy of the MPL was not distributed @@ -29,9 +29,10 @@ namespace Eigen { namespace internal { template<typename MatrixType> -struct traits<Transpose<MatrixType> > : traits<MatrixType> +struct traits<Transpose<MatrixType> > { - typedef typename MatrixType::Scalar Scalar; + typedef typename traits<MatrixType>::Scalar Scalar; + typedef typename traits<MatrixType>::Index Index; typedef typename nested<MatrixType>::type MatrixTypeNested; typedef typename remove_reference<MatrixTypeNested>::type MatrixTypeNestedPlain; typedef typename traits<MatrixType>::StorageKind StorageKind; @@ -45,7 +46,6 @@ struct traits<Transpose<MatrixType> > : traits<MatrixType> Flags0 = MatrixTypeNestedPlain::Flags & ~(LvalueBit | NestByRefBit), Flags1 = Flags0 | FlagsLvalueBit, Flags = Flags1 ^ RowMajorBit, - CoeffReadCost = MatrixTypeNestedPlain::CoeffReadCost, InnerStrideAtCompileTime = inner_stride_at_compile_time<MatrixType>::ret, OuterStrideAtCompileTime = outer_stride_at_compile_time<MatrixType>::ret }; @@ -61,9 +61,10 @@ template<typename MatrixType> class Transpose typedef typename TransposeImpl<MatrixType,typename internal::traits<MatrixType>::StorageKind>::Base Base; EIGEN_GENERIC_PUBLIC_INTERFACE(Transpose) + typedef typename internal::remove_all<MatrixType>::type NestedExpression; EIGEN_DEVICE_FUNC - inline Transpose(MatrixType& a_matrix) : m_matrix(a_matrix) {} + explicit inline Transpose(MatrixType& a_matrix) : m_matrix(a_matrix) {} EIGEN_INHERIT_ASSIGNMENT_OPERATORS(Transpose) @@ -100,12 +101,22 @@ struct TransposeImpl_base<MatrixType, false> } // end namespace internal +// Generic API dispatcher +template<typename XprType, typename StorageKind> +class TransposeImpl + : public internal::generic_xpr_base<Transpose<XprType> >::type +{ +public: + typedef typename internal::generic_xpr_base<Transpose<XprType> >::type Base; +}; + template<typename MatrixType> class TransposeImpl<MatrixType,Dense> : public internal::TransposeImpl_base<MatrixType>::type { public: typedef typename internal::TransposeImpl_base<MatrixType>::type Base; + using Base::coeffRef; EIGEN_DENSE_PUBLIC_INTERFACE(Transpose<MatrixType>) EIGEN_INHERIT_ASSIGNMENT_OPERATORS(TransposeImpl) @@ -118,23 +129,10 @@ template<typename MatrixType> class TransposeImpl<MatrixType,Dense> const Scalar >::type ScalarWithConstIfNotLvalue; - inline ScalarWithConstIfNotLvalue* data() { return derived().nestedExpression().data(); } - inline const Scalar* data() const { return derived().nestedExpression().data(); } - - EIGEN_DEVICE_FUNC - inline ScalarWithConstIfNotLvalue& coeffRef(Index rowId, Index colId) - { - EIGEN_STATIC_ASSERT_LVALUE(MatrixType) - return derived().nestedExpression().const_cast_derived().coeffRef(colId, rowId); - } - - EIGEN_DEVICE_FUNC - inline ScalarWithConstIfNotLvalue& coeffRef(Index index) - { - EIGEN_STATIC_ASSERT_LVALUE(MatrixType) - return derived().nestedExpression().const_cast_derived().coeffRef(index); - } + EIGEN_DEVICE_FUNC inline ScalarWithConstIfNotLvalue* data() { return derived().nestedExpression().data(); } + EIGEN_DEVICE_FUNC inline const Scalar* data() const { return derived().nestedExpression().data(); } + // FIXME: shall we keep the const version of coeffRef? EIGEN_DEVICE_FUNC inline const Scalar& coeffRef(Index rowId, Index colId) const { @@ -146,42 +144,6 @@ template<typename MatrixType> class TransposeImpl<MatrixType,Dense> { return derived().nestedExpression().coeffRef(index); } - - EIGEN_DEVICE_FUNC - inline CoeffReturnType coeff(Index rowId, Index colId) const - { - return derived().nestedExpression().coeff(colId, rowId); - } - - EIGEN_DEVICE_FUNC - inline CoeffReturnType coeff(Index index) const - { - return derived().nestedExpression().coeff(index); - } - - template<int LoadMode> - inline const PacketScalar packet(Index rowId, Index colId) const - { - return derived().nestedExpression().template packet<LoadMode>(colId, rowId); - } - - template<int LoadMode> - inline void writePacket(Index rowId, Index colId, const PacketScalar& x) - { - derived().nestedExpression().const_cast_derived().template writePacket<LoadMode>(colId, rowId, x); - } - - template<int LoadMode> - inline const PacketScalar packet(Index index) const - { - return derived().nestedExpression().template packet<LoadMode>(index); - } - - template<int LoadMode> - inline void writePacket(Index index, const PacketScalar& x) - { - derived().nestedExpression().const_cast_derived().template writePacket<LoadMode>(index, x); - } }; /** \returns an expression of the transpose of *this. @@ -207,7 +169,7 @@ template<typename Derived> inline Transpose<Derived> DenseBase<Derived>::transpose() { - return derived(); + return TransposeReturnType(derived()); } /** This is the const version of transpose(). @@ -245,8 +207,7 @@ template<typename Derived> inline const typename MatrixBase<Derived>::AdjointReturnType MatrixBase<Derived>::adjoint() const { - return this->transpose(); // in the complex case, the .conjugate() is be implicit here - // due to implicit conversion to return type + return AdjointReturnType(this->transpose()); } /*************************************************************************** @@ -256,18 +217,39 @@ MatrixBase<Derived>::adjoint() const namespace internal { template<typename MatrixType, - bool IsSquare = (MatrixType::RowsAtCompileTime == MatrixType::ColsAtCompileTime) && MatrixType::RowsAtCompileTime!=Dynamic> + bool IsSquare = (MatrixType::RowsAtCompileTime == MatrixType::ColsAtCompileTime) && MatrixType::RowsAtCompileTime!=Dynamic, + bool MatchPacketSize = + (int(MatrixType::RowsAtCompileTime) == int(internal::packet_traits<typename MatrixType::Scalar>::size)) + && (internal::evaluator<MatrixType>::Flags&PacketAccessBit) > struct inplace_transpose_selector; template<typename MatrixType> -struct inplace_transpose_selector<MatrixType,true> { // square matrix +struct inplace_transpose_selector<MatrixType,true,false> { // square matrix static void run(MatrixType& m) { m.matrix().template triangularView<StrictlyUpper>().swap(m.matrix().transpose()); } }; +// TODO: vectorized path is currently limited to LargestPacketSize x LargestPacketSize cases only. template<typename MatrixType> -struct inplace_transpose_selector<MatrixType,false> { // non square matrix +struct inplace_transpose_selector<MatrixType,true,true> { // PacketSize x PacketSize + static void run(MatrixType& m) { + typedef typename MatrixType::Scalar Scalar; + typedef typename internal::packet_traits<typename MatrixType::Scalar>::type Packet; + typedef typename MatrixType::Index Index; + const Index PacketSize = internal::packet_traits<Scalar>::size; + const Index Alignment = internal::evaluator<MatrixType>::Flags&AlignedBit ? Aligned : Unaligned; + PacketBlock<Packet> A; + for (Index i=0; i<PacketSize; ++i) + A.packet[i] = m.template packetByOuterInner<Alignment>(i,0); + internal::ptranspose(A); + for (Index i=0; i<PacketSize; ++i) + m.template writePacket<Alignment>(m.rowIndexByOuterInner(i,0), m.colIndexByOuterInner(i,0), A.packet[i]); + } +}; + +template<typename MatrixType,bool MatchPacketSize> +struct inplace_transpose_selector<MatrixType,false,MatchPacketSize> { // non square matrix static void run(MatrixType& m) { if (m.rows()==m.cols()) m.matrix().template triangularView<StrictlyUpper>().swap(m.matrix().transpose()); @@ -413,15 +395,15 @@ struct checkTransposeAliasing_impl<Derived, OtherDerived, false> } }; -} // end namespace internal - -template<typename Derived> -template<typename OtherDerived> -void DenseBase<Derived>::checkTransposeAliasing(const OtherDerived& other) const +template<typename Dst, typename Src> +void check_for_aliasing(const Dst &dst, const Src &src) { - internal::checkTransposeAliasing_impl<Derived, OtherDerived>::run(derived(), other); + internal::checkTransposeAliasing_impl<Dst, Src>::run(dst, src); } -#endif + +} // end namespace internal + +#endif // EIGEN_NO_DEBUG } // end namespace Eigen |