From 125cc9a5df6074756b89ea8aaa4e9a4b44b0f7e9 Mon Sep 17 00:00:00 2001 From: Rasmus Munk Larsen Date: Tue, 8 Dec 2020 18:13:35 -0800 Subject: Implement vectorized complex square root. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Closes #1905 Measured speedup for sqrt of `complex` on Skylake: SSE: ``` name old time/op new time/op delta BM_eigen_sqrt_ctype/1 49.4ns ± 0% 54.3ns ± 0% +10.01% BM_eigen_sqrt_ctype/8 332ns ± 0% 50ns ± 1% -84.97% BM_eigen_sqrt_ctype/64 2.81µs ± 1% 0.38µs ± 0% -86.49% BM_eigen_sqrt_ctype/512 23.8µs ± 0% 3.0µs ± 0% -87.32% BM_eigen_sqrt_ctype/4k 202µs ± 0% 24µs ± 2% -88.03% BM_eigen_sqrt_ctype/32k 1.63ms ± 0% 0.19ms ± 0% -88.18% BM_eigen_sqrt_ctype/256k 13.0ms ± 0% 1.5ms ± 1% -88.20% BM_eigen_sqrt_ctype/1M 52.1ms ± 0% 6.2ms ± 0% -88.18% ``` AVX2: ``` name old cpu/op new cpu/op delta BM_eigen_sqrt_ctype/1 53.6ns ± 0% 55.6ns ± 0% +3.71% BM_eigen_sqrt_ctype/8 334ns ± 0% 27ns ± 0% -91.86% BM_eigen_sqrt_ctype/64 2.79µs ± 0% 0.22µs ± 2% -92.28% BM_eigen_sqrt_ctype/512 23.8µs ± 1% 1.7µs ± 1% -92.81% BM_eigen_sqrt_ctype/4k 201µs ± 0% 14µs ± 1% -93.24% BM_eigen_sqrt_ctype/32k 1.62ms ± 0% 0.11ms ± 1% -93.29% BM_eigen_sqrt_ctype/256k 13.0ms ± 0% 0.9ms ± 1% -93.31% BM_eigen_sqrt_ctype/1M 52.0ms ± 0% 3.5ms ± 1% -93.31% ``` AVX512: ``` name old cpu/op new cpu/op delta BM_eigen_sqrt_ctype/1 53.7ns ± 0% 56.2ns ± 1% +4.75% BM_eigen_sqrt_ctype/8 334ns ± 0% 18ns ± 2% -94.63% BM_eigen_sqrt_ctype/64 2.79µs ± 0% 0.12µs ± 1% -95.54% BM_eigen_sqrt_ctype/512 23.9µs ± 1% 1.0µs ± 1% -95.89% BM_eigen_sqrt_ctype/4k 202µs ± 0% 8µs ± 1% -96.13% BM_eigen_sqrt_ctype/32k 1.63ms ± 0% 0.06ms ± 1% -96.15% BM_eigen_sqrt_ctype/256k 13.0ms ± 0% 0.5ms ± 4% -96.11% BM_eigen_sqrt_ctype/1M 52.1ms ± 0% 2.0ms ± 1% -96.13% ``` --- Eigen/src/Core/GenericPacketMath.h | 14 +++ Eigen/src/Core/arch/AVX/Complex.h | 36 ++++++- Eigen/src/Core/arch/AVX/PacketMath.h | 5 + Eigen/src/Core/arch/AVX512/Complex.h | 17 ++- Eigen/src/Core/arch/AVX512/PacketMath.h | 13 +++ .../Core/arch/Default/GenericPacketMathFunctions.h | 114 +++++++++++++++++++++ .../arch/Default/GenericPacketMathFunctionsFwd.h | 7 ++ Eigen/src/Core/arch/SSE/Complex.h | 38 ++++++- Eigen/src/Core/arch/SSE/PacketMath.h | 4 + test/packetmath.cpp | 53 +++++++++- 10 files changed, 290 insertions(+), 11 deletions(-) diff --git a/Eigen/src/Core/GenericPacketMath.h b/Eigen/src/Core/GenericPacketMath.h index 3f80e8033..e2d9af47f 100644 --- a/Eigen/src/Core/GenericPacketMath.h +++ b/Eigen/src/Core/GenericPacketMath.h @@ -539,6 +539,20 @@ inline void pbroadcast2(const typename unpacket_traits::type *a, template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet plset(const typename unpacket_traits::type& a) { return a; } +/** \internal \returns a packet with constant coefficients \a a, e.g.: (x, 0, x, 0), + where x is the value of all 1-bits. */ +template EIGEN_DEVICE_FUNC inline Packet +peven_mask(const Packet& /*a*/) { + typedef typename unpacket_traits::type Scalar; + const size_t n = unpacket_traits::size; + Scalar elements[n]; + for(size_t i = 0; i < n; ++i) { + memset(elements+i, ((i & 1) == 0 ? 0xff : 0), sizeof(Scalar)); + } + return ploadu(elements); +} + + /** \internal copy the packet \a from to \a *to, \a to must be 16 bytes aligned */ template EIGEN_DEVICE_FUNC inline void pstore(Scalar* to, const Packet& from) { (*to) = from; } diff --git a/Eigen/src/Core/arch/AVX/Complex.h b/Eigen/src/Core/arch/AVX/Complex.h index 23568cae9..506ca0be5 100644 --- a/Eigen/src/Core/arch/AVX/Complex.h +++ b/Eigen/src/Core/arch/AVX/Complex.h @@ -38,6 +38,7 @@ template<> struct packet_traits > : default_packet_traits HasMul = 1, HasDiv = 1, HasNegate = 1, + HasSqrt = 1, HasAbs = 0, HasAbs2 = 0, HasMin = 0, @@ -47,7 +48,18 @@ template<> struct packet_traits > : default_packet_traits }; #endif -template<> struct unpacket_traits { typedef std::complex type; enum {size=4, alignment=Aligned32, vectorizable=true, masked_load_available=false, masked_store_available=false}; typedef Packet2cf half; }; +template<> struct unpacket_traits { + typedef std::complex type; + typedef Packet2cf half; + typedef Packet8f as_real; + enum { + size=4, + alignment=Aligned32, + vectorizable=true, + masked_load_available=false, + masked_store_available=false + }; +}; template<> EIGEN_STRONG_INLINE Packet4cf padd(const Packet4cf& a, const Packet4cf& b) { return Packet4cf(_mm256_add_ps(a.v,b.v)); } template<> EIGEN_STRONG_INLINE Packet4cf psub(const Packet4cf& a, const Packet4cf& b) { return Packet4cf(_mm256_sub_ps(a.v,b.v)); } @@ -228,6 +240,7 @@ template<> struct packet_traits > : default_packet_traits HasMul = 1, HasDiv = 1, HasNegate = 1, + HasSqrt = 1, HasAbs = 0, HasAbs2 = 0, HasMin = 0, @@ -237,7 +250,18 @@ template<> struct packet_traits > : default_packet_traits }; #endif -template<> struct unpacket_traits { typedef std::complex type; enum {size=2, alignment=Aligned32, vectorizable=true, masked_load_available=false, masked_store_available=false}; typedef Packet1cd half; }; +template<> struct unpacket_traits { + typedef std::complex type; + typedef Packet1cd half; + typedef Packet4d as_real; + enum { + size=2, + alignment=Aligned32, + vectorizable=true, + masked_load_available=false, + masked_store_available=false + }; +}; template<> EIGEN_STRONG_INLINE Packet2cd padd(const Packet2cd& a, const Packet2cd& b) { return Packet2cd(_mm256_add_pd(a.v,b.v)); } template<> EIGEN_STRONG_INLINE Packet2cd psub(const Packet2cd& a, const Packet2cd& b) { return Packet2cd(_mm256_sub_pd(a.v,b.v)); } @@ -399,6 +423,14 @@ ptranspose(PacketBlock& kernel) { kernel.packet[0].v = tmp; } +template<> EIGEN_STRONG_INLINE Packet2cd psqrt(const Packet2cd& a) { + return psqrt_complex(a); +} + +template<> EIGEN_STRONG_INLINE Packet4cf psqrt(const Packet4cf& a) { + return psqrt_complex(a); +} + } // end namespace internal } // end namespace Eigen diff --git a/Eigen/src/Core/arch/AVX/PacketMath.h b/Eigen/src/Core/arch/AVX/PacketMath.h index f5c18f63f..d0152db12 100644 --- a/Eigen/src/Core/arch/AVX/PacketMath.h +++ b/Eigen/src/Core/arch/AVX/PacketMath.h @@ -248,6 +248,11 @@ template<> EIGEN_STRONG_INLINE Packet8f pzero(const Packet8f& /*a*/) { return _m template<> EIGEN_STRONG_INLINE Packet4d pzero(const Packet4d& /*a*/) { return _mm256_setzero_pd(); } template<> EIGEN_STRONG_INLINE Packet8i pzero(const Packet8i& /*a*/) { return _mm256_setzero_si256(); } + +template<> EIGEN_STRONG_INLINE Packet8f peven_mask(const Packet8f& /*a*/) { return Packet8f(_mm256_set_epi32(0, -1, 0, -1, 0, -1, 0, -1)); } +template<> EIGEN_STRONG_INLINE Packet8i peven_mask(const Packet8i& /*a*/) { return Packet8i(_mm256_set_epi32(0, -1, 0, -1, 0, -1, 0, -1)); } +template<> EIGEN_STRONG_INLINE Packet4d peven_mask(const Packet4d& /*a*/) { return Packet4d(_mm256_set_epi32(0, 0, -1, -1, 0, 0, -1, -1)); } + template<> EIGEN_STRONG_INLINE Packet8f pload1(const float* from) { return _mm256_broadcast_ss(from); } template<> EIGEN_STRONG_INLINE Packet4d pload1(const double* from) { return _mm256_broadcast_sd(from); } diff --git a/Eigen/src/Core/arch/AVX512/Complex.h b/Eigen/src/Core/arch/AVX512/Complex.h index 53ee53d17..45f22f436 100644 --- a/Eigen/src/Core/arch/AVX512/Complex.h +++ b/Eigen/src/Core/arch/AVX512/Complex.h @@ -37,6 +37,7 @@ template<> struct packet_traits > : default_packet_traits HasMul = 1, HasDiv = 1, HasNegate = 1, + HasSqrt = 1, HasAbs = 0, HasAbs2 = 0, HasMin = 0, @@ -47,6 +48,8 @@ template<> struct packet_traits > : default_packet_traits template<> struct unpacket_traits { typedef std::complex type; + typedef Packet4cf half; + typedef Packet16f as_real; enum { size = 8, alignment=unpacket_traits::alignment, @@ -54,7 +57,6 @@ template<> struct unpacket_traits { masked_load_available=false, masked_store_available=false }; - typedef Packet4cf half; }; template<> EIGEN_STRONG_INLINE Packet8cf ptrue(const Packet8cf& a) { return Packet8cf(ptrue(Packet16f(a.v))); } @@ -223,6 +225,7 @@ template<> struct packet_traits > : default_packet_traits HasMul = 1, HasDiv = 1, HasNegate = 1, + HasSqrt = 1, HasAbs = 0, HasAbs2 = 0, HasMin = 0, @@ -233,6 +236,8 @@ template<> struct packet_traits > : default_packet_traits template<> struct unpacket_traits { typedef std::complex type; + typedef Packet2cd half; + typedef Packet8d as_real; enum { size = 4, alignment = unpacket_traits::alignment, @@ -240,7 +245,6 @@ template<> struct unpacket_traits { masked_load_available=false, masked_store_available=false }; - typedef Packet2cd half; }; template<> EIGEN_STRONG_INLINE Packet4cd padd(const Packet4cd& a, const Packet4cd& b) { return Packet4cd(_mm512_add_pd(a.v,b.v)); } @@ -437,8 +441,15 @@ ptranspose(PacketBlock& kernel) { kernel.packet[0] = Packet4cd(_mm512_shuffle_f64x2(T0, T2, (shuffle_mask<0,2,0,2>::mask))); // [a0 b0 c0 d0] } -} // end namespace internal +template<> EIGEN_STRONG_INLINE Packet4cd psqrt(const Packet4cd& a) { + return psqrt_complex(a); +} + +template<> EIGEN_STRONG_INLINE Packet8cf psqrt(const Packet8cf& a) { + return psqrt_complex(a); +} +} // end namespace internal } // end namespace Eigen #endif // EIGEN_COMPLEX_AVX512_H diff --git a/Eigen/src/Core/arch/AVX512/PacketMath.h b/Eigen/src/Core/arch/AVX512/PacketMath.h index a001fb186..6662a5fe7 100644 --- a/Eigen/src/Core/arch/AVX512/PacketMath.h +++ b/Eigen/src/Core/arch/AVX512/PacketMath.h @@ -219,6 +219,19 @@ template<> EIGEN_STRONG_INLINE Packet16f pzero(const Packet16f& /*a*/) { return template<> EIGEN_STRONG_INLINE Packet8d pzero(const Packet8d& /*a*/) { return _mm512_setzero_pd(); } template<> EIGEN_STRONG_INLINE Packet16i pzero(const Packet16i& /*a*/) { return _mm512_setzero_si512(); } +template<> EIGEN_STRONG_INLINE Packet16f peven_mask(const Packet16f& /*a*/) { + return Packet16f(_mm512_set_epi32(0, -1, 0, -1, 0, -1, 0, -1, + 0, -1, 0, -1, 0, -1, 0, -1)); +} +template<> EIGEN_STRONG_INLINE Packet16i peven_mask(const Packet16i& /*a*/) { + return Packet16i(_mm512_set_epi32(0, -1, 0, -1, 0, -1, 0, -1, + 0, -1, 0, -1, 0, -1, 0, -1)); +} +template<> EIGEN_STRONG_INLINE Packet8d peven_mask(const Packet8d& /*a*/) { + return Packet8d(_mm512_set_epi32(0, 0, -1, -1, 0, 0, -1, -1, + 0, 0, -1, -1, 0, 0, -1, -1)); +} + template <> EIGEN_STRONG_INLINE Packet16f pload1(const float* from) { return _mm512_broadcastss_ps(_mm_load_ps1(from)); diff --git a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h index c6bb89b05..45cc780f1 100644 --- a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h +++ b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h @@ -673,6 +673,120 @@ Packet pcos_float(const Packet& x) return psincos_float(x); } + +template +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS +EIGEN_UNUSED +Packet psqrt_complex(const Packet& a) { + typedef typename unpacket_traits::type Scalar; + typedef typename Scalar::value_type RealScalar; + typedef typename unpacket_traits::as_real RealPacket; + + // Computes the principal sqrt of the complex numbers in the input. + // + // For example, for packets containing 2 complex numbers stored in interleaved format + // a = [a0, a1] = [x0, y0, x1, y1], + // where x0 = real(a0), y0 = imag(a0) etc., this function returns + // b = [b0, b1] = [u0, v0, u1, v1], + // such that b0^2 = a0, b1^2 = a1. + // + // To derive the formula for the complex square roots, let's consider the equation for + // a single complex square root of the number x + i*y. We want to find real numbers + // u and v such that + // (u + i*v)^2 = x + i*y <=> + // u^2 - v^2 + i*2*u*v = x + i*v. + // By equating the real and imaginary parts we get: + // u^2 - v^2 = x + // 2*u*v = y. + // + // For x >= 0, this has the numerically stable solution + // u = sqrt(0.5 * (x + sqrt(x^2 + y^2))) + // v = 0.5 * (y / u) + // and for x < 0, + // v = sign(y) * sqrt(0.5 * (x + sqrt(x^2 + y^2))) + // u = |0.5 * (y / v)| + // + // To avoid unnecessary over- and underflow, we compute sqrt(x^2 + y^2) as + // l = max(|x|, |y|) * sqrt(1 + (min(|x|, |y|) / max(|x|, |y|))^2) , + + // In the following, without lack of generality, we have annotated the code, assuming + // that the input is a packet of 2 complex numbers. + // + // Step 1. Compute l = [l0, l0, l1, l1], where + // l0 = sqrt(x0^2 + y0^2), l1 = sqrt(x1^2 + y1^2) + // To avoid over- and underflow, we use the stable formula for each hypotenuse + // l0 = (min0 == 0 ? max0 : max0 * sqrt(1 + (min0/max0)**2)), + // where max0 = max(|x0|, |y0|), min0 = min(|x0|, |y0|), and similarly for l1. + + Packet a_flip = pcplxflip(a); + RealPacket a_abs = pabs(a.v); // [|x0|, |y0|, |x1|, |y1|] + RealPacket a_abs_flip = pabs(a_flip.v); // [|y0|, |x0|, |y1|, |x1|] + RealPacket a_max = pmax(a_abs, a_abs_flip); + RealPacket a_min = pmin(a_abs, a_abs_flip); + RealPacket a_min_zero_mask = pcmp_eq(a_min, pzero(a_min)); + RealPacket a_max_zero_mask = pcmp_eq(a_max, pzero(a_max)); + RealPacket r = pdiv(a_min, a_max); + const RealPacket cst_one = pset1(RealScalar(1)); + RealPacket l = pmul(a_max, psqrt(padd(cst_one, pmul(r, r)))); // [l0, l0, l1, l1] + // Set l to a_max if a_min is zero. + l = pselect(a_min_zero_mask, a_max, l); + + // Step 2. Compute [rho0, *, rho1, *], where + // rho0 = sqrt(0.5 * (l0 + |x0|)), rho1 = sqrt(0.5 * (l1 + |x1|)) + // We don't care about the imaginary parts computed here. They will be overwritten later. + const RealPacket cst_half = pset1(RealScalar(0.5)); + Packet rho; + rho.v = psqrt(pmul(cst_half, padd(a_abs, l))); + + // Step 3. Compute [rho0, eta0, rho1, eta1], where + // eta0 = (y0 / l0) / 2, and eta1 = (y1 / l1) / 2. + // set eta = 0 of input is 0 + i0. + RealPacket eta = pandnot(pmul(cst_half, pdiv(a.v, pcplxflip(rho).v)), a_max_zero_mask); + RealPacket real_mask = peven_mask(a.v); + Packet positive_real_result; + // Compute result for inputs with positive real part. + positive_real_result.v = pselect(real_mask, rho.v, eta); + + // Step 4. Compute solution for inputs with negative real part: + // [|eta0|, sign(y0)*rho0, |eta1|, sign(y1)*rho1] + const RealPacket cst_imag_sign_mask = pset1(Scalar(RealScalar(0.0), RealScalar(-0.0))).v; + RealPacket imag_signs = pand(a.v, cst_imag_sign_mask); + Packet negative_real_result; + // Notice that rho is positive, so taking it's absolute value is a noop. + negative_real_result.v = por(pabs(pcplxflip(positive_real_result).v), imag_signs); + + // Step 5. Select solution branch based on the sign of the real parts. + Packet negative_real_mask; + negative_real_mask.v = pcmp_lt(pand(real_mask, a.v), pzero(a.v)); + negative_real_mask.v = por(negative_real_mask.v, pcplxflip(negative_real_mask).v); + Packet result = pselect(negative_real_mask, negative_real_result, positive_real_result); + + // Step 6. Handle special cases for infinities: + // * If z is (x,+∞), the result is (+∞,+∞) even if x is NaN + // * If z is (x,-∞), the result is (+∞,-∞) even if x is NaN + // * If z is (-∞,y), the result is (0*|y|,+∞) for finite or NaN y + // * If z is (+∞,y), the result is (+∞,0*|y|) for finite or NaN y + const RealPacket cst_pos_inf = pset1(NumTraits::infinity()); + Packet is_inf; + is_inf.v = pcmp_eq(a_abs, cst_pos_inf); + Packet is_real_inf; + is_real_inf.v = pand(is_inf.v, real_mask); + is_real_inf = por(is_real_inf, pcplxflip(is_real_inf)); + // prepare packet of (+∞,0*|y|) or (0*|y|,+∞), depending on the sign of the infinite real part. + Packet real_inf_result; + real_inf_result.v = pmul(a_abs, pset1(Scalar(RealScalar(1.0), RealScalar(0.0))).v); + real_inf_result.v = pselect(negative_real_mask.v, pcplxflip(real_inf_result).v, real_inf_result.v); + // prepare packet of (+∞,+∞) or (+∞,-∞), depending on the sign of the infinite imaginary part. + Packet is_imag_inf; + is_imag_inf.v = pandnot(is_inf.v, real_mask); + is_imag_inf = por(is_imag_inf, pcplxflip(is_imag_inf)); + Packet imag_inf_result; + imag_inf_result.v = por(pand(cst_pos_inf, real_mask), pandnot(a.v, real_mask)); + + return pselect(is_imag_inf, imag_inf_result, + pselect(is_real_inf, real_inf_result,result)); +} + /* polevl (modified for Eigen) * * Evaluate polynomial diff --git a/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h b/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h index b0f0b78fc..491f1c927 100644 --- a/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h +++ b/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h @@ -82,8 +82,15 @@ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet pcos_float(const Packet& x); +/** \internal \returns sqrt(x) for complex types */ +template +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS +EIGEN_UNUSED +Packet psqrt_complex(const Packet& a); + template struct ppolevl; + } // end namespace internal } // end namespace Eigen diff --git a/Eigen/src/Core/arch/SSE/Complex.h b/Eigen/src/Core/arch/SSE/Complex.h index 0d322a2a1..58cdb5dbe 100644 --- a/Eigen/src/Core/arch/SSE/Complex.h +++ b/Eigen/src/Core/arch/SSE/Complex.h @@ -40,6 +40,7 @@ template<> struct packet_traits > : default_packet_traits HasMul = 1, HasDiv = 1, HasNegate = 1, + HasSqrt = 1, HasAbs = 0, HasAbs2 = 0, HasMin = 0, @@ -50,7 +51,18 @@ template<> struct packet_traits > : default_packet_traits }; #endif -template<> struct unpacket_traits { typedef std::complex type; enum {size=2, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false}; typedef Packet2cf half; }; +template<> struct unpacket_traits { + typedef std::complex type; + typedef Packet2cf half; + typedef Packet4f as_real; + enum { + size=2, + alignment=Aligned16, + vectorizable=true, + masked_load_available=false, + masked_store_available=false + }; +}; template<> EIGEN_STRONG_INLINE Packet2cf padd(const Packet2cf& a, const Packet2cf& b) { return Packet2cf(_mm_add_ps(a.v,b.v)); } template<> EIGEN_STRONG_INLINE Packet2cf psub(const Packet2cf& a, const Packet2cf& b) { return Packet2cf(_mm_sub_ps(a.v,b.v)); } @@ -83,7 +95,6 @@ template<> EIGEN_STRONG_INLINE Packet2cf pmul(const Packet2cf& a, con } template<> EIGEN_STRONG_INLINE Packet2cf ptrue (const Packet2cf& a) { return Packet2cf(ptrue(Packet4f(a.v))); } - template<> EIGEN_STRONG_INLINE Packet2cf pand (const Packet2cf& a, const Packet2cf& b) { return Packet2cf(_mm_and_ps(a.v,b.v)); } template<> EIGEN_STRONG_INLINE Packet2cf por (const Packet2cf& a, const Packet2cf& b) { return Packet2cf(_mm_or_ps(a.v,b.v)); } template<> EIGEN_STRONG_INLINE Packet2cf pxor (const Packet2cf& a, const Packet2cf& b) { return Packet2cf(_mm_xor_ps(a.v,b.v)); } @@ -255,6 +266,7 @@ template<> struct packet_traits > : default_packet_traits HasMul = 1, HasDiv = 1, HasNegate = 1, + HasSqrt = 1, HasAbs = 0, HasAbs2 = 0, HasMin = 0, @@ -264,7 +276,18 @@ template<> struct packet_traits > : default_packet_traits }; #endif -template<> struct unpacket_traits { typedef std::complex type; enum {size=1, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false}; typedef Packet1cd half; }; +template<> struct unpacket_traits { + typedef std::complex type; + typedef Packet1cd half; + typedef Packet2d as_real; + enum { + size=1, + alignment=Aligned16, + vectorizable=true, + masked_load_available=false, + masked_store_available=false + }; +}; template<> EIGEN_STRONG_INLINE Packet1cd padd(const Packet1cd& a, const Packet1cd& b) { return Packet1cd(_mm_add_pd(a.v,b.v)); } template<> EIGEN_STRONG_INLINE Packet1cd psub(const Packet1cd& a, const Packet1cd& b) { return Packet1cd(_mm_sub_pd(a.v,b.v)); } @@ -426,8 +449,15 @@ template<> EIGEN_STRONG_INLINE Packet2cf pblend(const Selector<2>& ifPacket, co return Packet2cf(_mm_castpd_ps(result)); } -} // end namespace internal +template<> EIGEN_STRONG_INLINE Packet1cd psqrt(const Packet1cd& a) { + return psqrt_complex(a); +} +template<> EIGEN_STRONG_INLINE Packet2cf psqrt(const Packet2cf& a) { + return psqrt_complex(a); +} + +} // end namespace internal } // end namespace Eigen #endif // EIGEN_COMPLEX_SSE_H diff --git a/Eigen/src/Core/arch/SSE/PacketMath.h b/Eigen/src/Core/arch/SSE/PacketMath.h index ef77ab6fa..0724378bc 100755 --- a/Eigen/src/Core/arch/SSE/PacketMath.h +++ b/Eigen/src/Core/arch/SSE/PacketMath.h @@ -267,6 +267,10 @@ template<> EIGEN_STRONG_INLINE Packet16b pset1(const bool& from) { template<> EIGEN_STRONG_INLINE Packet4f pset1frombits(unsigned int from) { return _mm_castsi128_ps(pset1(from)); } template<> EIGEN_STRONG_INLINE Packet2d pset1frombits(uint64_t from) { return _mm_castsi128_pd(_mm_set1_epi64x(from)); } +template<> EIGEN_STRONG_INLINE Packet4f peven_mask(const Packet4f& /*a*/) { return Packet4f(_mm_set_epi32(0, -1, 0, -1)); } +template<> EIGEN_STRONG_INLINE Packet4i peven_mask(const Packet4i& /*a*/) { return Packet4i(_mm_set_epi32(0, -1, 0, -1)); } +template<> EIGEN_STRONG_INLINE Packet2d peven_mask(const Packet2d& /*a*/) { return Packet2d(_mm_set_epi32(0, 0, -1, -1)); } + template<> EIGEN_STRONG_INLINE Packet4f pzero(const Packet4f& /*a*/) { return _mm_setzero_ps(); } template<> EIGEN_STRONG_INLINE Packet2d pzero(const Packet2d& /*a*/) { return _mm_setzero_pd(); } template<> EIGEN_STRONG_INLINE Packet4i pzero(const Packet4i& /*a*/) { return _mm_setzero_si128(); } diff --git a/test/packetmath.cpp b/test/packetmath.cpp index d995e8b71..0e49d93a9 100644 --- a/test/packetmath.cpp +++ b/test/packetmath.cpp @@ -473,8 +473,6 @@ void packetmath() { CHECK_CWISE3_IF(true, internal::pselect, internal::pselect); } - CHECK_CWISE1_IF(PacketTraits::HasSqrt, numext::sqrt, internal::psqrt); - for (int i = 0; i < size; ++i) { data1[i] = internal::random(); } @@ -486,6 +484,11 @@ void packetmath() { packetmath_boolean_mask_ops(); packetmath_pcast_ops_runner::run(); packetmath_minus_zero_add(); + + for (int i = 0; i < size; ++i) { + data1[i] = numext::abs(internal::random()); + } + CHECK_CWISE1_IF(PacketTraits::HasSqrt, numext::sqrt, internal::psqrt); } // Notice that this definition works for complex types as well. @@ -899,6 +902,8 @@ void test_conj_helper(Scalar* data1, Scalar* data2, Scalar* ref, Scalar* pval) { template void packetmath_complex() { + typedef internal::packet_traits PacketTraits; + typedef typename Scalar::value_type RealScalar; const int PacketSize = internal::unpacket_traits::size; const int size = PacketSize * 4; @@ -917,11 +922,55 @@ void packetmath_complex() { test_conj_helper(data1, data2, ref, pval); test_conj_helper(data1, data2, ref, pval); + // Test pcplxflip. { for (int i = 0; i < PacketSize; ++i) ref[i] = Scalar(std::imag(data1[i]), std::real(data1[i])); internal::pstore(pval, internal::pcplxflip(internal::pload(data1))); VERIFY(test::areApprox(ref, pval, PacketSize) && "pcplxflip"); } + + if (PacketTraits::HasSqrt) { + for (int i = 0; i < size; ++i) { + data1[i] = Scalar(internal::random(), internal::random()); + } + CHECK_CWISE1(numext::sqrt, internal::psqrt); + + // Test misc. corner cases. + const RealScalar zero = RealScalar(0); + const RealScalar one = RealScalar(1); + const RealScalar inf = std::numeric_limits::infinity(); + const RealScalar nan = std::numeric_limits::quiet_NaN(); + data1[0] = Scalar(zero, zero); + data1[1] = Scalar(-zero, zero); + data1[2] = Scalar(one, zero); + data1[3] = Scalar(zero, one); + CHECK_CWISE1(numext::sqrt, internal::psqrt); + data1[0] = Scalar(-one, zero); + data1[1] = Scalar(zero, -one); + data1[2] = Scalar(one, one); + data1[3] = Scalar(-one, -one); + CHECK_CWISE1(numext::sqrt, internal::psqrt); + data1[0] = Scalar(inf, zero); + data1[1] = Scalar(zero, inf); + data1[2] = Scalar(-inf, zero); + data1[3] = Scalar(zero, -inf); + CHECK_CWISE1(numext::sqrt, internal::psqrt); + data1[0] = Scalar(inf, inf); + data1[1] = Scalar(-inf, inf); + data1[2] = Scalar(inf, -inf); + data1[3] = Scalar(-inf, -inf); + CHECK_CWISE1(numext::sqrt, internal::psqrt); + data1[0] = Scalar(nan, zero); + data1[1] = Scalar(zero, nan); + data1[2] = Scalar(nan, one); + data1[3] = Scalar(one, nan); + CHECK_CWISE1(numext::sqrt, internal::psqrt); + data1[0] = Scalar(nan, nan); + data1[1] = Scalar(inf, nan); + data1[2] = Scalar(nan, inf); + data1[3] = Scalar(-inf, nan); + CHECK_CWISE1(numext::sqrt, internal::psqrt); + } } template -- cgit v1.2.3