aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen/src/Core/arch/Default/Half.h
diff options
context:
space:
mode:
authorGravatar Antonio Sanchez <cantonios@google.com>2020-06-25 14:31:16 -0700
committerGravatar Antonio Sánchez <cantonios@google.com>2020-06-30 18:53:55 +0000
commit9cb8771e9c4a1f44ba59741c9fac495d1872bb25 (patch)
tree5348c34ac0673d09fe97aea29770e7b236e85510 /Eigen/src/Core/arch/Default/Half.h
parent145e51516fdac7b30d22c11c6878c2805fc3d724 (diff)
Fix tensor casts for large packets and casts to/from std::complex
The original tensor casts were only defined for `SrcCoeffRatio`:`TgtCoeffRatio` 1:1, 1:2, 2:1, 4:1. Here we add the missing 1:N and 8:1. We also add casting `Eigen::half` to/from `std::complex<T>`, which was missing to make it consistent with `Eigen:bfloat16`, and generalize the overload to work for any complex type. Tests were added to `basicstuff`, `packetmath`, and `cxx11_tensor_casts` to test all cast configurations.
Diffstat (limited to 'Eigen/src/Core/arch/Default/Half.h')
-rw-r--r--Eigen/src/Core/arch/Default/Half.h12
1 files changed, 11 insertions, 1 deletions
diff --git a/Eigen/src/Core/arch/Default/Half.h b/Eigen/src/Core/arch/Default/Half.h
index cfd0bdc06..b84cfc7db 100644
--- a/Eigen/src/Core/arch/Default/Half.h
+++ b/Eigen/src/Core/arch/Default/Half.h
@@ -86,7 +86,7 @@ struct half_base : public __half_raw {
#if (defined(EIGEN_CUDA_SDK_VER) && EIGEN_CUDA_SDK_VER >= 90000)
EIGEN_DEVICE_FUNC half_base(const __half& h) : __half_raw(*(__half_raw*)&h) {}
#endif
- #endif
+ #endif
#endif
};
@@ -133,6 +133,11 @@ struct half : public half_impl::half_base {
: half_impl::half_base(half_impl::float_to_half_rtne(static_cast<float>(val))) {}
explicit EIGEN_DEVICE_FUNC half(float f)
: half_impl::half_base(half_impl::float_to_half_rtne(f)) {}
+ // Following the convention of numpy, converting between complex and
+ // float will lead to loss of imag value.
+ template<typename RealScalar>
+ explicit EIGEN_DEVICE_FUNC half(std::complex<RealScalar> c)
+ : half_impl::half_base(half_impl::float_to_half_rtne(static_cast<float>(c.real()))) {}
EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(bool) const {
// +0.0 and -0.0 become false, everything else becomes true.
@@ -174,6 +179,11 @@ struct half : public half_impl::half_base {
EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(double) const {
return static_cast<double>(half_impl::half_to_float(*this));
}
+
+ template<typename RealScalar>
+ EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(std::complex<RealScalar>) const {
+ return std::complex<RealScalar>(static_cast<RealScalar>(*this), RealScalar(0));
+ }
};
} // end namespace Eigen