diff options
-rw-r--r-- | Eigen/src/Core/MathFunctions.h | 30 | ||||
-rw-r--r-- | Eigen/src/Core/MathFunctionsImpl.h | 9 |
2 files changed, 36 insertions, 3 deletions
diff --git a/Eigen/src/Core/MathFunctions.h b/Eigen/src/Core/MathFunctions.h index 7f82090a9..d7ac4d64d 100644 --- a/Eigen/src/Core/MathFunctions.h +++ b/Eigen/src/Core/MathFunctions.h @@ -2,6 +2,7 @@ // for linear algebra. // // Copyright (C) 2006-2010 Benoit Jacob <jacob.benoit.1@gmail.com> +// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. // // This Source Code Form is subject to the terms of the Mozilla // Public License v. 2.0. If a copy of the MPL was not distributed @@ -688,6 +689,30 @@ struct expm1_retval }; /**************************************************************************** +* Implementation of log * +****************************************************************************/ + +// Complex log defined in MathFunctionsImpl.h. +template<typename T> EIGEN_DEVICE_FUNC std::complex<T> complex_log(const std::complex<T>& z); + +template<typename Scalar> +struct log_impl { + EIGEN_DEVICE_FUNC static inline Scalar run(const Scalar& x) + { + EIGEN_USING_STD(log); + return static_cast<Scalar>(log(x)); + } +}; + +template<typename Scalar> +struct log_impl<std::complex<Scalar> > { + EIGEN_DEVICE_FUNC static inline std::complex<Scalar> run(const std::complex<Scalar>& z) + { + return complex_log(z); + } +}; + +/**************************************************************************** * Implementation of log1p * ****************************************************************************/ @@ -700,7 +725,7 @@ namespace std_fallback { typedef typename NumTraits<Scalar>::Real RealScalar; EIGEN_USING_STD(log); Scalar x1p = RealScalar(1) + x; - Scalar log_1p = log(x1p); + Scalar log_1p = log_impl<Scalar>::run(x1p); const bool is_small = numext::equal_strict(x1p, Scalar(1)); const bool is_inf = numext::equal_strict(x1p, log_1p); return (is_small || is_inf) ? x : x * (log_1p / (x1p - RealScalar(1))); @@ -1460,8 +1485,7 @@ T rsqrt(const T& x) template<typename T> EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T log(const T &x) { - EIGEN_USING_STD(log); - return static_cast<T>(log(x)); + return internal::log_impl<T>::run(x); } #if defined(SYCL_DEVICE_ONLY) diff --git a/Eigen/src/Core/MathFunctionsImpl.h b/Eigen/src/Core/MathFunctionsImpl.h index 0d3f317bb..4eaaaa784 100644 --- a/Eigen/src/Core/MathFunctionsImpl.h +++ b/Eigen/src/Core/MathFunctionsImpl.h @@ -184,6 +184,15 @@ EIGEN_DEVICE_FUNC std::complex<T> complex_rsqrt(const std::complex<T>& z) { : std::complex<T>(numext::abs(y) / (2 * w * abs_z), y < zero ? woz : -woz ); } +template<typename T> +EIGEN_DEVICE_FUNC std::complex<T> complex_log(const std::complex<T>& z) { + // Computes complex log. + T a = numext::abs(z); + EIGEN_USING_STD(atan2); + T b = atan2(z.imag(), z.real()); + return std::complex<T>(numext::log(a), b); +} + } // end namespace internal } // end namespace Eigen |