aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen/src/Core/arch/Default
diff options
context:
space:
mode:
authorGravatar Rasmus Munk Larsen <rmlarsen@google.com>2020-12-03 22:31:44 +0000
committerGravatar Rasmus Munk Larsen <rmlarsen@google.com>2020-12-03 22:31:44 +0000
commit4d91519a9be061da5d300079fca17dd0b9328050 (patch)
tree5546a7f478049ce24d8f69f20ad018d6a63ec807 /Eigen/src/Core/arch/Default
parent25d8ae7465e6430bc2dc7f65800332932d3bb774 (diff)
Add log2() operator to Eigen
Diffstat (limited to 'Eigen/src/Core/arch/Default')
-rw-r--r--Eigen/src/Core/arch/Default/BFloat16.h3
-rw-r--r--Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h63
-rw-r--r--Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h12
-rw-r--r--Eigen/src/Core/arch/Default/Half.h4
4 files changed, 71 insertions, 11 deletions
diff --git a/Eigen/src/Core/arch/Default/BFloat16.h b/Eigen/src/Core/arch/Default/BFloat16.h
index 351f451a3..616dcf667 100644
--- a/Eigen/src/Core/arch/Default/BFloat16.h
+++ b/Eigen/src/Core/arch/Default/BFloat16.h
@@ -512,6 +512,9 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 log1p(const bfloat16& a) {
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 log10(const bfloat16& a) {
return bfloat16(::log10f(float(a)));
}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 log2(const bfloat16& a) {
+ return bfloat16(static_cast<float>(M_LOG2E) * ::logf(float(a)));
+}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 sqrt(const bfloat16& a) {
return bfloat16(::sqrtf(float(a)));
}
diff --git a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h
index 60db2e12f..c6bb89b05 100644
--- a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h
+++ b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h
@@ -59,16 +59,16 @@ pldexp_double(Packet a, Packet exponent)
return pmul(a, preinterpret<Packet>(plogical_shift_left<52>(ei)));
}
-// Natural logarithm
+// Natural or base 2 logarithm.
// Computes log(x) as log(2^e * m) = C*e + log(m), where the constant C =log(2)
// and m is in the range [sqrt(1/2),sqrt(2)). In this range, the logarithm can
// be easily approximated by a polynomial centered on m=1 for stability.
// TODO(gonnet): Further reduce the interval allowing for lower-degree
// polynomial interpolants -> ... -> profit!
-template <typename Packet>
+template <typename Packet, bool base2>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
EIGEN_UNUSED
-Packet plog_float(const Packet _x)
+Packet plog_impl_float(const Packet _x)
{
Packet x = _x;
@@ -131,8 +131,13 @@ Packet plog_float(const Packet _x)
x = padd(x, y);
// Add the logarithm of the exponent back to the result of the interpolation.
- const Packet cst_ln2 = pset1<Packet>(M_LN2);
- x = pmadd(e, cst_ln2, x);
+ if (base2) {
+ const Packet cst_log2e = pset1<Packet>(static_cast<float>(M_LOG2E));
+ x = pmadd(x, cst_log2e, e);
+ } else {
+ const Packet cst_ln2 = pset1<Packet>(static_cast<float>(M_LN2));
+ x = pmadd(e, cst_ln2, x);
+ }
Packet invalid_mask = pcmp_lt_or_nan(_x, pzero(_x));
Packet iszero_mask = pcmp_eq(_x,pzero(_x));
@@ -145,8 +150,23 @@ Packet plog_float(const Packet _x)
por(pselect(pos_inf_mask,cst_pos_inf,x), invalid_mask));
}
+template <typename Packet>
+EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
+EIGEN_UNUSED
+Packet plog_float(const Packet _x)
+{
+ return plog_impl_float<Packet, /* base2 */ false>(_x);
+}
+
+template <typename Packet>
+EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
+EIGEN_UNUSED
+Packet plog2_float(const Packet _x)
+{
+ return plog_impl_float<Packet, /* base2 */ true>(_x);
+}
-/* Returns the base e (2.718...) logarithm of x.
+/* Returns the base e (2.718...) or base 2 logarithm of x.
* The argument is separated into its exponent and fractional parts.
* The logarithm of the fraction in the interval [sqrt(1/2), sqrt(2)],
* is approximated by
@@ -155,16 +175,16 @@ Packet plog_float(const Packet _x)
*
* for more detail see: http://www.netlib.org/cephes/
*/
-template <typename Packet>
+template <typename Packet, bool base2>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
EIGEN_UNUSED
-Packet plog_double(const Packet _x)
+Packet plog_impl_double(const Packet _x)
{
Packet x = _x;
const Packet cst_1 = pset1<Packet>(1.0);
const Packet cst_neg_half = pset1<Packet>(-0.5);
- // The smallest non denormalized float number.
+ // The smallest non denormalized double.
const Packet cst_min_norm_pos = pset1frombits<Packet>( static_cast<uint64_t>(0x0010000000000000ull));
const Packet cst_minus_inf = pset1frombits<Packet>( static_cast<uint64_t>(0xfff0000000000000ull));
const Packet cst_pos_inf = pset1frombits<Packet>( static_cast<uint64_t>(0x7ff0000000000000ull));
@@ -232,8 +252,13 @@ Packet plog_double(const Packet _x)
x = padd(x, y);
// Add the logarithm of the exponent back to the result of the interpolation.
- const Packet cst_ln2 = pset1<Packet>(M_LN2);
- x = pmadd(e, cst_ln2, x);
+ if (base2) {
+ const Packet cst_log2e = pset1<Packet>(M_LOG2E);
+ x = pmadd(x, cst_log2e, e);
+ } else {
+ const Packet cst_ln2 = pset1<Packet>(M_LN2);
+ x = pmadd(e, cst_ln2, x);
+ }
Packet invalid_mask = pcmp_lt_or_nan(_x, pzero(_x));
Packet iszero_mask = pcmp_eq(_x,pzero(_x));
@@ -246,6 +271,22 @@ Packet plog_double(const Packet _x)
por(pselect(pos_inf_mask,cst_pos_inf,x), invalid_mask));
}
+template <typename Packet>
+EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
+EIGEN_UNUSED
+Packet plog_double(const Packet _x)
+{
+ return plog_impl_double<Packet, /* base2 */ false>(_x);
+}
+
+template <typename Packet>
+EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
+EIGEN_UNUSED
+Packet plog2_double(const Packet _x)
+{
+ return plog_impl_double<Packet, /* base2 */ true>(_x);
+}
+
/** \internal \returns log(1 + x) computed using W. Kahan's formula.
See: http://www.plunk.org/~hatch/rightway.php
*/
diff --git a/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h b/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h
index 0e02a1b20..b0f0b78fc 100644
--- a/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h
+++ b/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h
@@ -32,12 +32,24 @@ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
EIGEN_UNUSED
Packet plog_float(const Packet _x);
+/** \internal \returns log2(x) for single precision float */
+template <typename Packet>
+EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
+EIGEN_UNUSED
+Packet plog2_float(const Packet _x);
+
/** \internal \returns log(x) for single precision float */
template <typename Packet>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
EIGEN_UNUSED
Packet plog_double(const Packet _x);
+/** \internal \returns log2(x) for single precision float */
+template <typename Packet>
+EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
+EIGEN_UNUSED
+Packet plog2_double(const Packet _x);
+
/** \internal \returns log(1 + x) */
template<typename Packet>
Packet generic_plog1p(const Packet& x);
diff --git a/Eigen/src/Core/arch/Default/Half.h b/Eigen/src/Core/arch/Default/Half.h
index 7029c500d..db85f4edf 100644
--- a/Eigen/src/Core/arch/Default/Half.h
+++ b/Eigen/src/Core/arch/Default/Half.h
@@ -622,6 +622,10 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half log1p(const half& a) {
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half log10(const half& a) {
return half(::log10f(float(a)));
}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half log2(const half& a) {
+ return half(static_cast<float>(M_LOG2E) * ::logf(float(a)));
+}
+
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half sqrt(const half& a) {
#if (EIGEN_CUDA_SDK_VER >= 80000 && defined EIGEN_CUDA_ARCH && EIGEN_CUDA_ARCH >= 530) || \
defined(EIGEN_HIP_DEVICE_COMPILE)