aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--Eigen/src/Core/MathFunctions.h22
-rw-r--r--Eigen/src/Core/arch/Default/Half.h14
-rw-r--r--test/basicstuff.cpp33
-rw-r--r--test/bfloat16_float.cpp16
4 files changed, 29 insertions, 56 deletions
diff --git a/Eigen/src/Core/MathFunctions.h b/Eigen/src/Core/MathFunctions.h
index 07f4b9493..e9da35995 100644
--- a/Eigen/src/Core/MathFunctions.h
+++ b/Eigen/src/Core/MathFunctions.h
@@ -376,7 +376,7 @@ struct hypot_retval
* Implementation of cast *
****************************************************************************/
-template<typename OldType, typename NewType>
+template<typename OldType, typename NewType, typename EnableIf = void>
struct cast_impl
{
EIGEN_DEVICE_FUNC
@@ -386,6 +386,22 @@ struct cast_impl
}
};
+// Casting from S -> Complex<T> leads to an implicit conversion from S to T,
+// generating warnings on clang. Here we explicitly cast the real component.
+template<typename OldType, typename NewType>
+struct cast_impl<OldType, NewType,
+ typename internal::enable_if<
+ !NumTraits<OldType>::IsComplex && NumTraits<NewType>::IsComplex
+ >::type>
+{
+ EIGEN_DEVICE_FUNC
+ static inline NewType run(const OldType& x)
+ {
+ typedef typename NumTraits<NewType>::Real NewReal;
+ return static_cast<NewType>(static_cast<NewReal>(x));
+ }
+};
+
// here, for once, we're plainly returning NewType: we don't want cast to do weird things.
template<typename OldType, typename NewType>
@@ -486,7 +502,7 @@ struct rint_retval
#if defined(EIGEN_HIP_DEVICE_COMPILE)
// HIP does not seem to have a native device side implementation for the math routine "arg"
using std::arg;
- #else
+ #else
EIGEN_USING_STD(arg);
#endif
return arg(x);
@@ -967,7 +983,7 @@ template<typename T> T generic_fast_tanh_float(const T& a_x);
namespace numext {
-#if (!defined(EIGEN_GPUCC) || defined(EIGEN_CONSTEXPR_ARE_DEVICE_FUNC))
+#if (!defined(EIGEN_GPUCC) || defined(EIGEN_CONSTEXPR_ARE_DEVICE_FUNC))
template<typename T>
EIGEN_DEVICE_FUNC
EIGEN_ALWAYS_INLINE T mini(const T& x, const T& y)
diff --git a/Eigen/src/Core/arch/Default/Half.h b/Eigen/src/Core/arch/Default/Half.h
index 4fdda8af8..bf408149a 100644
--- a/Eigen/src/Core/arch/Default/Half.h
+++ b/Eigen/src/Core/arch/Default/Half.h
@@ -60,20 +60,6 @@ namespace Eigen {
struct half;
-// explicit conversion operators are no available before C++11 so we first cast
-// half to RealScalar rather than to std::complex<RealScalar> directly
-#if !EIGEN_HAS_CXX11
-namespace internal {
-template <typename RealScalar>
-struct cast_impl<half, std::complex<RealScalar> > {
- EIGEN_DEVICE_FUNC static inline std::complex<RealScalar> run(const half &x)
- {
- return static_cast<std::complex<RealScalar> >(static_cast<RealScalar>(x));
- }
-};
-} // namespace internal
-#endif // EIGEN_HAS_CXX11
-
namespace half_impl {
#if !defined(EIGEN_HAS_GPU_FP16)
diff --git a/test/basicstuff.cpp b/test/basicstuff.cpp
index f9044a27a..4ca607c82 100644
--- a/test/basicstuff.cpp
+++ b/test/basicstuff.cpp
@@ -195,11 +195,8 @@ template<typename MatrixType> void basicStuffComplex(const MatrixType& m)
VERIFY(!static_cast<const MatrixType&>(cm).imag().isZero());
}
-template<typename SrcScalar, typename TgtScalar, bool SrcIsHalfOrBF16 = (internal::is_same<SrcScalar, half>::value || internal::is_same<SrcScalar, bfloat16>::value)> struct casting_test;
-
-
template<typename SrcScalar, typename TgtScalar>
-struct casting_test<SrcScalar, TgtScalar, false> {
+struct casting_test {
static void run() {
Matrix<SrcScalar,4,4> m;
for (int i=0; i<m.rows(); ++i) {
@@ -210,33 +207,7 @@ struct casting_test<SrcScalar, TgtScalar, false> {
Matrix<TgtScalar,4,4> n = m.template cast<TgtScalar>();
for (int i=0; i<m.rows(); ++i) {
for (int j=0; j<m.cols(); ++j) {
- VERIFY_IS_APPROX(n(i, j), static_cast<TgtScalar>(m(i, j)));
- }
- }
- }
-};
-
-template<typename SrcScalar, typename TgtScalar>
-struct casting_test<SrcScalar, TgtScalar, true> {
- static void run() {
- casting_test<SrcScalar, TgtScalar, false>::run();
- }
-};
-
-template<typename SrcScalar, typename RealScalar>
-struct casting_test<SrcScalar, std::complex<RealScalar>, true> {
- static void run() {
- typedef std::complex<RealScalar> TgtScalar;
- Matrix<SrcScalar,4,4> m;
- for (int i=0; i<m.rows(); ++i) {
- for (int j=0; j<m.cols(); ++j) {
- m(i, j) = internal::random_without_cast_overflow<SrcScalar, TgtScalar>::value();
- }
- }
- Matrix<TgtScalar,4,4> n = m.template cast<TgtScalar>();
- for (int i=0; i<m.rows(); ++i) {
- for (int j=0; j<m.cols(); ++j) {
- VERIFY_IS_APPROX(n(i, j), static_cast<TgtScalar>(static_cast<RealScalar>(m(i, j))));
+ VERIFY_IS_APPROX(n(i, j), (internal::cast<SrcScalar,TgtScalar>(m(i, j))));
}
}
}
diff --git a/test/bfloat16_float.cpp b/test/bfloat16_float.cpp
index 79c868e84..09df2b2f2 100644
--- a/test/bfloat16_float.cpp
+++ b/test/bfloat16_float.cpp
@@ -44,14 +44,14 @@ void test_truncate(float input, float expected_truncation, float expected_roundi
template<typename T>
void test_roundtrip() {
// Representable T round trip via bfloat16
- VERIFY_IS_EQUAL(static_cast<T>(static_cast<bfloat16>(-std::numeric_limits<T>::infinity())), -std::numeric_limits<T>::infinity());
- VERIFY_IS_EQUAL(static_cast<T>(static_cast<bfloat16>(std::numeric_limits<T>::infinity())), std::numeric_limits<T>::infinity());
- VERIFY_IS_EQUAL(static_cast<T>(static_cast<bfloat16>(T(-1.0))), T(-1.0));
- VERIFY_IS_EQUAL(static_cast<T>(static_cast<bfloat16>(T(-0.5))), T(-0.5));
- VERIFY_IS_EQUAL(static_cast<T>(static_cast<bfloat16>(T(-0.0))), T(-0.0));
- VERIFY_IS_EQUAL(static_cast<T>(static_cast<bfloat16>(T(1.0))), T(1.0));
- VERIFY_IS_EQUAL(static_cast<T>(static_cast<bfloat16>(T(0.5))), T(0.5));
- VERIFY_IS_EQUAL(static_cast<T>(static_cast<bfloat16>(T(0.0))), T(0.0));
+ VERIFY_IS_EQUAL((internal::cast<bfloat16,T>(internal::cast<T,bfloat16>(-std::numeric_limits<T>::infinity()))), -std::numeric_limits<T>::infinity());
+ VERIFY_IS_EQUAL((internal::cast<bfloat16,T>(internal::cast<T,bfloat16>(std::numeric_limits<T>::infinity()))), std::numeric_limits<T>::infinity());
+ VERIFY_IS_EQUAL((internal::cast<bfloat16,T>(internal::cast<T,bfloat16>(T(-1.0)))), T(-1.0));
+ VERIFY_IS_EQUAL((internal::cast<bfloat16,T>(internal::cast<T,bfloat16>(T(-0.5)))), T(-0.5));
+ VERIFY_IS_EQUAL((internal::cast<bfloat16,T>(internal::cast<T,bfloat16>(T(-0.0)))), T(-0.0));
+ VERIFY_IS_EQUAL((internal::cast<bfloat16,T>(internal::cast<T,bfloat16>(T(1.0)))), T(1.0));
+ VERIFY_IS_EQUAL((internal::cast<bfloat16,T>(internal::cast<T,bfloat16>(T(0.5)))), T(0.5));
+ VERIFY_IS_EQUAL((internal::cast<bfloat16,T>(internal::cast<T,bfloat16>(T(0.0)))), T(0.0));
}
void test_conversion()