diff options
author | Antonio Sanchez <cantonios@google.com> | 2020-06-25 14:31:16 -0700 |
---|---|---|
committer | Antonio Sánchez <cantonios@google.com> | 2020-06-30 18:53:55 +0000 |
commit | 9cb8771e9c4a1f44ba59741c9fac495d1872bb25 (patch) | |
tree | 5348c34ac0673d09fe97aea29770e7b236e85510 /Eigen/src/Core/arch/Default/BFloat16.h | |
parent | 145e51516fdac7b30d22c11c6878c2805fc3d724 (diff) |
Fix tensor casts for large packets and casts to/from std::complex
The original tensor casts were only defined for
`SrcCoeffRatio`:`TgtCoeffRatio` 1:1, 1:2, 2:1, 4:1. Here we add the
missing 1:N and 8:1.
We also add casting `Eigen::half` to/from `std::complex<T>`, which
was missing to make it consistent with `Eigen:bfloat16`, and
generalize the overload to work for any complex type.
Tests were added to `basicstuff`, `packetmath`, and
`cxx11_tensor_casts` to test all cast configurations.
Diffstat (limited to 'Eigen/src/Core/arch/Default/BFloat16.h')
-rw-r--r-- | Eigen/src/Core/arch/Default/BFloat16.h | 17 |
1 files changed, 5 insertions, 12 deletions
diff --git a/Eigen/src/Core/arch/Default/BFloat16.h b/Eigen/src/Core/arch/Default/BFloat16.h index c3725d473..99ce99a27 100644 --- a/Eigen/src/Core/arch/Default/BFloat16.h +++ b/Eigen/src/Core/arch/Default/BFloat16.h @@ -65,13 +65,8 @@ struct bfloat16 : public bfloat16_impl::bfloat16_base { : bfloat16_impl::bfloat16_base(bfloat16_impl::float_to_bfloat16_rtne(f)) {} // Following the convention of numpy, converting between complex and // float will lead to loss of imag value. - // Single precision complex. - typedef std::complex<float> complex64; - // Double precision complex. - typedef std::complex<double> complex128; - explicit EIGEN_DEVICE_FUNC bfloat16(const complex64& val) - : bfloat16_impl::bfloat16_base(bfloat16_impl::float_to_bfloat16_rtne(val.real())) {} - explicit EIGEN_DEVICE_FUNC bfloat16(const complex128& val) + template<typename RealScalar> + explicit EIGEN_DEVICE_FUNC bfloat16(const std::complex<RealScalar>& val) : bfloat16_impl::bfloat16_base(bfloat16_impl::float_to_bfloat16_rtne(static_cast<float>(val.real()))) {} EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(bool) const { @@ -114,11 +109,9 @@ struct bfloat16 : public bfloat16_impl::bfloat16_base { EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(double) const { return static_cast<double>(bfloat16_impl::bfloat16_to_float(*this)); } - EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(complex64) const { - return complex64(bfloat16_impl::bfloat16_to_float(*this), float(0.0)); - } - EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(complex128) const { - return complex128(static_cast<double>(bfloat16_impl::bfloat16_to_float(*this)), double(0.0)); + template<typename RealScalar> + EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(std::complex<RealScalar>) const { + return std::complex<RealScalar>(static_cast<RealScalar>(bfloat16_impl::bfloat16_to_float(*this)), RealScalar(0)); } EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(Eigen::half) const { return static_cast<Eigen::half>(bfloat16_impl::bfloat16_to_float(*this)); |