diff options
-rw-r--r-- | Eigen/src/Core/Ref.h | 16 | ||||
-rw-r--r-- | test/ref.cpp | 24 |
2 files changed, 35 insertions, 5 deletions
diff --git a/Eigen/src/Core/Ref.h b/Eigen/src/Core/Ref.h index 0cb117949..ea5a2bd5c 100644 --- a/Eigen/src/Core/Ref.h +++ b/Eigen/src/Core/Ref.h @@ -105,7 +105,8 @@ struct traits<Ref<_PlainObjectType, _Options, _StrideType> > OuterStrideMatch = Derived::IsVectorAtCompileTime || int(StrideType::OuterStrideAtCompileTime)==int(Dynamic) || int(StrideType::OuterStrideAtCompileTime)==int(Derived::OuterStrideAtCompileTime), AlignmentMatch = (_Options!=Aligned) || ((PlainObjectType::Flags&AlignedBit)==0) || ((traits<Derived>::Flags&AlignedBit)==AlignedBit), - MatchAtCompileTime = HasDirectAccess && StorageOrderMatch && InnerStrideMatch && OuterStrideMatch && AlignmentMatch + ScalarTypeMatch = internal::is_same<typename PlainObjectType::Scalar, typename Derived::Scalar>::value, + MatchAtCompileTime = HasDirectAccess && StorageOrderMatch && InnerStrideMatch && OuterStrideMatch && AlignmentMatch && ScalarTypeMatch }; typedef typename internal::conditional<MatchAtCompileTime,internal::true_type,internal::false_type>::type type; }; @@ -184,9 +185,11 @@ protected: template<typename PlainObjectType, int Options, typename StrideType> class Ref : public RefBase<Ref<PlainObjectType, Options, StrideType> > { + private: typedef internal::traits<Ref> Traits; template<typename Derived> - EIGEN_DEVICE_FUNC inline Ref(const PlainObjectBase<Derived>& expr); + EIGEN_DEVICE_FUNC inline Ref(const PlainObjectBase<Derived>& expr, + typename internal::enable_if<bool(Traits::template match<Derived>::MatchAtCompileTime),Derived>::type* = 0); public: typedef RefBase<Ref> Base; @@ -195,13 +198,15 @@ template<typename PlainObjectType, int Options, typename StrideType> class Ref #ifndef EIGEN_PARSED_BY_DOXYGEN template<typename Derived> - EIGEN_DEVICE_FUNC inline Ref(PlainObjectBase<Derived>& expr) + EIGEN_DEVICE_FUNC inline Ref(PlainObjectBase<Derived>& expr, + typename internal::enable_if<bool(Traits::template match<Derived>::MatchAtCompileTime),Derived>::type* = 0) { EIGEN_STATIC_ASSERT(bool(Traits::template match<Derived>::MatchAtCompileTime), STORAGE_LAYOUT_DOES_NOT_MATCH); Base::construct(expr.derived()); } template<typename Derived> - EIGEN_DEVICE_FUNC inline Ref(const DenseBase<Derived>& expr) + EIGEN_DEVICE_FUNC inline Ref(const DenseBase<Derived>& expr, + typename internal::enable_if<bool(Traits::template match<Derived>::MatchAtCompileTime),Derived>::type* = 0) #else template<typename Derived> inline Ref(DenseBase<Derived>& expr) @@ -228,7 +233,8 @@ template<typename TPlainObjectType, int Options, typename StrideType> class Ref< EIGEN_DENSE_PUBLIC_INTERFACE(Ref) template<typename Derived> - EIGEN_DEVICE_FUNC inline Ref(const DenseBase<Derived>& expr) + EIGEN_DEVICE_FUNC inline Ref(const DenseBase<Derived>& expr, + typename internal::enable_if<bool(Traits::template match<Derived>::ScalarTypeMatch),Derived>::type* = 0) { // std::cout << match_helper<Derived>::HasDirectAccess << "," << match_helper<Derived>::OuterStrideMatch << "," << match_helper<Derived>::InnerStrideMatch << "\n"; // std::cout << int(StrideType::OuterStrideAtCompileTime) << " - " << int(Derived::OuterStrideAtCompileTime) << "\n"; diff --git a/test/ref.cpp b/test/ref.cpp index b9470213c..fbe2c450f 100644 --- a/test/ref.cpp +++ b/test/ref.cpp @@ -228,6 +228,28 @@ void call_ref() VERIFY_EVALUATION_COUNT( call_ref_7(c,c), 0); } +typedef Matrix<double,Dynamic,Dynamic,RowMajor> RowMatrixXd; +int test_ref_overload_fun1(Ref<MatrixXd> ) { return 1; } +int test_ref_overload_fun1(Ref<RowMatrixXd> ) { return 2; } +int test_ref_overload_fun1(Ref<MatrixXf> ) { return 3; } + +int test_ref_overload_fun2(Ref<const MatrixXd> ) { return 4; } +int test_ref_overload_fun2(Ref<const MatrixXf> ) { return 5; } + +// See also bug 969 +void test_ref_overloads() +{ + MatrixXd Ad, Bd; + RowMatrixXd rAd, rBd; + VERIFY( test_ref_overload_fun1(Ad)==1 ); + VERIFY( test_ref_overload_fun1(rAd)==2 ); + + MatrixXf Af, Bf; + VERIFY( test_ref_overload_fun2(Ad)==4 ); + VERIFY( test_ref_overload_fun2(Ad+Bd)==4 ); + VERIFY( test_ref_overload_fun2(Af+Bf)==5 ); +} + void test_ref() { for(int i = 0; i < g_repeat; i++) { @@ -248,4 +270,6 @@ void test_ref() CALL_SUBTEST_5( ref_matrix(MatrixXi(internal::random<int>(1,10),internal::random<int>(1,10))) ); CALL_SUBTEST_6( call_ref() ); } + + CALL_SUBTEST_7( test_ref_overloads() ); } |