diff options
Diffstat (limited to 'Eigen/src/Core/arch')
-rw-r--r-- | Eigen/src/Core/arch/Default/BFloat16.h | 17 | ||||
-rw-r--r-- | Eigen/src/Core/arch/Default/Half.h | 12 |
2 files changed, 16 insertions, 13 deletions
diff --git a/Eigen/src/Core/arch/Default/BFloat16.h b/Eigen/src/Core/arch/Default/BFloat16.h index c3725d473..99ce99a27 100644 --- a/Eigen/src/Core/arch/Default/BFloat16.h +++ b/Eigen/src/Core/arch/Default/BFloat16.h @@ -65,13 +65,8 @@ struct bfloat16 : public bfloat16_impl::bfloat16_base { : bfloat16_impl::bfloat16_base(bfloat16_impl::float_to_bfloat16_rtne(f)) {} // Following the convention of numpy, converting between complex and // float will lead to loss of imag value. - // Single precision complex. - typedef std::complex<float> complex64; - // Double precision complex. - typedef std::complex<double> complex128; - explicit EIGEN_DEVICE_FUNC bfloat16(const complex64& val) - : bfloat16_impl::bfloat16_base(bfloat16_impl::float_to_bfloat16_rtne(val.real())) {} - explicit EIGEN_DEVICE_FUNC bfloat16(const complex128& val) + template<typename RealScalar> + explicit EIGEN_DEVICE_FUNC bfloat16(const std::complex<RealScalar>& val) : bfloat16_impl::bfloat16_base(bfloat16_impl::float_to_bfloat16_rtne(static_cast<float>(val.real()))) {} EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(bool) const { @@ -114,11 +109,9 @@ struct bfloat16 : public bfloat16_impl::bfloat16_base { EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(double) const { return static_cast<double>(bfloat16_impl::bfloat16_to_float(*this)); } - EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(complex64) const { - return complex64(bfloat16_impl::bfloat16_to_float(*this), float(0.0)); - } - EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(complex128) const { - return complex128(static_cast<double>(bfloat16_impl::bfloat16_to_float(*this)), double(0.0)); + template<typename RealScalar> + EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(std::complex<RealScalar>) const { + return std::complex<RealScalar>(static_cast<RealScalar>(bfloat16_impl::bfloat16_to_float(*this)), RealScalar(0)); } EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(Eigen::half) const { return static_cast<Eigen::half>(bfloat16_impl::bfloat16_to_float(*this)); diff --git a/Eigen/src/Core/arch/Default/Half.h b/Eigen/src/Core/arch/Default/Half.h index cfd0bdc06..b84cfc7db 100644 --- a/Eigen/src/Core/arch/Default/Half.h +++ b/Eigen/src/Core/arch/Default/Half.h @@ -86,7 +86,7 @@ struct half_base : public __half_raw { #if (defined(EIGEN_CUDA_SDK_VER) && EIGEN_CUDA_SDK_VER >= 90000) EIGEN_DEVICE_FUNC half_base(const __half& h) : __half_raw(*(__half_raw*)&h) {} #endif - #endif + #endif #endif }; @@ -133,6 +133,11 @@ struct half : public half_impl::half_base { : half_impl::half_base(half_impl::float_to_half_rtne(static_cast<float>(val))) {} explicit EIGEN_DEVICE_FUNC half(float f) : half_impl::half_base(half_impl::float_to_half_rtne(f)) {} + // Following the convention of numpy, converting between complex and + // float will lead to loss of imag value. + template<typename RealScalar> + explicit EIGEN_DEVICE_FUNC half(std::complex<RealScalar> c) + : half_impl::half_base(half_impl::float_to_half_rtne(static_cast<float>(c.real()))) {} EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(bool) const { // +0.0 and -0.0 become false, everything else becomes true. @@ -174,6 +179,11 @@ struct half : public half_impl::half_base { EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(double) const { return static_cast<double>(half_impl::half_to_float(*this)); } + + template<typename RealScalar> + EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(std::complex<RealScalar>) const { + return std::complex<RealScalar>(static_cast<RealScalar>(*this), RealScalar(0)); + } }; } // end namespace Eigen |