aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Nathan Luehr <nluehr@nvidia.com>2021-04-19 18:05:27 -0500
committerGravatar Rasmus Munk Larsen <rmlarsen@google.com>2021-05-11 22:02:21 +0000
commit7e6a1c129c201db4eff46f4dd68acdc7e935eaf2 (patch)
treece3662dea2b4b5329e89f7dc75f7bb1918bbff9b
parent6753f0f197e7b8a8019e82e7b144ac0281d6a7f1 (diff)
Device implementation of log for std::complex types.
-rw-r--r--Eigen/src/Core/MathFunctions.h30
-rw-r--r--Eigen/src/Core/MathFunctionsImpl.h9
2 files changed, 36 insertions, 3 deletions
diff --git a/Eigen/src/Core/MathFunctions.h b/Eigen/src/Core/MathFunctions.h
index 7f82090a9..d7ac4d64d 100644
--- a/Eigen/src/Core/MathFunctions.h
+++ b/Eigen/src/Core/MathFunctions.h
@@ -2,6 +2,7 @@
// for linear algebra.
//
// Copyright (C) 2006-2010 Benoit Jacob <jacob.benoit.1@gmail.com>
+// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
@@ -688,6 +689,30 @@ struct expm1_retval
};
/****************************************************************************
+* Implementation of log *
+****************************************************************************/
+
+// Complex log defined in MathFunctionsImpl.h.
+template<typename T> EIGEN_DEVICE_FUNC std::complex<T> complex_log(const std::complex<T>& z);
+
+template<typename Scalar>
+struct log_impl {
+ EIGEN_DEVICE_FUNC static inline Scalar run(const Scalar& x)
+ {
+ EIGEN_USING_STD(log);
+ return static_cast<Scalar>(log(x));
+ }
+};
+
+template<typename Scalar>
+struct log_impl<std::complex<Scalar> > {
+ EIGEN_DEVICE_FUNC static inline std::complex<Scalar> run(const std::complex<Scalar>& z)
+ {
+ return complex_log(z);
+ }
+};
+
+/****************************************************************************
* Implementation of log1p *
****************************************************************************/
@@ -700,7 +725,7 @@ namespace std_fallback {
typedef typename NumTraits<Scalar>::Real RealScalar;
EIGEN_USING_STD(log);
Scalar x1p = RealScalar(1) + x;
- Scalar log_1p = log(x1p);
+ Scalar log_1p = log_impl<Scalar>::run(x1p);
const bool is_small = numext::equal_strict(x1p, Scalar(1));
const bool is_inf = numext::equal_strict(x1p, log_1p);
return (is_small || is_inf) ? x : x * (log_1p / (x1p - RealScalar(1)));
@@ -1460,8 +1485,7 @@ T rsqrt(const T& x)
template<typename T>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
T log(const T &x) {
- EIGEN_USING_STD(log);
- return static_cast<T>(log(x));
+ return internal::log_impl<T>::run(x);
}
#if defined(SYCL_DEVICE_ONLY)
diff --git a/Eigen/src/Core/MathFunctionsImpl.h b/Eigen/src/Core/MathFunctionsImpl.h
index 0d3f317bb..4eaaaa784 100644
--- a/Eigen/src/Core/MathFunctionsImpl.h
+++ b/Eigen/src/Core/MathFunctionsImpl.h
@@ -184,6 +184,15 @@ EIGEN_DEVICE_FUNC std::complex<T> complex_rsqrt(const std::complex<T>& z) {
: std::complex<T>(numext::abs(y) / (2 * w * abs_z), y < zero ? woz : -woz );
}
+template<typename T>
+EIGEN_DEVICE_FUNC std::complex<T> complex_log(const std::complex<T>& z) {
+ // Computes complex log.
+ T a = numext::abs(z);
+ EIGEN_USING_STD(atan2);
+ T b = atan2(z.imag(), z.real());
+ return std::complex<T>(numext::log(a), b);
+}
+
} // end namespace internal
} // end namespace Eigen