aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen/src/Core/arch/AVX/MathFunctions.h
diff options
context:
space:
mode:
Diffstat (limited to 'Eigen/src/Core/arch/AVX/MathFunctions.h')
-rw-r--r--Eigen/src/Core/arch/AVX/MathFunctions.h107
1 files changed, 75 insertions, 32 deletions
diff --git a/Eigen/src/Core/arch/AVX/MathFunctions.h b/Eigen/src/Core/arch/AVX/MathFunctions.h
index c4bd6bd53..98d8e029f 100644
--- a/Eigen/src/Core/arch/AVX/MathFunctions.h
+++ b/Eigen/src/Core/arch/AVX/MathFunctions.h
@@ -10,11 +10,6 @@
#ifndef EIGEN_MATH_FUNCTIONS_AVX_H
#define EIGEN_MATH_FUNCTIONS_AVX_H
-// For some reason, this function didn't make it into the avxintirn.h
-// used by the compiler, so we'll just wrap it.
-#define _mm256_setr_m128(lo, hi) \
- _mm256_insertf128_si256(_mm256_castsi128_si256(lo), (hi), 1)
-
/* The sin, cos, exp, and log functions of this file are loosely derived from
* Julien Pommier's sse math library: http://gruntthepeon.free.fr/ssemath/
*/
@@ -23,6 +18,28 @@ namespace Eigen {
namespace internal {
+inline Packet8i pshiftleft(Packet8i v, int n)
+{
+#ifdef EIGEN_VECTORIZE_AVX2
+ return _mm256_slli_epi32(v, n);
+#else
+ __m128i lo = _mm_slli_epi32(_mm256_extractf128_si256(v, 0), n);
+ __m128i hi = _mm_slli_epi32(_mm256_extractf128_si256(v, 1), n);
+ return _mm256_insertf128_si256(_mm256_castsi128_si256(lo), (hi), 1);
+#endif
+}
+
+inline Packet8f pshiftright(Packet8f v, int n)
+{
+#ifdef EIGEN_VECTORIZE_AVX2
+ return _mm256_cvtepi32_ps(_mm256_srli_epi32(_mm256_castps_si256(v), n));
+#else
+ __m128i lo = _mm_srli_epi32(_mm256_extractf128_si256(_mm256_castps_si256(v), 0), n);
+ __m128i hi = _mm_srli_epi32(_mm256_extractf128_si256(_mm256_castps_si256(v), 1), n);
+ return _mm256_cvtepi32_ps(_mm256_insertf128_si256(_mm256_castsi128_si256(lo), (hi), 1));
+#endif
+}
+
// Sine function
// Computes sin(x) by wrapping x to the interval [-Pi/4,3*Pi/4] and
// evaluating interpolants in [-Pi/4,Pi/4] or [Pi/4,3*Pi/4]. The interpolants
@@ -54,17 +71,8 @@ psin<Packet8f>(const Packet8f& _x) {
// Make a mask for the entries that need flipping, i.e. wherever the shift
// is odd.
Packet8i shift_ints = _mm256_cvtps_epi32(shift);
- Packet8i shift_isodd =
- _mm256_castps_si256(_mm256_and_ps(_mm256_castsi256_ps(shift_ints), _mm256_castsi256_ps(p8i_one)));
-#ifdef EIGEN_VECTORIZE_AVX2
- Packet8i sign_flip_mask = _mm256_slli_epi32(shift_isodd, 31);
-#else
- __m128i lo =
- _mm_slli_epi32(_mm256_extractf128_si256(shift_isodd, 0), 31);
- __m128i hi =
- _mm_slli_epi32(_mm256_extractf128_si256(shift_isodd, 1), 31);
- Packet8i sign_flip_mask = _mm256_setr_m128(lo, hi);
-#endif
+ Packet8i shift_isodd = _mm256_castps_si256(_mm256_and_ps(_mm256_castsi256_ps(shift_ints), _mm256_castsi256_ps(p8i_one)));
+ Packet8i sign_flip_mask = pshiftleft(shift_isodd, 31);
// Create a mask for which interpolant to use, i.e. if z > 1, then the mask
// is set to ones for that entry.
@@ -142,15 +150,7 @@ plog<Packet8f>(const Packet8f& _x) {
// Truncate input values to the minimum positive normal.
x = pmax(x, p8f_min_norm_pos);
-// Extract the shifted exponents (No bitwise shifting in regular AVX, so
-// convert to SSE and do it there).
-#ifdef EIGEN_VECTORIZE_AVX2
- Packet8f emm0 = _mm256_cvtepi32_ps(_mm256_srli_epi32(_mm256_castps_si256(x), 23));
-#else
- __m128i lo = _mm_srli_epi32(_mm256_extractf128_si256(_mm256_castps_si256(x), 0), 23);
- __m128i hi = _mm_srli_epi32(_mm256_extractf128_si256(_mm256_castps_si256(x), 1), 23);
- Packet8f emm0 = _mm256_cvtepi32_ps(_mm256_setr_m128(lo, hi));
-#endif
+ Packet8f emm0 = pshiftright(x,23);
Packet8f e = _mm256_sub_ps(emm0, p8f_126f);
// Set the exponents to -1, i.e. x are in the range [0.5,1).
@@ -259,18 +259,61 @@ pexp<Packet8f>(const Packet8f& _x) {
// Build emm0 = 2^m.
Packet8i emm0 = _mm256_cvttps_epi32(padd(m, p8f_127));
-#ifdef EIGEN_VECTORIZE_AVX2
- emm0 = _mm256_slli_epi32(emm0, 23);
-#else
- __m128i lo = _mm_slli_epi32(_mm256_extractf128_si256(emm0, 0), 23);
- __m128i hi = _mm_slli_epi32(_mm256_extractf128_si256(emm0, 1), 23);
- emm0 = _mm256_setr_m128(lo, hi);
-#endif
+ emm0 = pshiftleft(emm0, 23);
// Return 2^m * exp(r).
return pmax(pmul(y, _mm256_castsi256_ps(emm0)), _x);
}
+// Hyperbolic Tangent function.
+// Doesn't do anything fancy, just a 13/6-degree rational interpolant which
+// is accurate up to a couple of ulp in the range [-9, 9], outside of which the
+// fl(tanh(x)) = +/-1.
+template <>
+EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet8f
+ptanh<Packet8f>(const Packet8f& _x) {
+ // Clamp the inputs to the range [-9, 9] since anything outside
+ // this range is +/-1.0f in single-precision.
+ _EIGEN_DECLARE_CONST_Packet8f(plus_9, 9.0f);
+ _EIGEN_DECLARE_CONST_Packet8f(minus_9, -9.0f);
+ const Packet8f x = pmax(p8f_minus_9, pmin(p8f_plus_9, _x));
+
+ // The monomial coefficients of the numerator polynomial (odd).
+ _EIGEN_DECLARE_CONST_Packet8f(alpha_1, 4.89352455891786e-03f);
+ _EIGEN_DECLARE_CONST_Packet8f(alpha_3, 6.37261928875436e-04f);
+ _EIGEN_DECLARE_CONST_Packet8f(alpha_5, 1.48572235717979e-05f);
+ _EIGEN_DECLARE_CONST_Packet8f(alpha_7, 5.12229709037114e-08f);
+ _EIGEN_DECLARE_CONST_Packet8f(alpha_9, -8.60467152213735e-11f);
+ _EIGEN_DECLARE_CONST_Packet8f(alpha_11, 2.00018790482477e-13f);
+ _EIGEN_DECLARE_CONST_Packet8f(alpha_13, -2.76076847742355e-16f);
+
+ // The monomial coefficients of the denominator polynomial (even).
+ _EIGEN_DECLARE_CONST_Packet8f(beta_0, 4.89352518554385e-03f);
+ _EIGEN_DECLARE_CONST_Packet8f(beta_2, 2.26843463243900e-03f);
+ _EIGEN_DECLARE_CONST_Packet8f(beta_4, 1.18534705686654e-04f);
+ _EIGEN_DECLARE_CONST_Packet8f(beta_6, 1.19825839466702e-06f);
+
+ // Since the polynomials are odd/even, we need x^2.
+ const Packet8f x2 = pmul(x, x);
+
+ // Evaluate the numerator polynomial p.
+ Packet8f p = pmadd(x2, p8f_alpha_13, p8f_alpha_11);
+ p = pmadd(x2, p, p8f_alpha_9);
+ p = pmadd(x2, p, p8f_alpha_7);
+ p = pmadd(x2, p, p8f_alpha_5);
+ p = pmadd(x2, p, p8f_alpha_3);
+ p = pmadd(x2, p, p8f_alpha_1);
+ p = pmul(x, p);
+
+ // Evaluate the denominator polynomial p.
+ Packet8f q = pmadd(x2, p8f_beta_6, p8f_beta_4);
+ q = pmadd(x2, q, p8f_beta_2);
+ q = pmadd(x2, q, p8f_beta_0);
+
+ // Divide the numerator by the denominator.
+ return pdiv(p, q);
+}
+
template <>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet4d
pexp<Packet4d>(const Packet4d& _x) {