From 9312a5bf5cd72f45558f402077b0c95683ee0fea Mon Sep 17 00:00:00 2001 From: Rasmus Munk Larsen Date: Wed, 30 Jun 2021 15:53:06 -0700 Subject: Implement a generic vectorized version of Smith's algorithms for complex division. --- Eigen/src/Core/arch/AVX/Complex.h | 12 ++---- Eigen/src/Core/arch/AVX512/Complex.h | 44 +--------------------- Eigen/src/Core/arch/AltiVec/Complex.h | 10 +---- .../Core/arch/Default/GenericPacketMathFunctions.h | 20 ++++++++++ .../arch/Default/GenericPacketMathFunctionsFwd.h | 6 +++ Eigen/src/Core/arch/MSA/Complex.h | 11 ++---- Eigen/src/Core/arch/NEON/Complex.h | 27 ++----------- Eigen/src/Core/arch/SSE/Complex.h | 12 +----- Eigen/src/Core/arch/ZVector/Complex.h | 16 ++------ 9 files changed, 45 insertions(+), 113 deletions(-) diff --git a/Eigen/src/Core/arch/AVX/Complex.h b/Eigen/src/Core/arch/AVX/Complex.h index ab7bd6c65..0491be992 100644 --- a/Eigen/src/Core/arch/AVX/Complex.h +++ b/Eigen/src/Core/arch/AVX/Complex.h @@ -167,15 +167,12 @@ template<> EIGEN_STRONG_INLINE std::complex predux_mul(const P Packet2cf(_mm256_extractf128_ps(a.v, 1)))); } + EIGEN_MAKE_CONJ_HELPER_CPLX_REAL(Packet4cf,Packet8f) template<> EIGEN_STRONG_INLINE Packet4cf pdiv(const Packet4cf& a, const Packet4cf& b) { - Packet4cf num = pmul(a, pconj(b)); - __m256 tmp = _mm256_mul_ps(b.v, b.v); - __m256 tmp2 = _mm256_shuffle_ps(tmp,tmp,0xB1); - __m256 denom = _mm256_add_ps(tmp, tmp2); - return Packet4cf(_mm256_div_ps(num.v, denom)); + return pdiv_complex(a, b); } template<> EIGEN_STRONG_INLINE Packet4cf pcplxflip(const Packet4cf& x) @@ -321,10 +318,7 @@ EIGEN_MAKE_CONJ_HELPER_CPLX_REAL(Packet2cd,Packet4d) template<> EIGEN_STRONG_INLINE Packet2cd pdiv(const Packet2cd& a, const Packet2cd& b) { - Packet2cd num = pmul(a, pconj(b)); - __m256d tmp = _mm256_mul_pd(b.v, b.v); - __m256d denom = _mm256_hadd_pd(tmp, tmp); - return Packet2cd(_mm256_div_pd(num.v, denom)); + return pdiv_complex(a, b); } template<> EIGEN_STRONG_INLINE Packet2cd pcplxflip(const Packet2cd& x) diff --git a/Eigen/src/Core/arch/AVX512/Complex.h b/Eigen/src/Core/arch/AVX512/Complex.h index 49c72b3f1..c11b8d2f8 100644 --- a/Eigen/src/Core/arch/AVX512/Complex.h +++ b/Eigen/src/Core/arch/AVX512/Complex.h @@ -157,11 +157,7 @@ EIGEN_MAKE_CONJ_HELPER_CPLX_REAL(Packet8cf,Packet16f) template<> EIGEN_STRONG_INLINE Packet8cf pdiv(const Packet8cf& a, const Packet8cf& b) { - Packet8cf num = pmul(a, pconj(b)); - __m512 tmp = _mm512_mul_ps(b.v, b.v); - __m512 tmp2 = _mm512_shuffle_ps(tmp,tmp,0xB1); - __m512 denom = _mm512_add_ps(tmp, tmp2); - return Packet8cf(_mm512_div_ps(num.v, denom)); + return pdiv_complex(a, b); } template<> EIGEN_STRONG_INLINE Packet8cf pcplxflip(const Packet8cf& x) @@ -309,47 +305,11 @@ template<> EIGEN_STRONG_INLINE std::complex predux_mul(const Packet2cd(_mm512_extractf64x4_pd(a.v,1)))); } -template<> struct conj_helper -{ - EIGEN_STRONG_INLINE Packet4cd pmadd(const Packet4cd& x, const Packet4cd& y, const Packet4cd& c) const - { return padd(pmul(x,y),c); } - - EIGEN_STRONG_INLINE Packet4cd pmul(const Packet4cd& a, const Packet4cd& b) const - { - return internal::pmul(a, pconj(b)); - } -}; - -template<> struct conj_helper -{ - EIGEN_STRONG_INLINE Packet4cd pmadd(const Packet4cd& x, const Packet4cd& y, const Packet4cd& c) const - { return padd(pmul(x,y),c); } - - EIGEN_STRONG_INLINE Packet4cd pmul(const Packet4cd& a, const Packet4cd& b) const - { - return internal::pmul(pconj(a), b); - } -}; - -template<> struct conj_helper -{ - EIGEN_STRONG_INLINE Packet4cd pmadd(const Packet4cd& x, const Packet4cd& y, const Packet4cd& c) const - { return padd(pmul(x,y),c); } - - EIGEN_STRONG_INLINE Packet4cd pmul(const Packet4cd& a, const Packet4cd& b) const - { - return pconj(internal::pmul(a, b)); - } -}; - EIGEN_MAKE_CONJ_HELPER_CPLX_REAL(Packet4cd,Packet8d) template<> EIGEN_STRONG_INLINE Packet4cd pdiv(const Packet4cd& a, const Packet4cd& b) { - Packet4cd num = pmul(a, pconj(b)); - __m512d tmp = _mm512_mul_pd(b.v, b.v); - __m512d denom = padd(_mm512_permute_pd(tmp,0x55), tmp); - return Packet4cd(_mm512_div_pd(num.v, denom)); + return pdiv_complex(a, b); } template<> EIGEN_STRONG_INLINE Packet4cd pcplxflip(const Packet4cd& x) diff --git a/Eigen/src/Core/arch/AltiVec/Complex.h b/Eigen/src/Core/arch/AltiVec/Complex.h index e1711930b..058f8dd1e 100644 --- a/Eigen/src/Core/arch/AltiVec/Complex.h +++ b/Eigen/src/Core/arch/AltiVec/Complex.h @@ -210,10 +210,7 @@ EIGEN_MAKE_CONJ_HELPER_CPLX_REAL(Packet2cf,Packet4f) template<> EIGEN_STRONG_INLINE Packet2cf pdiv(const Packet2cf& a, const Packet2cf& b) { - // TODO optimize it for AltiVec - Packet2cf res = pmul(a, pconj(b)); - Packet4f s = pmul(b.v, b.v); - return Packet2cf(pdiv(res.v, padd(s, vec_perm(s, s, p16uc_COMPLEX32_REV)))); + return pdiv_complex(a, b); } template<> EIGEN_STRONG_INLINE Packet2cf pcplxflip(const Packet2cf& x) @@ -375,10 +372,7 @@ EIGEN_MAKE_CONJ_HELPER_CPLX_REAL(Packet1cd,Packet2d) template<> EIGEN_STRONG_INLINE Packet1cd pdiv(const Packet1cd& a, const Packet1cd& b) { - // TODO optimize it for AltiVec - Packet1cd res = pmul(a,pconj(b)); - Packet2d s = pmul(b.v, b.v); - return Packet1cd(pdiv(res.v, padd(s, vec_perm(s, s, p16uc_REVERSE64)))); + return pdiv_complex(a, b); } EIGEN_STRONG_INLINE Packet1cd pcplxflip/**/(const Packet1cd& x) diff --git a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h index c9fbaf68b..f1e10c898 100644 --- a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h +++ b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h @@ -757,6 +757,26 @@ Packet pcos_float(const Packet& x) return psincos_float(x); } +template +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS +EIGEN_UNUSED Packet pdiv_complex(const Packet& x, const Packet& y) { + typedef typename unpacket_traits::as_real RealPacket; + // In the following we annotate the code for the case where the inputs + // are a pair length-2 SIMD vectors representing a single pair of complex + // numbers x = a + i*b, y = c + i*d. + const RealPacket y_abs = pabs(y.v); // |c|, |d| + const RealPacket y_abs_flip = pcplxflip(Packet(y_abs)).v; // |d|, |c| + const RealPacket y_max = pmax(y_abs, y_abs_flip); // max(|c|, |d|), max(|c|, |d|) + const RealPacket y_scaled = pdiv(y.v, y_max); // c / max(|c|, |d|), d / max(|c|, |d|) + // Compute scaled denominator. + const RealPacket y_scaled_sq = pmul(y_scaled, y_scaled); // c'**2, d'**2 + const RealPacket denom = y_scaled_sq + pcplxflip(Packet(y_scaled_sq)).v; + Packet result_scaled = pmul(x, pconj(Packet(y_scaled))); // a * c' + b * d', -a * d + b * c + // Divide elementwise by denom. + result_scaled = Packet(pdiv(result_scaled.v, denom)); + // Rescale result + return Packet(pdiv(result_scaled.v, y_max)); +} template EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS diff --git a/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h b/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h index 177a04e93..730cc7395 100644 --- a/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h +++ b/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h @@ -101,6 +101,12 @@ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet psqrt_complex(const Packet& a); +/** \internal \returns x / y for complex types */ +template +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS +EIGEN_UNUSED +Packet pdiv_complex(const Packet& x, const Packet& y); + template struct ppolevl; diff --git a/Eigen/src/Core/arch/MSA/Complex.h b/Eigen/src/Core/arch/MSA/Complex.h index 53dacfa43..76e9f7ca0 100644 --- a/Eigen/src/Core/arch/MSA/Complex.h +++ b/Eigen/src/Core/arch/MSA/Complex.h @@ -75,16 +75,13 @@ struct Packet2cf { EIGEN_STRONG_INLINE Packet2cf operator-(const Packet2cf& b) const { return Packet2cf(*this) -= b; } + EIGEN_STRONG_INLINE Packet2cf operator/(const Packet2cf& b) const { + return pdiv_complex(Packet2cf(*this), b); + } EIGEN_STRONG_INLINE Packet2cf& operator/=(const Packet2cf& b) { - *this *= b.conjugate(); - Packet4f s = pmul(b.v, b.v); - s = padd(s, (Packet4f)__builtin_msa_shf_w((v4i32)s, EIGEN_MSA_SHF_I8(1, 0, 3, 2))); - v = pdiv(v, s); + *this = Packet2cf(*this) / b; return *this; } - EIGEN_STRONG_INLINE Packet2cf operator/(const Packet2cf& b) const { - return Packet2cf(*this) /= b; - } EIGEN_STRONG_INLINE Packet2cf operator-(void) const { return Packet2cf(pnegate(v)); } diff --git a/Eigen/src/Core/arch/NEON/Complex.h b/Eigen/src/Core/arch/NEON/Complex.h index f40af7f87..0f74fe8d2 100644 --- a/Eigen/src/Core/arch/NEON/Complex.h +++ b/Eigen/src/Core/arch/NEON/Complex.h @@ -347,27 +347,11 @@ EIGEN_MAKE_CONJ_HELPER_CPLX_REAL(Packet2cf,Packet4f) template<> EIGEN_STRONG_INLINE Packet1cf pdiv(const Packet1cf& a, const Packet1cf& b) { - // TODO optimize it for NEON - Packet1cf res = pmul(a, pconj(b)); - Packet2f s, rev_s; - - // this computes the norm - s = vmul_f32(b.v, b.v); - rev_s = vrev64_f32(s); - - return Packet1cf(pdiv(res.v, vadd_f32(s, rev_s))); + return pdiv_complex(a, b); } template<> EIGEN_STRONG_INLINE Packet2cf pdiv(const Packet2cf& a, const Packet2cf& b) { - // TODO optimize it for NEON - Packet2cf res = pmul(a,pconj(b)); - Packet4f s, rev_s; - - // this computes the norm - s = vmulq_f32(b.v, b.v); - rev_s = vrev64q_f32(s); - - return Packet2cf(pdiv(res.v, vaddq_f32(s, rev_s))); + return pdiv_complex(a, b); } EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock& /*kernel*/) {} @@ -553,12 +537,7 @@ EIGEN_MAKE_CONJ_HELPER_CPLX_REAL(Packet1cd,Packet2d) template<> EIGEN_STRONG_INLINE Packet1cd pdiv(const Packet1cd& a, const Packet1cd& b) { - // TODO optimize it for NEON - Packet1cd res = pmul(a,pconj(b)); - Packet2d s = pmul(b.v, b.v); - Packet2d rev_s = preverse(s); - - return Packet1cd(pdiv(res.v, padd(s,rev_s))); + return pdiv_complex(a, b); } EIGEN_STRONG_INLINE Packet1cd pcplxflip/**/(const Packet1cd& x) diff --git a/Eigen/src/Core/arch/SSE/Complex.h b/Eigen/src/Core/arch/SSE/Complex.h index 8fe22da46..08abd845a 100644 --- a/Eigen/src/Core/arch/SSE/Complex.h +++ b/Eigen/src/Core/arch/SSE/Complex.h @@ -174,14 +174,9 @@ EIGEN_MAKE_CONJ_HELPER_CPLX_REAL(Packet2cf,Packet4f) template<> EIGEN_STRONG_INLINE Packet2cf pdiv(const Packet2cf& a, const Packet2cf& b) { - // TODO optimize it for SSE3 and 4 - Packet2cf res = pmul(a, pconj(b)); - __m128 s = _mm_mul_ps(b.v,b.v); - return Packet2cf(_mm_div_ps(res.v,_mm_add_ps(s,vec4f_swizzle1(s, 1, 0, 3, 2)))); + return pdiv_complex(a, b); } - - //---------- double ---------- struct Packet1cd { @@ -299,10 +294,7 @@ EIGEN_MAKE_CONJ_HELPER_CPLX_REAL(Packet1cd,Packet2d) template<> EIGEN_STRONG_INLINE Packet1cd pdiv(const Packet1cd& a, const Packet1cd& b) { - // TODO optimize it for SSE3 and 4 - Packet1cd res = pmul(a,pconj(b)); - __m128d s = _mm_mul_pd(b.v,b.v); - return Packet1cd(_mm_div_pd(res.v, _mm_add_pd(s,_mm_shuffle_pd(s, s, 0x1)))); + return pdiv_complex(a, b); } EIGEN_STRONG_INLINE Packet1cd pcplxflip/* */(const Packet1cd& x) diff --git a/Eigen/src/Core/arch/ZVector/Complex.h b/Eigen/src/Core/arch/ZVector/Complex.h index 0b9b33d99..a81ec249b 100644 --- a/Eigen/src/Core/arch/ZVector/Complex.h +++ b/Eigen/src/Core/arch/ZVector/Complex.h @@ -169,10 +169,7 @@ EIGEN_MAKE_CONJ_HELPER_CPLX_REAL(Packet1cd,Packet2d) template<> EIGEN_STRONG_INLINE Packet1cd pdiv(const Packet1cd& a, const Packet1cd& b) { - // TODO optimize it for AltiVec - Packet1cd res = pmul(a,pconj(b)); - Packet2d s = vec_madd(b.v, b.v, p2d_ZERO_); - return Packet1cd(pdiv(res.v, s + vec_perm(s, s, p16uc_REVERSE64))); + return pdiv_complex(a, b); } EIGEN_STRONG_INLINE Packet1cd pcplxflip/**/(const Packet1cd& x) @@ -308,11 +305,7 @@ EIGEN_MAKE_CONJ_HELPER_CPLX_REAL(Packet2cf,Packet4f) template<> EIGEN_STRONG_INLINE Packet2cf pdiv(const Packet2cf& a, const Packet2cf& b) { - // TODO optimize it for AltiVec - Packet2cf res; - res.cd[0] = pdiv(a.cd[0], b.cd[0]); - res.cd[1] = pdiv(a.cd[1], b.cd[1]); - return res; + return pdiv_complex(a, b); } EIGEN_STRONG_INLINE Packet2cf pcplxflip/**/(const Packet2cf& x) @@ -394,10 +387,7 @@ EIGEN_MAKE_CONJ_HELPER_CPLX_REAL(Packet2cf,Packet4f) template<> EIGEN_STRONG_INLINE Packet2cf pdiv(const Packet2cf& a, const Packet2cf& b) { - // TODO optimize it for AltiVec - Packet2cf res = pmul(a, pconj(b)); - Packet4f s = pmul(b.v, b.v); - return Packet2cf(pdiv(res.v, padd(s, vec_perm(s, s, p16uc_COMPLEX32_REV)))); + return pdiv_complex(a, b); } template<> EIGEN_STRONG_INLINE Packet2cf pcplxflip(const Packet2cf& x) -- cgit v1.2.3