aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h
diff options
context:
space:
mode:
authorGravatar Rasmus Munk Larsen <rmlarsen@google.com>2019-08-12 13:53:28 -0700
committerGravatar Rasmus Munk Larsen <rmlarsen@google.com>2019-08-12 13:53:28 -0700
commita3298b22ecd19a80a2fc03df3d463fdb04907c87 (patch)
treed6d9788ebd3a55539649240aa59a22479160a20a /Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h
parentd55d392e7b1136655b4223bea8e99cb2fe0a8afd (diff)
Implement vectorized versions of log1p and expm1 in Eigen using Kahan's formulas, and change the scalar implementations to properly handle infinite arguments.
Depending on instruction set, significant speedups are observed for the vectorized path: log1p wall time is reduced 60-93% (2.5x - 15x speedup) expm1 wall time is reduced 0-85% (1x - 7x speedup) The scalar path is slower by 20-30% due to the extra branch needed to handle +infinity correctly. Full benchmarks measured on Intel(R) Xeon(R) Gold 6154 here: https://bitbucket.org/snippets/rmlarsen/MXBkpM
Diffstat (limited to 'Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h')
-rw-r--r--Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h46
1 files changed, 46 insertions, 0 deletions
diff --git a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h
index 43e827638..640aae05a 100644
--- a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h
+++ b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h
@@ -126,6 +126,52 @@ Packet plog_float(const Packet _x)
por(pselect(pos_inf_mask,cst_pos_inf,x), invalid_mask));
}
+/** \internal \returns log(1 + x) computed using W. Kahan's formula.
+ See: http://www.plunk.org/~hatch/rightway.php
+ */
+template<typename Packet>
+Packet generic_plog1p(const Packet& x)
+{
+ typedef typename unpacket_traits<Packet>::type ScalarType;
+ const Packet one = pset1<Packet>(ScalarType(1));
+ Packet xp1 = padd(x, one);
+ Packet small_mask = pcmp_eq(xp1, one);
+ Packet log1 = plog(xp1);
+ // Add a check to handle x == +inf.
+ Packet pos_inf_mask = pcmp_eq(x, log1);
+ Packet log_large = pmul(x, pdiv(log1, psub(xp1, one)));
+ return pselect(por(small_mask, pos_inf_mask), x, log_large);
+}
+
+/** \internal \returns exp(x)-1 computed using W. Kahan's formula.
+ See: http://www.plunk.org/~hatch/rightway.php
+ */
+template<typename Packet>
+Packet generic_expm1(const Packet& x)
+{
+ typedef typename unpacket_traits<Packet>::type ScalarType;
+ const Packet one = pset1<Packet>(ScalarType(1));
+ const Packet neg_one = pset1<Packet>(ScalarType(-1));
+ Packet u = pexp(x);
+ Packet one_mask = pcmp_eq(u, one);
+ Packet u_minus_one = psub(u, one);
+ Packet neg_one_mask = pcmp_eq(u_minus_one, neg_one);
+ Packet logu = plog(u);
+ // The following comparison is to catch the case where
+ // exp(x) = +inf. It is written in this way to avoid having
+ // to form the constant +inf, which depends on the packet
+ // type.
+ Packet pos_inf_mask = pcmp_eq(logu, u);
+ Packet expm1 = pmul(u_minus_one, pdiv(x, logu));
+ expm1 = pselect(pos_inf_mask, u, expm1);
+ return pselect(one_mask,
+ x,
+ pselect(neg_one_mask,
+ neg_one,
+ expm1));
+}
+
+
// Exponential function. Works by writing "x = m*log(2) + r" where
// "m = floor(x/log(2)+1/2)" and "r" is the remainder. The result is then
// "exp(x) = 2^m*exp(r)" where exp(r) is in the range [-1,1).