diff options
author | Gael Guennebaud <g.gael@free.fr> | 2011-01-27 09:59:19 +0100 |
---|---|---|
committer | Gael Guennebaud <g.gael@free.fr> | 2011-01-27 09:59:19 +0100 |
commit | 0bfb78c8240525a8065e008efad08498d572ef81 (patch) | |
tree | 0f899ed3e3dceb9bfc2b1d356178a83d2999510f | |
parent | fe3bb545e0a1a1cfaf36f6e25a5cc0bfb00337c6 (diff) |
allow mixed complex-real and real-complex dot products
-rw-r--r-- | Eigen/src/Core/Dot.h | 16 | ||||
-rw-r--r-- | Eigen/src/Core/Functors.h | 21 | ||||
-rw-r--r-- | Eigen/src/Core/MatrixBase.h | 5 | ||||
-rw-r--r-- | Eigen/src/Core/util/ForwardDeclarations.h | 2 | ||||
-rw-r--r-- | test/adjoint.cpp | 5 |
5 files changed, 31 insertions, 18 deletions
diff --git a/Eigen/src/Core/Dot.h b/Eigen/src/Core/Dot.h index 16496273c..0d8856efa 100644 --- a/Eigen/src/Core/Dot.h +++ b/Eigen/src/Core/Dot.h @@ -41,18 +41,20 @@ template<typename T, typename U, > struct dot_nocheck { - static inline typename traits<T>::Scalar run(const MatrixBase<T>& a, const MatrixBase<U>& b) + typedef typename scalar_product_traits<typename traits<T>::Scalar,typename traits<U>::Scalar>::ReturnType ResScalar; + static inline ResScalar run(const MatrixBase<T>& a, const MatrixBase<U>& b) { - return a.template binaryExpr<scalar_conj_product_op<typename traits<T>::Scalar> >(b).sum(); + return a.template binaryExpr<scalar_conj_product_op<typename traits<T>::Scalar,typename traits<U>::Scalar> >(b).sum(); } }; template<typename T, typename U> struct dot_nocheck<T, U, true> { - static inline typename traits<T>::Scalar run(const MatrixBase<T>& a, const MatrixBase<U>& b) + typedef typename scalar_product_traits<typename traits<T>::Scalar,typename traits<U>::Scalar>::ReturnType ResScalar; + static inline ResScalar run(const MatrixBase<T>& a, const MatrixBase<U>& b) { - return a.transpose().template binaryExpr<scalar_conj_product_op<typename traits<T>::Scalar> >(b).sum(); + return a.transpose().template binaryExpr<scalar_conj_product_op<typename traits<T>::Scalar,typename traits<U>::Scalar> >(b).sum(); } }; @@ -70,14 +72,14 @@ struct dot_nocheck<T, U, true> */ template<typename Derived> template<typename OtherDerived> -typename internal::traits<Derived>::Scalar +typename internal::scalar_product_traits<typename internal::traits<Derived>::Scalar,typename internal::traits<OtherDerived>::Scalar>::ReturnType MatrixBase<Derived>::dot(const MatrixBase<OtherDerived>& other) const { EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived) EIGEN_STATIC_ASSERT_VECTOR_ONLY(OtherDerived) EIGEN_STATIC_ASSERT_SAME_VECTOR_SIZE(Derived,OtherDerived) - EIGEN_STATIC_ASSERT((internal::is_same<Scalar, typename OtherDerived::Scalar>::value), - YOU_MIXED_DIFFERENT_NUMERIC_TYPES__YOU_NEED_TO_USE_THE_CAST_METHOD_OF_MATRIXBASE_TO_CAST_NUMERIC_TYPES_EXPLICITLY) + typedef internal::scalar_conj_product_op<Scalar,typename OtherDerived::Scalar> func; + EIGEN_CHECK_BINARY_COMPATIBILIY(func,Scalar,typename OtherDerived::Scalar); eigen_assert(size() == other.size()); diff --git a/Eigen/src/Core/Functors.h b/Eigen/src/Core/Functors.h index 325a7dd85..917769c9e 100644 --- a/Eigen/src/Core/Functors.h +++ b/Eigen/src/Core/Functors.h @@ -59,6 +59,7 @@ struct functor_traits<scalar_sum_op<Scalar> > { */ template<typename LhsScalar,typename RhsScalar> struct scalar_product_op { enum { + // TODO vectorize mixed product Vectorizable = is_same<LhsScalar,RhsScalar>::value && packet_traits<LhsScalar>::HasMul && packet_traits<RhsScalar>::HasMul }; typedef typename scalar_product_traits<LhsScalar,RhsScalar>::ReturnType result_type; @@ -84,24 +85,27 @@ struct functor_traits<scalar_product_op<LhsScalar,RhsScalar> > { * * This is a short cut for conj(x) * y which is needed for optimization purpose; in Eigen2 support mode, this becomes x * conj(y) */ -template<typename Scalar> struct scalar_conj_product_op { +template<typename LhsScalar,typename RhsScalar> struct scalar_conj_product_op { enum { - Conj = NumTraits<Scalar>::IsComplex + Conj = NumTraits<LhsScalar>::IsComplex }; + typedef typename scalar_product_traits<LhsScalar,RhsScalar>::ReturnType result_type; + EIGEN_EMPTY_STRUCT_CTOR(scalar_conj_product_op) - EIGEN_STRONG_INLINE const Scalar operator() (const Scalar& a, const Scalar& b) const - { return conj_helper<Scalar,Scalar,Conj,false>().pmul(a,b); } + EIGEN_STRONG_INLINE const result_type operator() (const LhsScalar& a, const RhsScalar& b) const + { return conj_helper<LhsScalar,RhsScalar,Conj,false>().pmul(a,b); } + template<typename Packet> EIGEN_STRONG_INLINE const Packet packetOp(const Packet& a, const Packet& b) const { return conj_helper<Packet,Packet,Conj,false>().pmul(a,b); } }; -template<typename Scalar> -struct functor_traits<scalar_conj_product_op<Scalar> > { +template<typename LhsScalar,typename RhsScalar> +struct functor_traits<scalar_conj_product_op<LhsScalar,RhsScalar> > { enum { - Cost = NumTraits<Scalar>::MulCost, - PacketAccess = packet_traits<Scalar>::HasMul + Cost = NumTraits<LhsScalar>::MulCost, + PacketAccess = internal::is_same<LhsScalar, RhsScalar>::value && packet_traits<LhsScalar>::HasMul }; }; @@ -622,6 +626,7 @@ template<typename Scalar> struct functor_has_linear_access<scalar_identity_op<Sc // FIXME move this to functor_traits adding a functor_default template<typename Functor> struct functor_allows_mixing_real_and_complex { enum { ret = 0 }; }; template<typename LhsScalar,typename RhsScalar> struct functor_allows_mixing_real_and_complex<scalar_product_op<LhsScalar,RhsScalar> > { enum { ret = 1 }; }; +template<typename LhsScalar,typename RhsScalar> struct functor_allows_mixing_real_and_complex<scalar_conj_product_op<LhsScalar,RhsScalar> > { enum { ret = 1 }; }; /** \internal diff --git a/Eigen/src/Core/MatrixBase.h b/Eigen/src/Core/MatrixBase.h index 3b854ca5e..f318bfd5d 100644 --- a/Eigen/src/Core/MatrixBase.h +++ b/Eigen/src/Core/MatrixBase.h @@ -202,10 +202,11 @@ template<typename Derived> class MatrixBase #if EIGEN2_SUPPORT_STAGE != STAGE20_RESOLVE_API_CONFLICTS template<typename OtherDerived> + typename internal::scalar_product_traits<typename internal::traits<Derived>::Scalar,typename internal::traits<OtherDerived>::Scalar>::ReturnType #if EIGEN2_SUPPORT_STAGE == STAGE15_RESOLVE_API_CONFLICTS_WARN - EIGEN_DEPRECATED + EIGEN_DEPRECATED Scalar #endif - Scalar dot(const MatrixBase<OtherDerived>& other) const; + dot(const MatrixBase<OtherDerived>& other) const; #endif #ifdef EIGEN2_SUPPORT diff --git a/Eigen/src/Core/util/ForwardDeclarations.h b/Eigen/src/Core/util/ForwardDeclarations.h index 5a2de7095..a516a7094 100644 --- a/Eigen/src/Core/util/ForwardDeclarations.h +++ b/Eigen/src/Core/util/ForwardDeclarations.h @@ -155,7 +155,7 @@ template<typename LhsScalar, typename RhsScalar, bool ConjLhs=false, bool ConjRh template<typename Scalar> struct scalar_sum_op; template<typename Scalar> struct scalar_difference_op; -template<typename Scalar> struct scalar_conj_product_op; +template<typename LhsScalar,typename RhsScalar> struct scalar_conj_product_op; template<typename Scalar> struct scalar_quotient_op; template<typename Scalar> struct scalar_opposite_op; template<typename Scalar> struct scalar_conjugate_op; diff --git a/test/adjoint.cpp b/test/adjoint.cpp index 72cbf3406..47889591f 100644 --- a/test/adjoint.cpp +++ b/test/adjoint.cpp @@ -106,6 +106,11 @@ template<typename MatrixType> void adjoint(const MatrixType& m) m3.transposeInPlace(); VERIFY_IS_APPROX(m3,m1.conjugate()); + // check mixed dot product + typedef Matrix<RealScalar, MatrixType::RowsAtCompileTime, 1> RealVectorType; + RealVectorType rv1 = RealVectorType::Random(rows); + VERIFY_IS_APPROX(v1.dot(rv1.template cast<Scalar>()), v1.dot(rv1)); + VERIFY_IS_APPROX(rv1.template cast<Scalar>().dot(v1), rv1.dot(v1)); } void test_adjoint() |