From 77b447c24e3344e43ff64eb932d4bb35a2db01ce Mon Sep 17 00:00:00 2001 From: Rasmus Munk Larsen Date: Mon, 12 Nov 2018 13:42:24 -0800 Subject: Add optimized version of logistic function for float. As an example, this is about 50% faster than the existing version on Haswell using AVX. --- Eigen/src/Core/functors/UnaryFunctors.h | 61 +++++++++++++++++++++++++++++++++ 1 file changed, 61 insertions(+) diff --git a/Eigen/src/Core/functors/UnaryFunctors.h b/Eigen/src/Core/functors/UnaryFunctors.h index c1cc2ab3b..0c2d2cfca 100644 --- a/Eigen/src/Core/functors/UnaryFunctors.h +++ b/Eigen/src/Core/functors/UnaryFunctors.h @@ -850,6 +850,67 @@ struct functor_traits > { }; }; +/** \internal + * \brief Template specialization of the logistic function for float. + * + * Uses just a 9/10-degree rational interpolant which + * interpolates 1/(1+exp(-x)) - 0.5 up to a couple of ulp in the range + * [-18, 18], outside of which the fl(logistic(x)) = {0|1}. The shifted + * logistic is interpolated because it was easier to make the fit converge. + * + */ + +template <> +struct scalar_logistic_op { + EIGEN_EMPTY_STRUCT_CTOR(scalar_logistic_op) + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float operator()(const float& x) const { + const float one = 1.0f; + return one / (one + numext::exp(-x)); + } + + template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + Packet packetOp(const Packet& _x) const { + // Clamp the inputs to the range [-18, 18] since anything outside + // this range is 0.0f or 1.0f in single-precision. + const Packet x = pmax(pmin(_x, pset1(18.0)), pset1(-18.0)); + + // The monomial coefficients of the numerator polynomial (odd). + const Packet alpha_1 = pset1(2.48287947061529e-01); + const Packet alpha_3 = pset1(8.51377133304701e-03); + const Packet alpha_5 = pset1(6.08574864600143e-05); + const Packet alpha_7 = pset1(1.15627324459942e-07); + const Packet alpha_9 = pset1(4.37031012579801e-11); + + // The monomial coefficients of the denominator polynomial (even). + const Packet beta_0 = pset1(9.93151921023180e-01); + const Packet beta_2 = pset1(1.16817656904453e-01); + const Packet beta_4 = pset1(1.70198817374094e-03); + const Packet beta_6 = pset1(6.29106785017040e-06); + const Packet beta_8 = pset1(5.76102136993427e-09); + const Packet beta_10 = pset1(6.10247389755681e-13); + + // Since the polynomials are odd/even, we need x^2. + const Packet x2 = pmul(x, x); + + // Evaluate the numerator polynomial p. + Packet p = pmadd(x2, alpha_9, alpha_7); + p = pmadd(x2, p, alpha_5); + p = pmadd(x2, p, alpha_3); + p = pmadd(x2, p, alpha_1); + p = pmul(x, p); + + // Evaluate the denominator polynomial p. + Packet q = pmadd(x2, beta_10, beta_8); + q = pmadd(x2, q, beta_6); + q = pmadd(x2, q, beta_4); + q = pmadd(x2, q, beta_2); + q = pmadd(x2, q, beta_0); + + // Divide the numerator by the denominator and shift it up. + return pmax(pmin(padd(pdiv(p, q), pset1(0.5)), pset1(1.0)), + pset1(0.0)); + } +}; } // end namespace internal -- cgit v1.2.3