aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h
diff options
context:
space:
mode:
authorGravatar Rasmus Munk Larsen <rmlarsen@google.com>2020-11-24 20:53:07 +0000
committerGravatar Rasmus Munk Larsen <rmlarsen@google.com>2020-11-24 20:53:07 +0000
commitc770746d709686ef2b8b652616d9232f9b028e78 (patch)
tree624821fa175d8f40cc13886d7483ffd35e9da1e3 /Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h
parent22f67b59585805fedf86759f7013b2b670f83386 (diff)
Fix Half NaN definition and test.
The `half_float` test was failing with `-mcpu=cortex-a55` (native `__fp16`) due to a bad NaN bit-pattern comparison (in the case of casting a float to `__fp16`, the signaling `NaN` is quieted). There was also an inconsistency between `numeric_limits<half>::quiet_NaN()` and `NumTraits::quiet_NaN()`. Here we correct the inconsistency and compare NaNs according to the IEEE 754 definition. Also modified the `bfloat16_float` test to match. Tested with `cortex-a53` and `cortex-a55`.
Diffstat (limited to 'Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h')
-rw-r--r--Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h56
1 files changed, 56 insertions, 0 deletions
diff --git a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h
index 6d92d1c72..4e8a42463 100644
--- a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h
+++ b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h
@@ -643,6 +643,62 @@ Packet pcos_float(const Packet& x)
return psincos_float<false>(x);
}
+
+template<typename Packet>
+EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
+EIGEN_UNUSED
+Packet psqrt_complex(const Packet& a) {
+ typedef typename unpacket_traits<Packet>::type Scalar;
+ typedef typename Scalar::value_type RealScalar;
+ typedef typename unpacket_traits<Packet>::real RealPacket;
+
+ // Computes the principal sqrt of the complex numbers. For clarity, the comments
+ // below spell out the steps, assuming Packet contains 2 complex numbers, e.g.
+ // a = [a0_r, a0_i, a1_r, a1_i]
+ // In other words, the function computes b = [b0_r, b0_i, b1_r, b1_i] such that
+ // (b0_r + i*b0_i)^2 = a0_r + i*a0_i, and
+ // (b1_r + i*b1_i)^2 = a1_r + i*a1_i .
+
+ // Step 1. Compute l = [l0, l0, l1, l1], where
+ // l0 = sqrt(a0_r^2 + a0_i^2), l1 = sqrt(a1_r^2 + a1_i^2)
+ // To avoid over- and underflow, we use the stable formula for each hypotenuse
+ // l0 = (x0 == 0 ? x0 : x0 * sqrt(1 + (y0/x0)**2)),
+ // where x0 = max(|a0_r|, |a0_i|), y0 = min(|a0_r|, |a0_i|)
+ // and similarly for l1.
+ Packet a_flip = pcplxflip(a);
+ Packet zero_mask;
+ zero_mask.v = pcmp_eq(a.v, pzero(a.v));
+ RealPacket a_abs = pabs(a.v); // [|a0_i|, |a0_r|, |a1_i|, |a1_r|]
+ RealPacket a_abs_flip = pabs(a_flip.v); // [|a0_i|, |a0_r|, |a1_i|, |a1_r|]
+ RealPacket a_max = pmax(a_abs, a_abs_flip);
+ RealPacket a_min = pmin(a_abs, a_abs_flip);
+ RealPacket r = pdiv(a_min, a_max);
+ RealPacket one = pset1<RealPacket>(RealScalar(1));
+ RealPacket l = pmul(a_max, psqrt(padd(one, pmul(r, r)))); // [l0, l0, l1, l1]
+ // Set l to zero if both real and imaginary parts are zero.
+ l = pandnot(l, pand(zero_mask.v, pcplxflip(zero_mask).v));
+
+ // Step 2. Compute
+ // [ sqrt((l0 + a0_r)/2), sqrt((l0 - a0_r)/2),
+ // sqrt((l1 + a1_r)/2), sqrt((l1 - a1_r)/2) ]
+ Packet real_mask;
+ real_mask.v = peven_mask(real_mask.v);
+ Packet a_real = pand(a, real_mask);
+ l = padd(l, a_real.v);
+ l = psub(l, pcplxflip(a_real).v);
+ l = psqrt(pmul(l, pset1<RealPacket>(RealScalar(0.5))));
+ // If imag(a) is zero, we mask out the imaginary part, which should be zero.
+ l = pandnot(l, pandnot(zero_mask.v, real_mask.v));
+
+ //Step 3. Apply the sign of the imaginary parts of a to get the final result:
+ // b = [ sqrt((l0 + a0_r)/2), sign(a0_i)*sqrt((l0 - a0_r)/2),
+ // sqrt((l1 + a1_r)/2), sign(a1_i)*sqrt((l1 - a1_r)/2) ]
+ RealPacket imag_sign_mask = pset1<Packet>(Scalar(RealScalar(0.0), RealScalar(-0.0))).v;
+ RealPacket imag_signs = pand<RealPacket>(a.v, imag_sign_mask);
+ Packet result = Packet(pxor<RealPacket>(l, imag_signs));
+ return result;
+}
+
/* polevl (modified for Eigen)
*
* Evaluate polynomial