aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--Eigen/src/Core/Dot.h16
-rw-r--r--Eigen/src/Core/Functors.h21
-rw-r--r--Eigen/src/Core/MatrixBase.h5
-rw-r--r--Eigen/src/Core/util/ForwardDeclarations.h2
-rw-r--r--test/adjoint.cpp5
5 files changed, 31 insertions, 18 deletions
diff --git a/Eigen/src/Core/Dot.h b/Eigen/src/Core/Dot.h
index 16496273c..0d8856efa 100644
--- a/Eigen/src/Core/Dot.h
+++ b/Eigen/src/Core/Dot.h
@@ -41,18 +41,20 @@ template<typename T, typename U,
>
struct dot_nocheck
{
- static inline typename traits<T>::Scalar run(const MatrixBase<T>& a, const MatrixBase<U>& b)
+ typedef typename scalar_product_traits<typename traits<T>::Scalar,typename traits<U>::Scalar>::ReturnType ResScalar;
+ static inline ResScalar run(const MatrixBase<T>& a, const MatrixBase<U>& b)
{
- return a.template binaryExpr<scalar_conj_product_op<typename traits<T>::Scalar> >(b).sum();
+ return a.template binaryExpr<scalar_conj_product_op<typename traits<T>::Scalar,typename traits<U>::Scalar> >(b).sum();
}
};
template<typename T, typename U>
struct dot_nocheck<T, U, true>
{
- static inline typename traits<T>::Scalar run(const MatrixBase<T>& a, const MatrixBase<U>& b)
+ typedef typename scalar_product_traits<typename traits<T>::Scalar,typename traits<U>::Scalar>::ReturnType ResScalar;
+ static inline ResScalar run(const MatrixBase<T>& a, const MatrixBase<U>& b)
{
- return a.transpose().template binaryExpr<scalar_conj_product_op<typename traits<T>::Scalar> >(b).sum();
+ return a.transpose().template binaryExpr<scalar_conj_product_op<typename traits<T>::Scalar,typename traits<U>::Scalar> >(b).sum();
}
};
@@ -70,14 +72,14 @@ struct dot_nocheck<T, U, true>
*/
template<typename Derived>
template<typename OtherDerived>
-typename internal::traits<Derived>::Scalar
+typename internal::scalar_product_traits<typename internal::traits<Derived>::Scalar,typename internal::traits<OtherDerived>::Scalar>::ReturnType
MatrixBase<Derived>::dot(const MatrixBase<OtherDerived>& other) const
{
EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived)
EIGEN_STATIC_ASSERT_VECTOR_ONLY(OtherDerived)
EIGEN_STATIC_ASSERT_SAME_VECTOR_SIZE(Derived,OtherDerived)
- EIGEN_STATIC_ASSERT((internal::is_same<Scalar, typename OtherDerived::Scalar>::value),
- YOU_MIXED_DIFFERENT_NUMERIC_TYPES__YOU_NEED_TO_USE_THE_CAST_METHOD_OF_MATRIXBASE_TO_CAST_NUMERIC_TYPES_EXPLICITLY)
+ typedef internal::scalar_conj_product_op<Scalar,typename OtherDerived::Scalar> func;
+ EIGEN_CHECK_BINARY_COMPATIBILIY(func,Scalar,typename OtherDerived::Scalar);
eigen_assert(size() == other.size());
diff --git a/Eigen/src/Core/Functors.h b/Eigen/src/Core/Functors.h
index 325a7dd85..917769c9e 100644
--- a/Eigen/src/Core/Functors.h
+++ b/Eigen/src/Core/Functors.h
@@ -59,6 +59,7 @@ struct functor_traits<scalar_sum_op<Scalar> > {
*/
template<typename LhsScalar,typename RhsScalar> struct scalar_product_op {
enum {
+ // TODO vectorize mixed product
Vectorizable = is_same<LhsScalar,RhsScalar>::value && packet_traits<LhsScalar>::HasMul && packet_traits<RhsScalar>::HasMul
};
typedef typename scalar_product_traits<LhsScalar,RhsScalar>::ReturnType result_type;
@@ -84,24 +85,27 @@ struct functor_traits<scalar_product_op<LhsScalar,RhsScalar> > {
*
* This is a short cut for conj(x) * y which is needed for optimization purpose; in Eigen2 support mode, this becomes x * conj(y)
*/
-template<typename Scalar> struct scalar_conj_product_op {
+template<typename LhsScalar,typename RhsScalar> struct scalar_conj_product_op {
enum {
- Conj = NumTraits<Scalar>::IsComplex
+ Conj = NumTraits<LhsScalar>::IsComplex
};
+ typedef typename scalar_product_traits<LhsScalar,RhsScalar>::ReturnType result_type;
+
EIGEN_EMPTY_STRUCT_CTOR(scalar_conj_product_op)
- EIGEN_STRONG_INLINE const Scalar operator() (const Scalar& a, const Scalar& b) const
- { return conj_helper<Scalar,Scalar,Conj,false>().pmul(a,b); }
+ EIGEN_STRONG_INLINE const result_type operator() (const LhsScalar& a, const RhsScalar& b) const
+ { return conj_helper<LhsScalar,RhsScalar,Conj,false>().pmul(a,b); }
+
template<typename Packet>
EIGEN_STRONG_INLINE const Packet packetOp(const Packet& a, const Packet& b) const
{ return conj_helper<Packet,Packet,Conj,false>().pmul(a,b); }
};
-template<typename Scalar>
-struct functor_traits<scalar_conj_product_op<Scalar> > {
+template<typename LhsScalar,typename RhsScalar>
+struct functor_traits<scalar_conj_product_op<LhsScalar,RhsScalar> > {
enum {
- Cost = NumTraits<Scalar>::MulCost,
- PacketAccess = packet_traits<Scalar>::HasMul
+ Cost = NumTraits<LhsScalar>::MulCost,
+ PacketAccess = internal::is_same<LhsScalar, RhsScalar>::value && packet_traits<LhsScalar>::HasMul
};
};
@@ -622,6 +626,7 @@ template<typename Scalar> struct functor_has_linear_access<scalar_identity_op<Sc
// FIXME move this to functor_traits adding a functor_default
template<typename Functor> struct functor_allows_mixing_real_and_complex { enum { ret = 0 }; };
template<typename LhsScalar,typename RhsScalar> struct functor_allows_mixing_real_and_complex<scalar_product_op<LhsScalar,RhsScalar> > { enum { ret = 1 }; };
+template<typename LhsScalar,typename RhsScalar> struct functor_allows_mixing_real_and_complex<scalar_conj_product_op<LhsScalar,RhsScalar> > { enum { ret = 1 }; };
/** \internal
diff --git a/Eigen/src/Core/MatrixBase.h b/Eigen/src/Core/MatrixBase.h
index 3b854ca5e..f318bfd5d 100644
--- a/Eigen/src/Core/MatrixBase.h
+++ b/Eigen/src/Core/MatrixBase.h
@@ -202,10 +202,11 @@ template<typename Derived> class MatrixBase
#if EIGEN2_SUPPORT_STAGE != STAGE20_RESOLVE_API_CONFLICTS
template<typename OtherDerived>
+ typename internal::scalar_product_traits<typename internal::traits<Derived>::Scalar,typename internal::traits<OtherDerived>::Scalar>::ReturnType
#if EIGEN2_SUPPORT_STAGE == STAGE15_RESOLVE_API_CONFLICTS_WARN
- EIGEN_DEPRECATED
+ EIGEN_DEPRECATED Scalar
#endif
- Scalar dot(const MatrixBase<OtherDerived>& other) const;
+ dot(const MatrixBase<OtherDerived>& other) const;
#endif
#ifdef EIGEN2_SUPPORT
diff --git a/Eigen/src/Core/util/ForwardDeclarations.h b/Eigen/src/Core/util/ForwardDeclarations.h
index 5a2de7095..a516a7094 100644
--- a/Eigen/src/Core/util/ForwardDeclarations.h
+++ b/Eigen/src/Core/util/ForwardDeclarations.h
@@ -155,7 +155,7 @@ template<typename LhsScalar, typename RhsScalar, bool ConjLhs=false, bool ConjRh
template<typename Scalar> struct scalar_sum_op;
template<typename Scalar> struct scalar_difference_op;
-template<typename Scalar> struct scalar_conj_product_op;
+template<typename LhsScalar,typename RhsScalar> struct scalar_conj_product_op;
template<typename Scalar> struct scalar_quotient_op;
template<typename Scalar> struct scalar_opposite_op;
template<typename Scalar> struct scalar_conjugate_op;
diff --git a/test/adjoint.cpp b/test/adjoint.cpp
index 72cbf3406..47889591f 100644
--- a/test/adjoint.cpp
+++ b/test/adjoint.cpp
@@ -106,6 +106,11 @@ template<typename MatrixType> void adjoint(const MatrixType& m)
m3.transposeInPlace();
VERIFY_IS_APPROX(m3,m1.conjugate());
+ // check mixed dot product
+ typedef Matrix<RealScalar, MatrixType::RowsAtCompileTime, 1> RealVectorType;
+ RealVectorType rv1 = RealVectorType::Random(rows);
+ VERIFY_IS_APPROX(v1.dot(rv1.template cast<Scalar>()), v1.dot(rv1));
+ VERIFY_IS_APPROX(rv1.template cast<Scalar>().dot(v1), rv1.dot(v1));
}
void test_adjoint()