aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen/src/Core/functors
diff options
context:
space:
mode:
authorGravatar Rasmus Munk Larsen <rmlarsen@google.com>2018-11-12 13:42:24 -0800
committerGravatar Rasmus Munk Larsen <rmlarsen@google.com>2018-11-12 13:42:24 -0800
commit77b447c24e3344e43ff64eb932d4bb35a2db01ce (patch)
tree31ef3a98a227660054435b7792c43c578ed68d2c /Eigen/src/Core/functors
parentc81bdbdadc72b96dda3c4a120bfb189df62ece18 (diff)
Add optimized version of logistic function for float. As an example, this is about 50% faster than the existing version on Haswell using AVX.
Diffstat (limited to 'Eigen/src/Core/functors')
-rw-r--r--Eigen/src/Core/functors/UnaryFunctors.h61
1 files changed, 61 insertions, 0 deletions
diff --git a/Eigen/src/Core/functors/UnaryFunctors.h b/Eigen/src/Core/functors/UnaryFunctors.h
index c1cc2ab3b..0c2d2cfca 100644
--- a/Eigen/src/Core/functors/UnaryFunctors.h
+++ b/Eigen/src/Core/functors/UnaryFunctors.h
@@ -850,6 +850,67 @@ struct functor_traits<scalar_logistic_op<T> > {
};
};
+/** \internal
+ * \brief Template specialization of the logistic function for float.
+ *
+ * Uses just a 9/10-degree rational interpolant which
+ * interpolates 1/(1+exp(-x)) - 0.5 up to a couple of ulp in the range
+ * [-18, 18], outside of which the fl(logistic(x)) = {0|1}. The shifted
+ * logistic is interpolated because it was easier to make the fit converge.
+ *
+ */
+
+template <>
+struct scalar_logistic_op<float> {
+ EIGEN_EMPTY_STRUCT_CTOR(scalar_logistic_op)
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float operator()(const float& x) const {
+ const float one = 1.0f;
+ return one / (one + numext::exp(-x));
+ }
+
+ template <typename Packet> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ Packet packetOp(const Packet& _x) const {
+ // Clamp the inputs to the range [-18, 18] since anything outside
+ // this range is 0.0f or 1.0f in single-precision.
+ const Packet x = pmax(pmin(_x, pset1<Packet>(18.0)), pset1<Packet>(-18.0));
+
+ // The monomial coefficients of the numerator polynomial (odd).
+ const Packet alpha_1 = pset1<Packet>(2.48287947061529e-01);
+ const Packet alpha_3 = pset1<Packet>(8.51377133304701e-03);
+ const Packet alpha_5 = pset1<Packet>(6.08574864600143e-05);
+ const Packet alpha_7 = pset1<Packet>(1.15627324459942e-07);
+ const Packet alpha_9 = pset1<Packet>(4.37031012579801e-11);
+
+ // The monomial coefficients of the denominator polynomial (even).
+ const Packet beta_0 = pset1<Packet>(9.93151921023180e-01);
+ const Packet beta_2 = pset1<Packet>(1.16817656904453e-01);
+ const Packet beta_4 = pset1<Packet>(1.70198817374094e-03);
+ const Packet beta_6 = pset1<Packet>(6.29106785017040e-06);
+ const Packet beta_8 = pset1<Packet>(5.76102136993427e-09);
+ const Packet beta_10 = pset1<Packet>(6.10247389755681e-13);
+
+ // Since the polynomials are odd/even, we need x^2.
+ const Packet x2 = pmul(x, x);
+
+ // Evaluate the numerator polynomial p.
+ Packet p = pmadd(x2, alpha_9, alpha_7);
+ p = pmadd(x2, p, alpha_5);
+ p = pmadd(x2, p, alpha_3);
+ p = pmadd(x2, p, alpha_1);
+ p = pmul(x, p);
+
+ // Evaluate the denominator polynomial p.
+ Packet q = pmadd(x2, beta_10, beta_8);
+ q = pmadd(x2, q, beta_6);
+ q = pmadd(x2, q, beta_4);
+ q = pmadd(x2, q, beta_2);
+ q = pmadd(x2, q, beta_0);
+
+ // Divide the numerator by the denominator and shift it up.
+ return pmax(pmin(padd(pdiv(p, q), pset1<Packet>(0.5)), pset1<Packet>(1.0)),
+ pset1<Packet>(0.0));
+ }
+};
} // end namespace internal