diff options
author | Rasmus Munk Larsen <rmlarsen@google.com> | 2020-12-03 22:31:44 +0000 |
---|---|---|
committer | Rasmus Munk Larsen <rmlarsen@google.com> | 2020-12-03 22:31:44 +0000 |
commit | 4d91519a9be061da5d300079fca17dd0b9328050 (patch) | |
tree | 5546a7f478049ce24d8f69f20ad018d6a63ec807 /Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h | |
parent | 25d8ae7465e6430bc2dc7f65800332932d3bb774 (diff) |
Add log2() operator to Eigen
Diffstat (limited to 'Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h')
-rw-r--r-- | Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h | 63 |
1 files changed, 52 insertions, 11 deletions
diff --git a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h index 60db2e12f..c6bb89b05 100644 --- a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h +++ b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h @@ -59,16 +59,16 @@ pldexp_double(Packet a, Packet exponent) return pmul(a, preinterpret<Packet>(plogical_shift_left<52>(ei))); } -// Natural logarithm +// Natural or base 2 logarithm. // Computes log(x) as log(2^e * m) = C*e + log(m), where the constant C =log(2) // and m is in the range [sqrt(1/2),sqrt(2)). In this range, the logarithm can // be easily approximated by a polynomial centered on m=1 for stability. // TODO(gonnet): Further reduce the interval allowing for lower-degree // polynomial interpolants -> ... -> profit! -template <typename Packet> +template <typename Packet, bool base2> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED -Packet plog_float(const Packet _x) +Packet plog_impl_float(const Packet _x) { Packet x = _x; @@ -131,8 +131,13 @@ Packet plog_float(const Packet _x) x = padd(x, y); // Add the logarithm of the exponent back to the result of the interpolation. - const Packet cst_ln2 = pset1<Packet>(M_LN2); - x = pmadd(e, cst_ln2, x); + if (base2) { + const Packet cst_log2e = pset1<Packet>(static_cast<float>(M_LOG2E)); + x = pmadd(x, cst_log2e, e); + } else { + const Packet cst_ln2 = pset1<Packet>(static_cast<float>(M_LN2)); + x = pmadd(e, cst_ln2, x); + } Packet invalid_mask = pcmp_lt_or_nan(_x, pzero(_x)); Packet iszero_mask = pcmp_eq(_x,pzero(_x)); @@ -145,8 +150,23 @@ Packet plog_float(const Packet _x) por(pselect(pos_inf_mask,cst_pos_inf,x), invalid_mask)); } +template <typename Packet> +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS +EIGEN_UNUSED +Packet plog_float(const Packet _x) +{ + return plog_impl_float<Packet, /* base2 */ false>(_x); +} + +template <typename Packet> +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS +EIGEN_UNUSED +Packet plog2_float(const Packet _x) +{ + return plog_impl_float<Packet, /* base2 */ true>(_x); +} -/* Returns the base e (2.718...) logarithm of x. +/* Returns the base e (2.718...) or base 2 logarithm of x. * The argument is separated into its exponent and fractional parts. * The logarithm of the fraction in the interval [sqrt(1/2), sqrt(2)], * is approximated by @@ -155,16 +175,16 @@ Packet plog_float(const Packet _x) * * for more detail see: http://www.netlib.org/cephes/ */ -template <typename Packet> +template <typename Packet, bool base2> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED -Packet plog_double(const Packet _x) +Packet plog_impl_double(const Packet _x) { Packet x = _x; const Packet cst_1 = pset1<Packet>(1.0); const Packet cst_neg_half = pset1<Packet>(-0.5); - // The smallest non denormalized float number. + // The smallest non denormalized double. const Packet cst_min_norm_pos = pset1frombits<Packet>( static_cast<uint64_t>(0x0010000000000000ull)); const Packet cst_minus_inf = pset1frombits<Packet>( static_cast<uint64_t>(0xfff0000000000000ull)); const Packet cst_pos_inf = pset1frombits<Packet>( static_cast<uint64_t>(0x7ff0000000000000ull)); @@ -232,8 +252,13 @@ Packet plog_double(const Packet _x) x = padd(x, y); // Add the logarithm of the exponent back to the result of the interpolation. - const Packet cst_ln2 = pset1<Packet>(M_LN2); - x = pmadd(e, cst_ln2, x); + if (base2) { + const Packet cst_log2e = pset1<Packet>(M_LOG2E); + x = pmadd(x, cst_log2e, e); + } else { + const Packet cst_ln2 = pset1<Packet>(M_LN2); + x = pmadd(e, cst_ln2, x); + } Packet invalid_mask = pcmp_lt_or_nan(_x, pzero(_x)); Packet iszero_mask = pcmp_eq(_x,pzero(_x)); @@ -246,6 +271,22 @@ Packet plog_double(const Packet _x) por(pselect(pos_inf_mask,cst_pos_inf,x), invalid_mask)); } +template <typename Packet> +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS +EIGEN_UNUSED +Packet plog_double(const Packet _x) +{ + return plog_impl_double<Packet, /* base2 */ false>(_x); +} + +template <typename Packet> +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS +EIGEN_UNUSED +Packet plog2_double(const Packet _x) +{ + return plog_impl_double<Packet, /* base2 */ true>(_x); +} + /** \internal \returns log(1 + x) computed using W. Kahan's formula. See: http://www.plunk.org/~hatch/rightway.php */ |