diff options
author | David Tellenbach <david.tellenbach@me.com> | 2020-07-09 17:24:00 +0000 |
---|---|---|
committer | Rasmus Munk Larsen <rmlarsen@google.com> | 2020-07-09 17:24:00 +0000 |
commit | ee4715ff488d5cf685820ad06374cbe7d509ac1a (patch) | |
tree | 29ffd26a21626eda74c101a8df497d323acfb8e3 | |
parent | 8889a2c1c648f5dd1413dc2d94c2407c7ce1bd32 (diff) |
Fix test basic stuff
- Guard fundamental types that are not available pre C++11
- Separate subsequent angle brackets >> by spaces
- Allow casting of Eigen::half and Eigen::bfloat16 to complex types
-rw-r--r-- | Eigen/src/Core/arch/Default/BFloat16.h | 14 | ||||
-rw-r--r-- | Eigen/src/Core/arch/Default/Half.h | 17 | ||||
-rw-r--r-- | test/basicstuff.cpp | 98 |
3 files changed, 95 insertions, 34 deletions
diff --git a/Eigen/src/Core/arch/Default/BFloat16.h b/Eigen/src/Core/arch/Default/BFloat16.h index 99ce99a27..561304f80 100644 --- a/Eigen/src/Core/arch/Default/BFloat16.h +++ b/Eigen/src/Core/arch/Default/BFloat16.h @@ -27,6 +27,20 @@ namespace Eigen { struct bfloat16; +// explicit conversion operators are no available before C++11 so we first cast +// bfloat16 to RealScalar rather than to std::complex<RealScalar> directly +#if !EIGEN_HAS_CXX11 +namespace internal { +template <typename RealScalar> +struct cast_impl<bfloat16, std::complex<RealScalar> > { + EIGEN_DEVICE_FUNC static inline std::complex<RealScalar> run(const bfloat16 &x) + { + return static_cast<std::complex<RealScalar> >(static_cast<RealScalar>(x)); + } +}; +} // namespace internal +#endif // EIGEN_HAS_CXX11 + namespace bfloat16_impl { // Make our own __bfloat16_raw definition. diff --git a/Eigen/src/Core/arch/Default/Half.h b/Eigen/src/Core/arch/Default/Half.h index b84cfc7db..bbc15d463 100644 --- a/Eigen/src/Core/arch/Default/Half.h +++ b/Eigen/src/Core/arch/Default/Half.h @@ -36,7 +36,7 @@ #ifndef EIGEN_HALF_H #define EIGEN_HALF_H -#if __cplusplus > 199711L +#if EIGEN_HAS_CXX11 #define EIGEN_EXPLICIT_CAST(tgt_type) explicit operator tgt_type() #else #define EIGEN_EXPLICIT_CAST(tgt_type) operator tgt_type() @@ -48,6 +48,20 @@ namespace Eigen { struct half; +// explicit conversion operators are no available before C++11 so we first cast +// half to RealScalar rather than to std::complex<RealScalar> directly +#if !EIGEN_HAS_CXX11 +namespace internal { +template <typename RealScalar> +struct cast_impl<half, std::complex<RealScalar> > { + EIGEN_DEVICE_FUNC static inline std::complex<RealScalar> run(const half &x) + { + return static_cast<std::complex<RealScalar> >(static_cast<RealScalar>(x)); + } +}; +} // namespace internal +#endif // EIGEN_HAS_CXX11 + namespace half_impl { #if !defined(EIGEN_HAS_GPU_FP16) @@ -737,7 +751,6 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half __ldg(const Eigen::half* ptr) } #endif - #if defined(EIGEN_GPU_COMPILE_PHASE) namespace Eigen { namespace numext { diff --git a/test/basicstuff.cpp b/test/basicstuff.cpp index 80fc8a07f..f9044a27a 100644 --- a/test/basicstuff.cpp +++ b/test/basicstuff.cpp @@ -195,41 +195,73 @@ template<typename MatrixType> void basicStuffComplex(const MatrixType& m) VERIFY(!static_cast<const MatrixType&>(cm).imag().isZero()); } +template<typename SrcScalar, typename TgtScalar, bool SrcIsHalfOrBF16 = (internal::is_same<SrcScalar, half>::value || internal::is_same<SrcScalar, bfloat16>::value)> struct casting_test; + + template<typename SrcScalar, typename TgtScalar> -void casting_test() -{ - Matrix<SrcScalar,4,4> m; - for (int i=0; i<m.rows(); ++i) { - for (int j=0; j<m.cols(); ++j) { - m(i, j) = internal::random_without_cast_overflow<SrcScalar,TgtScalar>::value(); +struct casting_test<SrcScalar, TgtScalar, false> { + static void run() { + Matrix<SrcScalar,4,4> m; + for (int i=0; i<m.rows(); ++i) { + for (int j=0; j<m.cols(); ++j) { + m(i, j) = internal::random_without_cast_overflow<SrcScalar,TgtScalar>::value(); + } + } + Matrix<TgtScalar,4,4> n = m.template cast<TgtScalar>(); + for (int i=0; i<m.rows(); ++i) { + for (int j=0; j<m.cols(); ++j) { + VERIFY_IS_APPROX(n(i, j), static_cast<TgtScalar>(m(i, j))); + } } } - Matrix<TgtScalar,4,4> n = m.template cast<TgtScalar>(); - for (int i=0; i<m.rows(); ++i) { - for (int j=0; j<m.cols(); ++j) { - VERIFY_IS_APPROX(n(i, j), static_cast<TgtScalar>(m(i, j))); +}; + +template<typename SrcScalar, typename TgtScalar> +struct casting_test<SrcScalar, TgtScalar, true> { + static void run() { + casting_test<SrcScalar, TgtScalar, false>::run(); + } +}; + +template<typename SrcScalar, typename RealScalar> +struct casting_test<SrcScalar, std::complex<RealScalar>, true> { + static void run() { + typedef std::complex<RealScalar> TgtScalar; + Matrix<SrcScalar,4,4> m; + for (int i=0; i<m.rows(); ++i) { + for (int j=0; j<m.cols(); ++j) { + m(i, j) = internal::random_without_cast_overflow<SrcScalar, TgtScalar>::value(); + } + } + Matrix<TgtScalar,4,4> n = m.template cast<TgtScalar>(); + for (int i=0; i<m.rows(); ++i) { + for (int j=0; j<m.cols(); ++j) { + VERIFY_IS_APPROX(n(i, j), static_cast<TgtScalar>(static_cast<RealScalar>(m(i, j)))); + } } } -} +}; template<typename SrcScalar, typename EnableIf = void> struct casting_test_runner { static void run() { - casting_test<SrcScalar, bool>(); - casting_test<SrcScalar, int8_t>(); - casting_test<SrcScalar, uint8_t>(); - casting_test<SrcScalar, int16_t>(); - casting_test<SrcScalar, uint16_t>(); - casting_test<SrcScalar, int32_t>(); - casting_test<SrcScalar, uint32_t>(); - casting_test<SrcScalar, int64_t>(); - casting_test<SrcScalar, uint64_t>(); - casting_test<SrcScalar, half>(); - casting_test<SrcScalar, bfloat16>(); - casting_test<SrcScalar, float>(); - casting_test<SrcScalar, double>(); - casting_test<SrcScalar, std::complex<float>>(); - casting_test<SrcScalar, std::complex<double>>(); + casting_test<SrcScalar, bool>::run(); + casting_test<SrcScalar, int8_t>::run(); + casting_test<SrcScalar, uint8_t>::run(); + casting_test<SrcScalar, int16_t>::run(); + casting_test<SrcScalar, uint16_t>::run(); + casting_test<SrcScalar, int32_t>::run(); + casting_test<SrcScalar, uint32_t>::run(); +#if EIGEN_HAS_CXX11 + casting_test<SrcScalar, int64_t>::run(); + casting_test<SrcScalar, uint64_t>::run(); +#endif + casting_test<SrcScalar, half>::run(); + casting_test<SrcScalar, bfloat16>::run(); + casting_test<SrcScalar, float>::run(); + casting_test<SrcScalar, double>::run(); + casting_test<SrcScalar, std::complex<float> >::run(); + casting_test<SrcScalar, std::complex<double> >::run(); } }; @@ -238,10 +270,10 @@ struct casting_test_runner<SrcScalar, typename internal::enable_if<(NumTraits<Sr { static void run() { // Only a few casts from std::complex<T> are defined. - casting_test<SrcScalar, half>(); - casting_test<SrcScalar, bfloat16>(); - casting_test<SrcScalar, std::complex<float>>(); - casting_test<SrcScalar, std::complex<double>>(); + casting_test<SrcScalar, half>::run(); + casting_test<SrcScalar, bfloat16>::run(); + casting_test<SrcScalar, std::complex<float> >::run(); + casting_test<SrcScalar, std::complex<double> >::run(); } }; @@ -253,14 +285,16 @@ void casting_all() { casting_test_runner<uint16_t>::run(); casting_test_runner<int32_t>::run(); casting_test_runner<uint32_t>::run(); +#if EIGEN_HAS_CXX11 casting_test_runner<int64_t>::run(); casting_test_runner<uint64_t>::run(); +#endif casting_test_runner<half>::run(); casting_test_runner<bfloat16>::run(); casting_test_runner<float>::run(); casting_test_runner<double>::run(); - casting_test_runner<std::complex<float>>::run(); - casting_test_runner<std::complex<double>>::run(); + casting_test_runner<std::complex<float> >::run(); + casting_test_runner<std::complex<double> >::run(); } template <typename Scalar> |